jet.hpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #ifndef JET_HPP
  2. #define JET_HPP
  3. #include <cmath>
  4. #include <iostream>
  5. #include <Eigen/Core>
  6. using Eigen::Vector;
  7. template<typename T, unsigned _n_vars = 1, typename diff_type = T>
  8. struct jet {
  9. T value;
  10. constexpr static unsigned int n_vars = _n_vars;
  11. Eigen::Array<diff_type, n_vars, 1> deriv;
  12. jet() = default;
  13. jet(T v) : value(v), deriv(Eigen::Array<diff_type, n_vars, 1>::Zero()){}
  14. jet(T v, Eigen::Array<T, n_vars, 1> d) : value(v), deriv(d){}
  15. jet(T v, T d) requires(n_vars == 1) : value(v), deriv(d){}
  16. operator T()const noexcept{
  17. return value;
  18. }
  19. // Addition operator
  20. jet<T, n_vars> operator+(const jet<T, n_vars>& other) const {
  21. return {value + other.value, deriv + other.deriv};
  22. }
  23. // Subtraction operator
  24. jet<T, n_vars> operator-(const jet<T, n_vars>& other) const {
  25. return {value - other.value, deriv - other.deriv};
  26. }
  27. // Multiplication operator
  28. jet<T, n_vars> operator*(const jet<T, n_vars>& other) const {
  29. return {value * other.value, value * other.deriv + deriv * other.value};
  30. }
  31. // Division operator
  32. jet<T, n_vars> operator/(const jet<T, n_vars>& other) const {
  33. T denominator = other.value * other.value;
  34. return {value / other.value, (deriv * other.value - value * other.deriv) / denominator};
  35. }
  36. // Compound assignment operators
  37. // +=
  38. jet<T, n_vars>& operator+=(const jet<T, n_vars>& other) {
  39. value += other.value;
  40. deriv += other.deriv;
  41. return *this;
  42. }
  43. // -=
  44. jet<T, n_vars>& operator-=(const jet<T, n_vars>& other) {
  45. value -= other.value;
  46. deriv -= other.deriv;
  47. return *this;
  48. }
  49. // *=
  50. jet<T, n_vars>& operator*=(const jet<T, n_vars>& other) {
  51. T oldValue = value;
  52. value *= other.value;
  53. deriv = oldValue * other.deriv + deriv * other.value;
  54. return *this;
  55. }
  56. // /=
  57. jet<T, n_vars>& operator/=(const jet<T, n_vars>& other) {
  58. T denominator = other.value * other.value;
  59. T oldValue = value;
  60. value /= other.value;
  61. deriv = (deriv * other.value - oldValue * other.deriv) / denominator;
  62. return *this;
  63. }
  64. template<std::convertible_to<T> R>
  65. jet<T, n_vars> operator+(const R& scalar) const {
  66. return {value + scalar, deriv};
  67. }
  68. // Subtraction operator with T value
  69. template<std::convertible_to<T> R>
  70. jet<T, n_vars> operator-(const R& scalar) const {
  71. return {value - scalar, deriv};
  72. }
  73. // Multiplication operator with T value
  74. template<std::convertible_to<T> R>
  75. jet<T, n_vars> operator*(const R& scalar) const {
  76. return {value * scalar, deriv * scalar};
  77. }
  78. // Division operator with T value
  79. template<std::convertible_to<T> R>
  80. jet<T, n_vars> operator/(const R& scalar) const {
  81. return {value / scalar, deriv / scalar};
  82. }
  83. // Compound assignment operators with T values
  84. // +=
  85. template<std::convertible_to<T> R>
  86. jet<T, n_vars>& operator+=(const R& scalar) {
  87. value += scalar;
  88. return *this;
  89. }
  90. // -=
  91. template<std::convertible_to<T> R>
  92. jet<T, n_vars>& operator-=(const R& scalar) {
  93. value -= scalar;
  94. return *this;
  95. }
  96. // *=
  97. template<std::convertible_to<T> R>
  98. jet<T, n_vars>& operator*=(const R& scalar) {
  99. value *= scalar;
  100. deriv *= scalar;
  101. return *this;
  102. }
  103. // /=
  104. template<std::convertible_to<T> R>
  105. jet<T, n_vars>& operator/=(const R& scalar) {
  106. value /= scalar;
  107. deriv /= scalar;
  108. return *this;
  109. }
  110. template<typename stream_t>
  111. friend stream_t& operator<<(stream_t& str, const jet<T, n_vars, diff_type>& j){
  112. str << '(' << j.value << ", d/dx = " << j.deriv.transpose() << ')';
  113. return str;
  114. }
  115. };
  116. template<typename T, unsigned N>
  117. jet<T, N> sin(const jet<T, N>& x) {
  118. using std::sin;
  119. using std::cos;
  120. return {sin(x.value), x.deriv * cos(x.value)};
  121. }
  122. template<typename T, unsigned N>
  123. jet<T, N> cos(const jet<T, N>& x) {
  124. using std::sin;
  125. using std::cos;
  126. return {cos(x.value), -x.deriv * sin(x.value)};
  127. }
  128. template<typename T, unsigned N>
  129. jet<T, N> exp(const jet<T, N>& x) {
  130. using std::exp;
  131. T expVal = exp(x.value);
  132. return {expVal, x.deriv * expVal};
  133. }
  134. template<typename T, unsigned N>
  135. jet<T, N> log(const jet<T, N>& x) {
  136. using std::log;
  137. return {log(x.value), x.deriv / x.value};
  138. }
  139. template<typename T, unsigned N>
  140. jet<T, N> sqrt(const jet<T, N>& x) {
  141. T sqrtVal = sqrt(x.value);
  142. return {sqrtVal, x.deriv / (T(2) * sqrtVal)};
  143. }
  144. #endif