IterationCompiler.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857
  1. #include "IterationCompiler.h"
  2. #include "NaiveIRGenerator.h"
  3. #include "Mandel.h"
  4. #ifdef WITH_ASMJIT
  5. #include "ExecData.h"
  6. #endif // WITH_ASMJIT
  7. #include "OpenClInternal.h"
  8. #include "OpenClCode.h"
  9. #include <cmath>
  10. #include <omp.h>
  11. #include <any>
  12. #include <string>
  13. using namespace std::string_literals;
  14. namespace mnd
  15. {
  16. #ifdef WITH_ASMJIT
  17. #if defined(__x86_64__) || defined(_M_X64)
  18. struct CompileVisitor
  19. {
  20. using Reg = asmjit::x86::Xmm;
  21. asmjit::x86::Compiler& cc;
  22. Reg& a;
  23. Reg& b;
  24. Reg& x;
  25. Reg& y;
  26. Reg visitNode(ir::Node& node)
  27. {
  28. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  29. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  30. return *regPtr;
  31. }
  32. else {
  33. Reg reg = std::visit(*this, node);
  34. nodeData = reg;
  35. return reg;
  36. }
  37. }
  38. CompileVisitor(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  39. cc{ cc },
  40. a{ a }, b{ b },
  41. x{ x }, y{ y }
  42. {
  43. }
  44. Reg operator()(const ir::Constant& c) {
  45. auto constant = cc.newDoubleConst(asmjit::ConstPool::kScopeLocal, mnd::convert<double>(c.value));
  46. auto reg = cc.newXmmSd();
  47. std::string commentStr = "move constant [";
  48. commentStr += std::to_string(mnd::convert<double>(c.value));
  49. commentStr += "]";
  50. cc.comment(commentStr.c_str());
  51. cc.movsd(reg, constant);
  52. return reg;
  53. }
  54. Reg operator()(const ir::Variable& v) {
  55. if (v.name == "z_re") {
  56. return a;
  57. }
  58. else if (v.name == "z_im") {
  59. return b;
  60. }
  61. else if (v.name == "c_re") {
  62. return x;
  63. }
  64. else if (v.name == "c_im") {
  65. return y;
  66. }
  67. else
  68. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  69. }
  70. Reg operator()(const ir::Negation& n) {
  71. auto sub = cc.newXmmSd();
  72. cc.xorpd(sub, sub);
  73. cc.subsd(sub, visitNode(*n.value));
  74. return sub;
  75. }
  76. Reg operator()(const ir::Addition& add) {
  77. auto res = cc.newXmmSd();
  78. cc.movapd(res, visitNode(*add.left));
  79. cc.addsd(res, visitNode(*add.right));
  80. return res;
  81. }
  82. Reg operator()(const ir::Subtraction& add) {
  83. auto res = cc.newXmmSd();
  84. cc.movapd(res, visitNode(*add.left));
  85. cc.subsd(res, visitNode(*add.right));
  86. return res;
  87. }
  88. Reg operator()(const ir::Multiplication& add) {
  89. auto res = cc.newXmmSd();
  90. cc.movapd(res, visitNode(*add.left));
  91. cc.mulsd(res, visitNode(*add.right));
  92. return res;
  93. }
  94. Reg operator()(const ir::Division& add) {
  95. auto res = cc.newXmmSd();
  96. cc.movapd(res, visitNode(*add.left));
  97. cc.divsd(res, visitNode(*add.right));
  98. return res;
  99. }
  100. static double myAtan2(double y, double x)
  101. {
  102. double result = ::atan2(y, x);
  103. printf("atan2(%f, %f) = %f\n", y, x, result);
  104. return result;
  105. }
  106. Reg operator()(const ir::Atan2& at2) {
  107. using namespace asmjit;
  108. auto y = visitNode(*at2.left);
  109. auto x = visitNode(*at2.right);
  110. auto arg = cc.newXmmSd();
  111. double(*atanFunc)(double, double) = ::atan2;
  112. cc.comment("call atan2");
  113. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  114. call->setArg(0, y);
  115. call->setArg(1, x);
  116. call->setRet(0, arg);
  117. return arg;
  118. }
  119. Reg operator()(const ir::Pow& p) {
  120. using namespace asmjit;
  121. auto a = visitNode(*p.left);
  122. auto b = visitNode(*p.right);
  123. auto arg = cc.newXmmSd();
  124. double(*powFunc)(double, double) = ::pow;
  125. cc.comment("call pow");
  126. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  127. call->setArg(0, a);
  128. call->setArg(1, b);
  129. call->setRet(0, arg);
  130. return arg;
  131. }
  132. Reg operator()(const ir::Cos& c) {
  133. using namespace asmjit;
  134. auto a = visitNode(*c.value);
  135. auto arg = cc.newXmmSd();
  136. double(*cosFunc)(double) = ::cos;
  137. cc.comment("call cos");
  138. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  139. call->setArg(0, a);
  140. call->setRet(0, arg);
  141. return arg;
  142. }
  143. Reg operator()(const ir::Sin& s) {
  144. using namespace asmjit;
  145. auto a = visitNode(*s.value);
  146. auto arg = cc.newXmmSd();
  147. double(*sinFunc)(double) = ::sin;
  148. cc.comment("call sin");
  149. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  150. call->setArg(0, a);
  151. call->setRet(0, arg);
  152. return arg;
  153. }
  154. Reg operator()(const ir::Exp& ex) {
  155. using namespace asmjit;
  156. auto a = visitNode(*ex.value);
  157. auto arg = cc.newXmmSd();
  158. double(*expFunc)(double) = ::exp;
  159. cc.comment("call exp");
  160. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  161. call->setArg(0, a);
  162. call->setRet(0, arg);
  163. return arg;
  164. }
  165. Reg operator()(const ir::Ln& l) {
  166. using namespace asmjit;
  167. auto a = visitNode(*l.value);
  168. auto arg = cc.newXmmSd();
  169. double(*logFunc)(double) = ::log;
  170. cc.comment("call log");
  171. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  172. call->setArg(0, a);
  173. call->setRet(0, arg);
  174. return arg;
  175. }
  176. };
  177. static void printD(double d) {
  178. printf("val: %f\n", d); fflush(stdout);
  179. }
  180. CompiledGenerator compile(const ir::Formula& formula)
  181. {
  182. using namespace asmjit;
  183. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  184. JitRuntime& jitRuntime = *ed->jitRuntime;
  185. ed->code->init(jitRuntime.codeInfo());
  186. x86::Compiler& comp = *ed->compiler;
  187. x86::Mem sixteen = comp.newDoubleConst(ConstPool::kScopeLocal, 16.0);
  188. Label startLoop = comp.newLabel();
  189. Label endLoop = comp.newLabel();
  190. x86::Gp maxIter = comp.newInt32();
  191. x86::Gp k = comp.newInt32();
  192. x86::Xmm x = comp.newXmmSd();
  193. x86::Xmm y = comp.newXmmSd();
  194. x86::Xmm a = comp.newXmmSd();
  195. x86::Xmm b = comp.newXmmSd();
  196. comp.addFunc(FuncSignatureT<int, double, double, int>(CallConv::kIdHost));
  197. comp.setArg(0, x);
  198. comp.setArg(1, y);
  199. CompileVisitor formVisitor{ comp, a, b, x, y };
  200. auto startA = std::visit(formVisitor, *formula.startA);
  201. auto startB = std::visit(formVisitor, *formula.startB);
  202. comp.movapd(a, startA);
  203. comp.movapd(b, startB);
  204. comp.setArg(2, maxIter);
  205. //comp.movapd(a, x);
  206. //comp.movapd(b, y);
  207. //comp.xorpd(a, a);
  208. //comp.movapd(b, b);
  209. comp.xor_(k, k);
  210. comp.bind(startLoop);
  211. auto newA = std::visit(formVisitor, *formula.newA);
  212. auto newB = std::visit(formVisitor, *formula.newB);
  213. comp.movapd(a, newA);
  214. comp.movapd(b, newB);
  215. x86::Xmm aa = comp.newXmmSd();
  216. x86::Xmm bb = comp.newXmmSd();
  217. comp.movapd(aa, a);
  218. comp.mulsd(aa, a);
  219. comp.movapd(bb, b);
  220. comp.mulsd(bb, b);
  221. comp.addsd(bb, aa);
  222. //auto call = comp.call(imm(printD), FuncSignatureT<void, double>(CallConv::kIdHost));
  223. //call->setArg(0, bb);
  224. comp.comisd(bb, sixteen);
  225. comp.jnb(endLoop);
  226. comp.inc(k);
  227. comp.cmp(k, maxIter);
  228. comp.jle(startLoop);
  229. comp.bind(endLoop);
  230. comp.ret(k);
  231. comp.endFunc();
  232. auto err = comp.finalize();
  233. if (err == asmjit::kErrorOk) {
  234. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  235. if (err != asmjit::kErrorOk)
  236. throw "error adding function";
  237. }
  238. else {
  239. throw "error compiling";
  240. }
  241. return CompiledGenerator{ std::move(ed) };
  242. }
  243. struct CompileVisitorAVXFloat
  244. {
  245. using Reg = asmjit::x86::Ymm;
  246. asmjit::x86::Compiler& cc;
  247. Reg& a;
  248. Reg& b;
  249. Reg& x;
  250. Reg& y;
  251. Reg visitNode(ir::Node& node)
  252. {
  253. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  254. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  255. return *regPtr;
  256. }
  257. else {
  258. Reg reg = std::visit(*this, node);
  259. nodeData = reg;
  260. return reg;
  261. }
  262. }
  263. CompileVisitorAVXFloat(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  264. cc{ cc },
  265. a{ a }, b{ b },
  266. x{ x }, y{ y }
  267. {
  268. }
  269. Reg operator()(const ir::Constant& c) {
  270. auto constant = cc.newFloatConst(asmjit::ConstPool::kScopeLocal, mnd::convert<float>(c.value));
  271. auto reg = cc.newYmmPs();
  272. std::string commentStr = "move constant [";
  273. commentStr += std::to_string(mnd::convert<double>(c.value));
  274. commentStr += "]";
  275. cc.comment(commentStr.c_str());
  276. cc.vbroadcastss(reg, constant);
  277. return reg;
  278. }
  279. Reg operator()(const ir::Variable& v) {
  280. if (v.name == "z_re") {
  281. return a;
  282. }
  283. else if (v.name == "z_im") {
  284. return b;
  285. }
  286. else if (v.name == "c_re") {
  287. return x;
  288. }
  289. else if (v.name == "c_im") {
  290. return y;
  291. }
  292. else
  293. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  294. }
  295. Reg operator()(const ir::Negation& n) {
  296. auto sub = cc.newYmmPs();
  297. cc.vxorps(sub, sub, sub);
  298. cc.vsubps(sub, sub, visitNode(*n.value));
  299. return sub;
  300. }
  301. Reg operator()(const ir::Addition& add) {
  302. auto res = cc.newYmmPs();
  303. cc.vaddps(res, visitNode(*add.left), visitNode(*add.right));
  304. return res;
  305. }
  306. Reg operator()(const ir::Subtraction& add) {
  307. auto res = cc.newYmmPs();
  308. cc.vsubps(res, visitNode(*add.left), visitNode(*add.right));
  309. return res;
  310. }
  311. Reg operator()(const ir::Multiplication& add) {
  312. auto res = cc.newYmmPs();
  313. cc.vmulps(res, visitNode(*add.left), visitNode(*add.right));
  314. return res;
  315. }
  316. Reg operator()(const ir::Division& add) {
  317. auto res = cc.newYmmPs();
  318. cc.vdivps(res, visitNode(*add.left), visitNode(*add.right));
  319. return res;
  320. }
  321. static double myAtan2(double y, double x)
  322. {
  323. double result = ::atan2(y, x);
  324. printf("atan2(%f, %f) = %f\n", y, x, result);
  325. return result;
  326. }
  327. Reg operator()(const ir::Atan2& at2) {
  328. using namespace asmjit;
  329. auto y = visitNode(*at2.left);
  330. auto x = visitNode(*at2.right);
  331. auto arg = cc.newYmmPs();
  332. /*
  333. double(*atanFunc)(double, double) = ::atan2;
  334. cc.comment("call atan2");
  335. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  336. call->setArg(0, y);
  337. call->setArg(1, x);
  338. call->setRet(0, arg);*/
  339. return arg;
  340. }
  341. Reg operator()(const ir::Pow& p) {
  342. using namespace asmjit;
  343. auto a = visitNode(*p.left);
  344. auto b = visitNode(*p.right);
  345. auto arg = cc.newYmmPs();
  346. /*double(*powFunc)(double, double) = ::pow;
  347. cc.comment("call pow");
  348. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  349. call->setArg(0, a);
  350. call->setArg(1, b);
  351. call->setRet(0, arg);*/
  352. return arg;
  353. }
  354. Reg operator()(const ir::Cos& c) {
  355. using namespace asmjit;
  356. auto a = visitNode(*c.value);
  357. auto arg = cc.newYmmPs();
  358. /*double(*cosFunc)(double) = ::cos;
  359. cc.comment("call cos");
  360. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  361. call->setArg(0, a);
  362. call->setRet(0, arg);*/
  363. return arg;
  364. }
  365. Reg operator()(const ir::Sin& s) {
  366. using namespace asmjit;
  367. auto a = visitNode(*s.value);
  368. auto arg = cc.newYmmPs();
  369. /*double(*sinFunc)(double) = ::sin;
  370. cc.comment("call sin");
  371. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  372. call->setArg(0, a);
  373. call->setRet(0, arg);*/
  374. return arg;
  375. }
  376. Reg operator()(const ir::Exp& ex) {
  377. using namespace asmjit;
  378. auto a = visitNode(*ex.value);
  379. auto arg = cc.newYmmPs();
  380. /*double(*expFunc)(double) = ::exp;
  381. cc.comment("call exp");
  382. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  383. call->setArg(0, a);
  384. call->setRet(0, arg);*/
  385. return arg;
  386. }
  387. Reg operator()(const ir::Ln& l) {
  388. using namespace asmjit;
  389. auto a = visitNode(*l.value);
  390. auto arg = cc.newYmmPs();
  391. /*double(*logFunc)(double) = ::log;
  392. cc.comment("call log");
  393. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  394. call->setArg(0, a);
  395. call->setRet(0, arg);*/
  396. return arg;
  397. }
  398. };
  399. CompiledGeneratorVec compileAVXFloat(const ir::Formula& formula)
  400. {
  401. using namespace asmjit;
  402. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  403. JitRuntime& jitRuntime = *ed->jitRuntime;
  404. ed->code->init(jitRuntime.codeInfo());
  405. x86::Compiler& comp = *ed->compiler;
  406. x86::Mem sixteen = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(16.0f));
  407. x86::Mem one = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(1.0f));
  408. x86::Mem factors = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(0, 1, 2, 3, 4, 5, 6, 7));
  409. Label startLoop = comp.newLabel();
  410. Label endLoop = comp.newLabel();
  411. x86::Gp maxIter = comp.newInt32();
  412. x86::Gp k = comp.newInt32();
  413. x86::Gp resPtr = comp.newGpq();
  414. x86::Ymm adder = comp.newYmmPs();
  415. x86::Ymm counter = comp.newYmmPs();
  416. x86::Xmm xorig = comp.newXmmSs();
  417. x86::Xmm yorig = comp.newXmmSs();
  418. x86::Ymm dx = comp.newYmmPs();
  419. x86::Ymm x = comp.newYmmPs();
  420. x86::Ymm y = comp.newYmmPs();
  421. x86::Ymm a = comp.newYmmPs();
  422. x86::Ymm b = comp.newYmmPs();
  423. comp.addFunc(FuncSignatureT<int, float, float, float, int, float*>(CallConv::kIdHost));
  424. comp.setArg(0, xorig);
  425. comp.setArg(1, yorig);
  426. comp.setArg(2, dx.xmm());
  427. comp.setArg(3, maxIter);
  428. comp.setArg(4, resPtr);
  429. comp.vmovaps(adder, one);
  430. comp.vxorps(counter, counter, counter);
  431. comp.vshufps(xorig, xorig, xorig, 0);
  432. comp.vshufps(yorig, yorig, yorig, 0);
  433. comp.vshufps(dx.half(), dx.half(), dx.half(), 0);
  434. comp.vinsertf128(x, xorig.ymm(), xorig, 1);
  435. comp.vinsertf128(y, yorig.ymm(), yorig, 1);
  436. comp.vinsertf128(dx, dx, dx.xmm(), 1);
  437. comp.vmulps(dx, dx, factors);
  438. comp.vaddps(x, x, dx);
  439. CompileVisitorAVXFloat formVisitor{ comp, a, b, x, y };
  440. auto startA = std::visit(formVisitor, *formula.startA);
  441. auto startB = std::visit(formVisitor, *formula.startB);
  442. comp.vmovaps(a, startA);
  443. comp.vmovaps(b, startB);
  444. comp.xor_(k, k);
  445. comp.bind(startLoop);
  446. auto newA = std::visit(formVisitor, *formula.newA);
  447. auto newB = std::visit(formVisitor, *formula.newB);
  448. comp.vmovaps(a, newA);
  449. comp.vmovaps(b, newB);
  450. x86::Ymm aa = comp.newYmmPs();
  451. x86::Ymm bb = comp.newYmmPs();
  452. x86::Ymm cmp = comp.newYmmPs();
  453. comp.vmulps(aa, a, a);
  454. comp.vmulps(bb, b, b);
  455. comp.vaddps(bb, bb, aa);
  456. comp.vcmpps(cmp, bb, sixteen, 18);
  457. comp.vandps(adder, adder, cmp);
  458. comp.vaddps(counter, counter, adder);
  459. comp.cmp(k, maxIter);
  460. comp.je(endLoop);
  461. comp.add(k, 1);
  462. comp.vtestps(cmp, cmp);
  463. comp.jne(startLoop);
  464. comp.bind(endLoop);
  465. comp.vmovups(x86::xmmword_ptr(resPtr), counter.half());
  466. comp.vextractf128(x86::xmmword_ptr(resPtr, 16), counter, 0x1);
  467. comp.ret(k);
  468. comp.endFunc();
  469. auto err = comp.finalize();
  470. if (err == asmjit::kErrorOk) {
  471. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  472. if (err != asmjit::kErrorOk)
  473. throw "error adding function";
  474. }
  475. else {
  476. throw "error compiling";
  477. }
  478. return CompiledGeneratorVec{ std::move(ed) };
  479. }
  480. #endif // defined(__x86_64__) || defined(_M_X64)
  481. #endif // WITH_ASMJIT
  482. struct OpenClVisitor
  483. {
  484. int varnameCounter = 0;
  485. std::stringstream code;
  486. std::string floatTypeName;
  487. OpenClVisitor(int startVarname, const std::string& floatTypeName) :
  488. varnameCounter{ startVarname },
  489. floatTypeName{ floatTypeName }
  490. {
  491. }
  492. std::string createVarname(void)
  493. {
  494. return "tmp"s + std::to_string(varnameCounter++);
  495. }
  496. std::string visitNode(ir::Node& node)
  497. {
  498. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  499. if (std::string* var = std::any_cast<std::string>(&nodeData)) {
  500. return *var;
  501. }
  502. else {
  503. std::string value = std::visit(*this, node);
  504. if (!std::get_if<ir::Variable>(&node) && !std::get_if<ir::Constant>(&node)) {
  505. std::string varname = createVarname();
  506. code << floatTypeName << " " << varname << " = " << value << ";" << std::endl;
  507. nodeData = varname;
  508. return varname;
  509. }
  510. return value;
  511. }
  512. }
  513. std::string operator()(const ir::Constant& c) {
  514. return std::to_string(mnd::convert<double>(c.value)) + ((floatTypeName == "float") ? "f" : "");
  515. }
  516. std::string operator()(const ir::Variable& v) {
  517. return v.name;
  518. }
  519. std::string operator()(const ir::Negation& n) {
  520. return "-("s + visitNode(*n.value) + ")";
  521. }
  522. std::string operator()(const ir::Addition& a) {
  523. return "("s + visitNode(*a.left) + ") + (" + visitNode(*a.right) + ")";
  524. }
  525. std::string operator()(const ir::Subtraction& a) {
  526. return "("s + visitNode(*a.left) + ") - (" + visitNode(*a.right) + ")";
  527. }
  528. std::string operator()(const ir::Multiplication& a) {
  529. return "("s + visitNode(*a.left) + ") * (" + visitNode(*a.right) + ")";
  530. }
  531. std::string operator()(const ir::Division& a) {
  532. return "("s + visitNode(*a.left) + ") / (" + visitNode(*a.right) + ")";
  533. }
  534. std::string operator()(const ir::Atan2& a) {
  535. return "atan2("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  536. }
  537. std::string operator()(const ir::Pow& a) {
  538. return "pow("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  539. }
  540. std::string operator()(const ir::Cos& a) {
  541. return "cos("s + visitNode(*a.value) + ")";
  542. }
  543. std::string operator()(const ir::Sin& a) {
  544. return "sin("s + visitNode(*a.value) + ")";
  545. }
  546. std::string operator()(const ir::Exp& a) {
  547. return "exp("s + visitNode(*a.value) + ")";
  548. }
  549. std::string operator()(const ir::Ln& a) {
  550. return "log("s + visitNode(*a.value) + ")";
  551. }
  552. };
  553. std::string compileToOpenCl(const ir::Formula& formula)
  554. {
  555. OpenClVisitor z0Visitor{ 0, "float" };
  556. std::string startA = z0Visitor.visitNode(*formula.startA);
  557. std::string startB = z0Visitor.visitNode(*formula.startB);
  558. OpenClVisitor ocv{ z0Visitor.varnameCounter, "float" };
  559. std::string newA = ocv.visitNode(*formula.newA);
  560. std::string newB = ocv.visitNode(*formula.newB);
  561. std::string prelude =
  562. "__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
  563. " int index = get_global_id(0);\n"
  564. " int ix = index % width;\n"
  565. " int iy = index / width;\n"
  566. " float c_re = ix * pixelScaleX + xl;\n"
  567. " float c_im = iy * pixelScaleY + yt;\n";
  568. prelude += z0Visitor.code.str() +
  569. " float z_re = " + startA + ";\n" +
  570. " float z_im = " + startB + ";\n" +
  571. "\n"
  572. " int n = 0;\n"
  573. " while (n < max - 1) {\n";
  574. /*
  575. std::string orig =
  576. " float aa = a * a;"
  577. " float bb = b * b;"
  578. " float ab = a * b;"
  579. " a = aa - bb + x;"
  580. " b = ab + ab + y;";
  581. */
  582. std::string after =
  583. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  584. " n++;\n"
  585. " }\n"
  586. " if (n >= max - 1) {\n"
  587. " A[index] = max;\n"
  588. " }\n"
  589. " else {\n"
  590. " A[index] = ((float)n);\n"
  591. " }\n"
  592. "}\n";
  593. std::string code = prelude + ocv.code.str();
  594. code += "z_re = " + newA + ";\n";
  595. code += "z_im = " + newB + ";\n";
  596. code += after;
  597. //code = mnd::getFloat_cl();
  598. printf("cl: %s\n", code.c_str()); fflush(stdout);
  599. return code;
  600. }
  601. std::string compileToOpenClDouble(const ir::Formula& formula)
  602. {
  603. OpenClVisitor z0Visitor{ 0, "double" };
  604. std::string startA = z0Visitor.visitNode(*formula.startA);
  605. std::string startB = z0Visitor.visitNode(*formula.startB);
  606. OpenClVisitor ocv{ z0Visitor.varnameCounter, "double" };
  607. std::string newA = ocv.visitNode(*formula.newA);
  608. std::string newB = ocv.visitNode(*formula.newB);
  609. std::string prelude =
  610. "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
  611. "__kernel void iterate(__global float* A, const int width, double xl, double yt, double pixelScaleX, double pixelScaleY, int max, int smooth, int julia, double juliaX, double juliaY) {\n"
  612. " int index = get_global_id(0);\n"
  613. " int ix = index % width;\n"
  614. " int iy = index / width;\n"
  615. " double c_re = ix * pixelScaleX + xl;\n"
  616. " double c_im = iy * pixelScaleY + yt;\n";
  617. prelude += z0Visitor.code.str() +
  618. " double z_re = " + startA + ";\n" +
  619. " double z_im = " + startB + ";\n" +
  620. "\n"
  621. " int n = 0;\n"
  622. " while (n < max - 1) {\n";
  623. std::string after =
  624. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  625. " n++;\n"
  626. " }\n"
  627. " if (n >= max - 1) {\n"
  628. " A[index] = max;\n"
  629. " }\n"
  630. " else {\n"
  631. " A[index] = ((float)n);\n"
  632. " }\n"
  633. "}\n";
  634. std::string code = prelude + ocv.code.str();
  635. code += "z_re = " + newA + ";\n";
  636. code += "z_im = " + newB + ";\n";
  637. code += after;
  638. //code = mnd::getFloat_cl();
  639. printf("cld: %s\n", code.c_str()); fflush(stdout);
  640. return code;
  641. }
  642. #ifdef WITH_OPENCL
  643. std::unique_ptr<MandelGenerator> compileCl(const ir::Formula& formula, MandelDevice& md)
  644. {
  645. return std::make_unique<CompiledClGenerator>(md, compileToOpenCl(formula));
  646. }
  647. std::unique_ptr<MandelGenerator> compileClDouble(const ir::Formula& formula, MandelDevice& md)
  648. {
  649. return std::make_unique<CompiledClGeneratorDouble>(md, compileToOpenClDouble(formula));
  650. }
  651. #endif
  652. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileCpu(mnd::MandelContext& mndCtxt,
  653. const IterationFormula& z0,
  654. const IterationFormula& zi)
  655. {
  656. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  657. IterationFormula z0o = z0.clone();
  658. IterationFormula zio = zi.clone();
  659. z0o.optimize();
  660. zio.optimize();
  661. ir::Formula irf = mnd::expand(z0o, zio);
  662. irf.optimize();
  663. #ifdef WITH_ASMJIT
  664. #if defined(__x86_64__) || defined(_M_X64)
  665. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  666. auto dg = std::make_unique<CompiledGenerator>(compile(irf));
  667. printf("asm: %s\n", dg->dump().c_str()); fflush(stdout);
  668. vec.push_back(std::move(dg));
  669. if (mndCtxt.getCpuInfo().hasAvx()) {
  670. auto dgavx = std::make_unique<CompiledGeneratorVec>(compileAVXFloat(irf));
  671. printf("asm avxvec: %s\n", dgavx->dump().c_str()); fflush(stdout);
  672. vec.push_back(std::move(dgavx));
  673. }
  674. #endif // defined(__x86_64__) || defined(_M_X64)
  675. #endif // WITH_ASMJIT
  676. return vec;
  677. }
  678. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileOpenCl(mnd::MandelDevice& dev,
  679. const IterationFormula& z0,
  680. const IterationFormula& zi)
  681. {
  682. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  683. IterationFormula z0o = z0.clone();
  684. IterationFormula zio = zi.clone();
  685. z0o.optimize();
  686. zio.optimize();
  687. printf("if: %s\n", mnd::toString(*zio.expr).c_str()); fflush(stdout);
  688. ir::Formula irf = mnd::expand(z0o, zio);
  689. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  690. irf.optimize();
  691. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  692. auto fl = compileCl(irf, dev);
  693. vec.push_back(std::move(fl));
  694. if (dev.supportsDouble()) {
  695. irf.clearNodeData();
  696. auto fld = compileClDouble(irf, dev);
  697. vec.push_back(std::move(fld));
  698. }
  699. return vec;// { { mnd::GeneratorType::FLOAT, std::move(fl) } };
  700. }
  701. GeneratorCollection compileFormula(mnd::MandelContext& mndCtxt, const IterationFormula& z0,
  702. const IterationFormula& zi)
  703. {
  704. GeneratorCollection cr;
  705. cr.cpuGenerators = compileCpu(mndCtxt, z0, zi);
  706. for (auto& dev : mndCtxt.getDevices()) {
  707. auto gens = compileOpenCl(*dev, z0, zi);
  708. std::move(gens.begin(), gens.end(), std::back_inserter(cr.clGenerators));
  709. }
  710. cr.adaptiveGenerator = std::make_unique<mnd::AdaptiveGenerator>();
  711. if (!cr.clGenerators.empty()) {
  712. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<float>(), *cr.clGenerators[0]);
  713. }
  714. if (!cr.cpuGenerators.empty()) {
  715. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<double>(), *cr.cpuGenerators[0]);
  716. }
  717. return cr;
  718. }
  719. }
  720. mnd::GeneratorCollection::GeneratorCollection(void) :
  721. cpuGenerators{},
  722. clGenerators{},
  723. adaptiveGenerator{ nullptr }
  724. {
  725. }