IterationIR.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. #include "IterationIR.h"
  2. #include <utility>
  3. #include <optional>
  4. using namespace mnd;
  5. namespace mnd
  6. {
  7. using ir::Node;
  8. struct ConvertVisitor
  9. {
  10. using NodePair = std::pair<Node*, Node*>;
  11. util::Arena<Node>& arena;
  12. const mnd::IterationFormula& iterationFormula;
  13. Node* zero;
  14. Node* half;
  15. Node* one;
  16. ConvertVisitor(util::Arena<Node>& arena, const mnd::IterationFormula& iterationFormula) :
  17. arena{ arena },
  18. iterationFormula{ iterationFormula }
  19. {
  20. zero = arena.allocate(ir::Constant{ 0.0 });
  21. half = arena.allocate(ir::Constant{ 0.5 });
  22. one = arena.allocate(ir::Constant{ 1.0 });
  23. }
  24. NodePair operator() (const Constant& c)
  25. {
  26. Node* cnst = arena.allocate(ir::Constant{ c.re });
  27. Node* zero = arena.allocate(ir::Constant{ c.im });
  28. return { cnst, zero };
  29. }
  30. NodePair operator() (const Variable& v)
  31. {
  32. //printf("var %s\n", v.name.c_str()); fflush(stdout);
  33. if (v.name == "i") {
  34. return { zero, one };
  35. }
  36. else if (iterationFormula.containsVariable(v.name)) {
  37. Node* a = arena.allocate(ir::Variable{ v.name + "_re" });
  38. Node* b = arena.allocate(ir::Variable{ v.name + "_im" });
  39. return { a, b };
  40. }
  41. else
  42. throw "unknown variable";
  43. }
  44. NodePair operator() (const Negation& v)
  45. {
  46. auto [opa, opb] = std::visit(*this, *v.operand);
  47. Node* a = arena.allocate(ir::Negation{ opa });
  48. Node* b = arena.allocate(ir::Negation{ opb });
  49. return { a, b };
  50. }
  51. NodePair operator() (const Addition& add)
  52. {
  53. auto [lefta, leftb] = std::visit(*this, *add.left);
  54. auto [righta, rightb] = std::visit(*this, *add.right);
  55. if (add.subtraction) {
  56. Node* a = arena.allocate(ir::Subtraction{ lefta, righta });
  57. Node* b = arena.allocate(ir::Subtraction{ leftb, rightb });
  58. return { a, b };
  59. }
  60. else {
  61. Node* a = arena.allocate(ir::Addition{ lefta, righta });
  62. Node* b = arena.allocate(ir::Addition{ leftb, rightb });
  63. return { a, b };
  64. }
  65. }
  66. NodePair operator() (const Multiplication& mul)
  67. {
  68. auto [a, b] = std::visit(*this, *mul.left);
  69. auto [c, d] = std::visit(*this, *mul.right);
  70. return multiplication(a, b, c, d);
  71. }
  72. NodePair multiplication(Node* a, Node* b, Node* c, Node* d)
  73. {
  74. Node* ac = arena.allocate(ir::Multiplication{ a, c });
  75. Node* bd = arena.allocate(ir::Multiplication{ b, d });
  76. Node* ad = arena.allocate(ir::Multiplication{ a, d });
  77. Node* bc = arena.allocate(ir::Multiplication{ b, c });
  78. Node* newa = arena.allocate(ir::Subtraction{ ac, bd });
  79. Node* newb = arena.allocate(ir::Addition{ ad, bc });
  80. return { newa, newb };
  81. }
  82. NodePair sq(Node* a, Node* b)
  83. {
  84. Node* aa = arena.allocate(ir::Multiplication{ a, a });
  85. Node* bb = arena.allocate(ir::Multiplication{ b, b });
  86. Node* ab = arena.allocate(ir::Multiplication{ a, b });
  87. Node* newa = arena.allocate(ir::Subtraction{ aa, bb });
  88. Node* newb = arena.allocate(ir::Addition{ ab, ab });
  89. return { newa, newb };
  90. }
  91. NodePair operator() (const Division& div)
  92. {
  93. auto [a, b] = std::visit(*this, *div.left);
  94. auto [c, d] = std::visit(*this, *div.right);
  95. return division(a, b, c, d);
  96. }
  97. NodePair division(Node* a, Node* b, Node* c, Node* d)
  98. {
  99. Node* ac = arena.allocate(ir::Multiplication{ a, c });
  100. Node* bd = arena.allocate(ir::Multiplication{ b, d });
  101. Node* bc = arena.allocate(ir::Multiplication{ b, c });
  102. Node* ad = arena.allocate(ir::Multiplication{ a, d });
  103. Node* cc = arena.allocate(ir::Multiplication{ c, c });
  104. Node* dd = arena.allocate(ir::Multiplication{ d, d });
  105. Node* ac_bd = arena.allocate(ir::Addition{ ac, bd });
  106. Node* bc_ad = arena.allocate(ir::Subtraction{ bc, ad });
  107. Node* den = arena.allocate(ir::Addition{ cc, dd });
  108. Node* factor = arena.allocate(ir::Division{ one, den });
  109. Node* re = arena.allocate(ir::Multiplication{ factor, ac_bd });
  110. Node* im = arena.allocate(ir::Multiplication{ factor, bc_ad });
  111. return { re, im };
  112. }
  113. NodePair oneOver(Node* a, Node* b)
  114. {
  115. Node* cc = arena.allocate(ir::Multiplication{ a, a });
  116. Node* dd = arena.allocate(ir::Multiplication{ b, b });
  117. Node* den = arena.allocate(ir::Addition{ cc, dd });
  118. Node* factor = arena.allocate(ir::Division{ one, den });
  119. Node* re = arena.allocate(ir::Multiplication{ factor, a });
  120. Node* im = arena.allocate(ir::Negation{ arena.allocate(ir::Multiplication{ factor, b }) });
  121. return { re, im };
  122. }
  123. NodePair operator() (const Pow& p)
  124. {
  125. auto [a, b] = std::visit(*this, *p.left);
  126. auto [c, d] = std::visit(*this, *p.right);
  127. if (p.integerExponent) {
  128. if (auto* ex = std::get_if<ir::Constant>(c)) {
  129. return intPow({ a, b }, int(ex->value));
  130. }
  131. }
  132. if (p.realExponent) {
  133. return realPow({ a, b }, c);
  134. }
  135. auto arg = arena.allocate(ir::Atan2{ b, a });
  136. auto aa = arena.allocate(ir::Multiplication{ a, a });
  137. auto bb = arena.allocate(ir::Multiplication{ b, b });
  138. auto absSq = arena.allocate(ir::Addition{ aa, bb });
  139. auto halfc = arena.allocate(ir::Multiplication{ c, half });
  140. auto darg = arena.allocate(ir::Multiplication{ d, arg });
  141. auto minusdarg = arena.allocate(ir::Negation{ darg });
  142. auto abspowc = arena.allocate(ir::Pow{ absSq, halfc });
  143. auto expdarg = arena.allocate(ir::Exp{ minusdarg });
  144. auto newAbs = arena.allocate(ir::Multiplication{ abspowc, expdarg });
  145. auto carg = arena.allocate(ir::Multiplication{ arg, c });
  146. auto halfd = arena.allocate(ir::Multiplication{ d, half });
  147. //absSq = arena.allocate(ir::Addition{ absSq, half });
  148. auto lnabsSq = arena.allocate(ir::Ln{ absSq });
  149. auto halfdlnabsSq = arena.allocate(ir::Multiplication{ halfd, lnabsSq });
  150. auto newArg = arena.allocate(ir::Addition{ halfdlnabsSq, carg });
  151. auto cosArg = arena.allocate(ir::Cos{ newArg });
  152. auto sinArg = arena.allocate(ir::Sin{ newArg });
  153. auto newA = arena.allocate(ir::Multiplication{ cosArg, newAbs });
  154. auto newB = arena.allocate(ir::Multiplication{ sinArg, newAbs });
  155. return { newA, newB };
  156. }
  157. NodePair intPow(NodePair val, int exponent) {
  158. auto [a, b] = val;
  159. if (exponent < 0) {
  160. auto [inva, invb] = intPow(val, -exponent);
  161. return oneOver(inva, invb);
  162. }
  163. if (exponent == 0)
  164. return { one, zero };
  165. else if (exponent == 1)
  166. return val;
  167. else if (exponent == 2)
  168. return sq(a, b);
  169. else {
  170. bool isEven = (exponent % 2) == 0;
  171. if (isEven) {
  172. NodePair square = sq(a, b);
  173. return intPow(square, exponent / 2);
  174. }
  175. else {
  176. int expm1 = exponent - 1;
  177. NodePair square = sq(a, b);
  178. auto[pa, pb] = intPow(square, expm1 / 2);
  179. return multiplication(pa, pb, a, b);
  180. }
  181. }
  182. }
  183. NodePair realPow(NodePair val, Node* exponent) {
  184. auto [a, b] = val;
  185. auto arg = arena.allocate(ir::Atan2{ b, a });
  186. auto aa = arena.allocate(ir::Multiplication{ a, a });
  187. auto bb = arena.allocate(ir::Multiplication{ b, b });
  188. auto absSq = arena.allocate(ir::Addition{ aa, bb });
  189. auto halfc = arena.allocate(ir::Multiplication{ exponent, half });
  190. auto newAbs = arena.allocate(ir::Pow{ absSq, halfc });
  191. auto newArg = arena.allocate(ir::Multiplication{ arg, exponent });
  192. auto cosArg = arena.allocate(ir::Cos{ newArg });
  193. auto sinArg = arena.allocate(ir::Sin{ newArg });
  194. auto newA = arena.allocate(ir::Multiplication{ cosArg, newAbs });
  195. auto newB = arena.allocate(ir::Multiplication{ sinArg, newAbs });
  196. return { newA, newB };
  197. }
  198. };
  199. ir::Formula expand(const mnd::IterationFormula& z0, const mnd::IterationFormula& zi)
  200. {
  201. ir::Formula formula;
  202. ConvertVisitor cv0{ formula.nodeArena, z0 };
  203. ConvertVisitor cvi{ formula.nodeArena, zi };
  204. std::tie(formula.startA, formula.startB) = std::visit(cv0, *z0.expr);
  205. std::tie(formula.newA, formula.newB) = std::visit(cvi, *zi.expr);
  206. return formula;
  207. }
  208. }
  209. using namespace std::string_literals;
  210. struct ToStringVisitor
  211. {
  212. // return string and precedence
  213. using Ret = std::pair<std::string, int>;
  214. Ret operator()(const ir::Constant& c) {
  215. return { mnd::toString(c.value), 0 };
  216. }
  217. Ret operator()(const ir::Variable& v) {
  218. return { v.name, 0 };
  219. }
  220. Ret operator()(const ir::Negation& n) {
  221. auto [str, prec] = std::visit(*this, *n.value);
  222. if (prec > 0)
  223. return { "-("s + str + ")", 2 };
  224. else
  225. return { "-"s + str, 2 };
  226. }
  227. Ret operator()(const ir::Addition& n) {
  228. auto [strl, precl] = std::visit(*this, *n.left);
  229. auto [strr, precr] = std::visit(*this, *n.right);
  230. std::string ret;
  231. if (precl > 4)
  232. ret += strl + " + ";
  233. else
  234. ret += "(" + strl + ") + ";
  235. if (precr > 4)
  236. ret += strr;
  237. else
  238. ret += "(" + strr + ")";
  239. return { ret, 4 };
  240. }
  241. Ret operator()(const ir::Subtraction& n) {
  242. auto [strl, precl] = std::visit(*this, *n.left);
  243. auto [strr, precr] = std::visit(*this, *n.right);
  244. std::string ret;
  245. if (precl > 4)
  246. ret += strl + " - ";
  247. else
  248. ret += "(" + strl + ") - ";
  249. if (precr >= 4)
  250. ret += strr;
  251. else
  252. ret += "(" + strr + ")";
  253. return { ret, 4 };
  254. }
  255. Ret operator()(const ir::Multiplication& n) {
  256. auto [strl, precl] = std::visit(*this, *n.left);
  257. auto [strr, precr] = std::visit(*this, *n.right);
  258. std::string ret;
  259. if (precl > 3)
  260. ret += strl + " * ";
  261. else
  262. ret += "(" + strl + ") * ";
  263. if (precr > 3)
  264. ret += strr;
  265. else
  266. ret += "(" + strr + ")";
  267. return { ret, 3 };
  268. }
  269. Ret operator()(const ir::Division& n) {
  270. auto [strl, precl] = std::visit(*this, *n.left);
  271. auto [strr, precr] = std::visit(*this, *n.right);
  272. std::string ret;
  273. if (precl > 3)
  274. ret += strl + " / ";
  275. else
  276. ret += "(" + strl + ") / ";
  277. if (precr >= 3)
  278. ret += strr;
  279. else
  280. ret += "(" + strr + ")";
  281. return { ret, 3 };
  282. }
  283. Ret operator()(const ir::Atan2& n) {
  284. return { "atan2(" + std::visit(*this, *n.left).first + ", " + std::visit(*this, *n.right).first + ")", 1 };
  285. }
  286. Ret operator()(const ir::Pow& n) {
  287. auto [strl, precl] = std::visit(*this, *n.left);
  288. auto [strr, precr] = std::visit(*this, *n.right);
  289. std::string ret;
  290. if (precl >= 2)
  291. ret += strl + " ^ ";
  292. else
  293. ret += "(" + strl + ") ^ ";
  294. if (precr > 2)
  295. ret += strr;
  296. else
  297. ret += "(" + strr + ")";
  298. return { ret, 2 };
  299. }
  300. Ret operator()(const ir::Cos& n) {
  301. return { "cos(" + std::visit(*this, *n.value).first + ")", 1 };
  302. }
  303. Ret operator()(const ir::Sin& n) {
  304. return { "sin(" + std::visit(*this, *n.value).first + ")", 1 };
  305. }
  306. Ret operator()(const ir::Exp& n) {
  307. return { "exp(" + std::visit(*this, *n.value).first + ")", 1 };
  308. }
  309. Ret operator()(const ir::Ln& n) {
  310. return { "ln(" + std::visit(*this, *n.value).first + ")", 1 };
  311. }
  312. };
  313. std::string mnd::ir::Formula::toString(void) const
  314. {
  315. return std::string("a = ") + std::visit(ToStringVisitor{}, *this->newA).first +
  316. "\nb = " + std::visit(ToStringVisitor{}, *this->newB).first +
  317. "\nx = " + std::visit(ToStringVisitor{}, *this->startA).first +
  318. "\ny = " + std::visit(ToStringVisitor{}, *this->startB).first;
  319. }
  320. struct ConstantPropagator
  321. {
  322. mnd::ir::Formula& formula;
  323. mnd::util::Arena<Node>& arena;
  324. using MaybeNode = std::optional<Node>;
  325. ConstantPropagator(mnd::ir::Formula& formula) :
  326. formula{ formula },
  327. arena{ formula.nodeArena }
  328. {
  329. }
  330. void propagateConstants(void) {
  331. visitNode(formula.newA);
  332. visitNode(formula.newB);
  333. visitNode(formula.newA);
  334. visitNode(formula.newB);
  335. }
  336. bool hasBeenVisited(Node* n) {
  337. return std::visit([] (auto& x) {
  338. if (auto* b = std::any_cast<bool>(&x.nodeData))
  339. return *b;
  340. else
  341. return false;
  342. }, *n);
  343. }
  344. void visitNode(Node* n) {
  345. if (!hasBeenVisited(n)) {
  346. MaybeNode mbn = std::visit(*this, *n);
  347. if (mbn.has_value()) {
  348. *n = std::move(mbn.value());
  349. }
  350. std::visit([] (auto& x) { x.nodeData = true; }, *n);
  351. }
  352. }
  353. ir::Constant* getIfConstant(Node* n) {
  354. return std::get_if<ir::Constant>(n);
  355. }
  356. MaybeNode operator()(ir::Constant& x) {
  357. return std::nullopt;
  358. }
  359. MaybeNode operator()(ir::Variable& x) {
  360. return std::nullopt;
  361. }
  362. MaybeNode operator()(ir::Negation& n) {
  363. visitNode(n.value);
  364. if (auto* c = getIfConstant(n.value)) {
  365. return ir::Constant{ -c->value };
  366. }
  367. if (auto* neg = std::get_if<ir::Negation>(n.value)) {
  368. return *neg->value;
  369. }
  370. return std::nullopt;
  371. }
  372. MaybeNode operator()(ir::Addition& n) {
  373. visitNode(n.left);
  374. visitNode(n.right);
  375. auto* ca = getIfConstant(n.left);
  376. auto* cb = getIfConstant(n.right);
  377. if (ca && cb) {
  378. return ir::Constant{ ca->value + cb->value };
  379. }
  380. else if (ca && ca->value == 0) {
  381. return *n.right;
  382. }
  383. else if (cb && cb->value == 0) {
  384. return *n.left;
  385. }
  386. else if (cb) {
  387. // move constants to the left
  388. std::swap(n.left, n.right);
  389. }
  390. else if (auto* nright = std::get_if<ir::Negation>(n.right)) {
  391. return ir::Subtraction{ n.left, nright->value };
  392. }
  393. return std::nullopt;
  394. }
  395. MaybeNode operator()(ir::Subtraction& n) {
  396. visitNode(n.left);
  397. visitNode(n.right);
  398. auto* ca = getIfConstant(n.left);
  399. auto* cb = getIfConstant(n.right);
  400. if (ca && cb) {
  401. return ir::Constant{ ca->value - cb->value };
  402. }
  403. else if (ca && ca->value == 0) {
  404. return ir::Negation{ n.right };
  405. }
  406. else if (cb && cb->value == 0) {
  407. return *n.left;
  408. }
  409. return std::nullopt;
  410. }
  411. MaybeNode operator()(ir::Multiplication& n) {
  412. visitNode(n.left);
  413. visitNode(n.right);
  414. printf("simpifying mul: %s * %s\n", std::visit(ToStringVisitor{}, *n.left).first.c_str(), std::visit(ToStringVisitor{}, *n.right).first.c_str());
  415. auto* ca = getIfConstant(n.left);
  416. auto* cb = getIfConstant(n.right);
  417. if (ca && cb) {
  418. return ir::Constant{ ca->value * cb->value };
  419. }
  420. if (ca && ca->value == 0) {
  421. return ir::Constant{ 0 };
  422. }
  423. if (cb && cb->value == 0) {
  424. return ir::Constant{ 0 };
  425. }
  426. if (ca && ca->value == 1) {
  427. return *n.right;
  428. }
  429. if (cb && cb->value == 1) {
  430. return *n.left;
  431. }
  432. if (ca) {
  433. auto* rightMul = std::get_if<ir::Multiplication>(n.right);
  434. if (rightMul) {
  435. auto* clr = getIfConstant(rightMul->left);
  436. if (clr) {
  437. printf("left %s, right %s\n", mnd::toString(ca->value).c_str(), mnd::toString(clr->value).c_str());
  438. //ca->value *= clr->value;
  439. n.right = rightMul->right;
  440. auto mul = ir::Multiplication{ arena.allocate(ir::Constant{ ca->value * clr->value }), rightMul->right };//
  441. auto maybeBetter = this->operator()(mul);
  442. return maybeBetter.value_or(mul);
  443. }
  444. }
  445. }
  446. if (cb) {
  447. // move constants to the left
  448. std::swap(n.left, n.right);
  449. }
  450. return std::nullopt;
  451. }
  452. MaybeNode operator()(ir::Division& n) {
  453. visitNode(n.left);
  454. visitNode(n.right);
  455. auto* ca = getIfConstant(n.left);
  456. auto* cb = getIfConstant(n.right);
  457. if (ca && cb) {
  458. return ir::Constant{ ca->value / cb->value };
  459. }
  460. else if (ca && ca->value == 0) {
  461. return ir::Constant{ 0 };
  462. }
  463. else if (cb && cb->value == 1) {
  464. return *n.left;
  465. }
  466. return std::nullopt;
  467. }
  468. MaybeNode operator()(ir::Atan2& n) {
  469. visitNode(n.left);
  470. visitNode(n.right);
  471. auto* ca = getIfConstant(n.left);
  472. auto* cb = getIfConstant(n.right);
  473. if (ca && cb) {
  474. return ir::Constant{ mnd::atan2(ca->value, cb->value) };
  475. }
  476. return std::nullopt;
  477. }
  478. MaybeNode operator()(ir::Pow& n) {
  479. visitNode(n.left);
  480. visitNode(n.right);
  481. auto* ca = getIfConstant(n.left);
  482. auto* cb = getIfConstant(n.right);
  483. if (ca && cb) {
  484. return ir::Constant{ mnd::pow(ca->value, cb->value) };
  485. }
  486. else if (cb && cb->value == 1) {
  487. return *n.left;
  488. }
  489. else if (cb && cb->value == 1) {
  490. return ir::Constant{ 1 };
  491. }
  492. return std::nullopt;
  493. }
  494. MaybeNode operator()(ir::Cos& n) {
  495. visitNode(n.value);
  496. auto* ca = getIfConstant(n.value);
  497. if (ca) {
  498. return ir::Constant{ mnd::cos(ca->value) };
  499. }
  500. return std::nullopt;
  501. }
  502. MaybeNode operator()(ir::Sin& n) {
  503. visitNode(n.value);
  504. auto* ca = getIfConstant(n.value);
  505. if (ca) {
  506. return ir::Constant{ mnd::sin(ca->value) };
  507. }
  508. return std::nullopt;
  509. }
  510. MaybeNode operator()(ir::Exp& n) {
  511. visitNode(n.value);
  512. auto* ca = getIfConstant(n.value);
  513. if (ca) {
  514. return ir::Constant{ mnd::exp(ca->value) };
  515. }
  516. return std::nullopt;
  517. }
  518. MaybeNode operator()(ir::Ln& n) {
  519. visitNode(n.value);
  520. auto* ca = getIfConstant(n.value);
  521. if (ca) {
  522. return ir::Constant{ mnd::log(ca->value) };
  523. }
  524. return std::nullopt;
  525. }
  526. };
  527. void mnd::ir::Formula::constantPropagation(void)
  528. {
  529. ConstantPropagator cp { *this };
  530. cp.propagateConstants();
  531. }
  532. void mnd::ir::Formula::optimize(void)
  533. {
  534. constantPropagation();
  535. }
  536. void mnd::ir::Formula::clearNodeData(void)
  537. {
  538. nodeArena.forAll([] (Node& n) {
  539. std::visit([] (auto& x) {
  540. x.nodeData.reset();
  541. }, n);
  542. });
  543. }