IterationCompiler.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. #include "IterationCompiler.h"
  2. #include "NaiveIRGenerator.h"
  3. #include "ExecData.h"
  4. #include "Mandel.h"
  5. #include "OpenClInternal.h"
  6. #include "OpenClCode.h"
  7. #include <cmath>
  8. #include <omp.h>
  9. #include <any>
  10. #include <string>
  11. using namespace std::string_literals;
  12. namespace mnd
  13. {
  14. struct CompileVisitor
  15. {
  16. using Reg = asmjit::x86::Xmm;
  17. asmjit::x86::Compiler& cc;
  18. Reg& a;
  19. Reg& b;
  20. Reg& x;
  21. Reg& y;
  22. Reg visitNode(ir::Node& node)
  23. {
  24. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  25. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  26. return *regPtr;
  27. }
  28. else {
  29. Reg reg = std::visit(*this, node);
  30. nodeData = reg;
  31. return reg;
  32. }
  33. }
  34. CompileVisitor(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  35. cc{ cc },
  36. a{ a }, b{ b },
  37. x{ x }, y{ y }
  38. {
  39. }
  40. Reg operator()(const ir::Constant& c) {
  41. auto constant = cc.newDoubleConst(asmjit::ConstPool::kScopeLocal, mnd::convert<double>(c.value));
  42. auto reg = cc.newXmmSd();
  43. std::string commentStr = "move constant [";
  44. commentStr += std::to_string(mnd::convert<double>(c.value));
  45. commentStr += "]";
  46. cc.comment(commentStr.c_str());
  47. cc.movsd(reg, constant);
  48. return reg;
  49. }
  50. Reg operator()(const ir::Variable& v) {
  51. if (v.name == "z_re") {
  52. return a;
  53. }
  54. else if (v.name == "z_im") {
  55. return b;
  56. }
  57. else if (v.name == "c_re") {
  58. return x;
  59. }
  60. else if (v.name == "c_im") {
  61. return y;
  62. }
  63. else
  64. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  65. }
  66. Reg operator()(const ir::Negation& n) {
  67. auto sub = cc.newXmmSd();
  68. cc.xorpd(sub, sub);
  69. cc.subsd(sub, visitNode(*n.value));
  70. return sub;
  71. }
  72. Reg operator()(const ir::Addition& add) {
  73. auto res = cc.newXmmSd();
  74. cc.movapd(res, visitNode(*add.left));
  75. cc.addsd(res, visitNode(*add.right));
  76. return res;
  77. }
  78. Reg operator()(const ir::Subtraction& add) {
  79. auto res = cc.newXmmSd();
  80. cc.movapd(res, visitNode(*add.left));
  81. cc.subsd(res, visitNode(*add.right));
  82. return res;
  83. }
  84. Reg operator()(const ir::Multiplication& add) {
  85. auto res = cc.newXmmSd();
  86. cc.movapd(res, visitNode(*add.left));
  87. cc.mulsd(res, visitNode(*add.right));
  88. return res;
  89. }
  90. Reg operator()(const ir::Division& add) {
  91. auto res = cc.newXmmSd();
  92. cc.movapd(res, visitNode(*add.left));
  93. cc.divsd(res, visitNode(*add.right));
  94. return res;
  95. }
  96. static double myAtan2(double y, double x)
  97. {
  98. double result = ::atan2(y, x);
  99. printf("atan2(%f, %f) = %f\n", y, x, result);
  100. return result;
  101. }
  102. Reg operator()(const ir::Atan2& at2) {
  103. using namespace asmjit;
  104. auto y = visitNode(*at2.left);
  105. auto x = visitNode(*at2.right);
  106. auto arg = cc.newXmmSd();
  107. double(*atanFunc)(double, double) = ::atan2;
  108. cc.comment("call atan2");
  109. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  110. call->setArg(0, y);
  111. call->setArg(1, x);
  112. call->setRet(0, arg);
  113. return arg;
  114. }
  115. Reg operator()(const ir::Pow& p) {
  116. using namespace asmjit;
  117. auto a = visitNode(*p.left);
  118. auto b = visitNode(*p.right);
  119. auto arg = cc.newXmmSd();
  120. double(*powFunc)(double, double) = ::pow;
  121. cc.comment("call pow");
  122. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  123. call->setArg(0, a);
  124. call->setArg(1, b);
  125. call->setRet(0, arg);
  126. return arg;
  127. }
  128. Reg operator()(const ir::Cos& c) {
  129. using namespace asmjit;
  130. auto a = visitNode(*c.value);
  131. auto arg = cc.newXmmSd();
  132. double(*cosFunc)(double) = ::cos;
  133. cc.comment("call cos");
  134. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  135. call->setArg(0, a);
  136. call->setRet(0, arg);
  137. return arg;
  138. }
  139. Reg operator()(const ir::Sin& s) {
  140. using namespace asmjit;
  141. auto a = visitNode(*s.value);
  142. auto arg = cc.newXmmSd();
  143. double(*sinFunc)(double) = ::sin;
  144. cc.comment("call sin");
  145. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  146. call->setArg(0, a);
  147. call->setRet(0, arg);
  148. return arg;
  149. }
  150. Reg operator()(const ir::Exp& ex) {
  151. using namespace asmjit;
  152. auto a = visitNode(*ex.value);
  153. auto arg = cc.newXmmSd();
  154. double(*expFunc)(double) = ::exp;
  155. cc.comment("call exp");
  156. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  157. call->setArg(0, a);
  158. call->setRet(0, arg);
  159. return arg;
  160. }
  161. Reg operator()(const ir::Ln& l) {
  162. using namespace asmjit;
  163. auto a = visitNode(*l.value);
  164. auto arg = cc.newXmmSd();
  165. double(*logFunc)(double) = ::log;
  166. cc.comment("call log");
  167. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  168. call->setArg(0, a);
  169. call->setRet(0, arg);
  170. return arg;
  171. }
  172. };
  173. static void printD(double d) {
  174. printf("val: %f\n", d); fflush(stdout);
  175. }
  176. CompiledGenerator compile(const ir::Formula& formula)
  177. {
  178. using namespace asmjit;
  179. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  180. JitRuntime& jitRuntime = *ed->jitRuntime;
  181. ed->code->init(jitRuntime.codeInfo());
  182. x86::Compiler& comp = *ed->compiler;
  183. x86::Mem sixteen = comp.newDoubleConst(ConstPool::kScopeLocal, 16.0);
  184. Label startLoop = comp.newLabel();
  185. Label endLoop = comp.newLabel();
  186. x86::Gp maxIter = comp.newInt32();
  187. x86::Gp k = comp.newInt32();
  188. x86::Xmm x = comp.newXmmSd();
  189. x86::Xmm y = comp.newXmmSd();
  190. x86::Xmm a = comp.newXmmSd();
  191. x86::Xmm b = comp.newXmmSd();
  192. comp.addFunc(FuncSignatureT<int, double, double, int>(CallConv::kIdHost));
  193. comp.setArg(0, x);
  194. comp.setArg(1, y);
  195. CompileVisitor formVisitor{ comp, a, b, x, y };
  196. auto startA = std::visit(formVisitor, *formula.startA);
  197. auto startB = std::visit(formVisitor, *formula.startB);
  198. comp.movapd(a, startA);
  199. comp.movapd(b, startB);
  200. comp.setArg(2, maxIter);
  201. //comp.movapd(a, x);
  202. //comp.movapd(b, y);
  203. //comp.xorpd(a, a);
  204. //comp.movapd(b, b);
  205. comp.xor_(k, k);
  206. comp.bind(startLoop);
  207. auto newA = std::visit(formVisitor, *formula.newA);
  208. auto newB = std::visit(formVisitor, *formula.newB);
  209. comp.movapd(a, newA);
  210. comp.movapd(b, newB);
  211. x86::Xmm aa = comp.newXmmSd();
  212. x86::Xmm bb = comp.newXmmSd();
  213. comp.movapd(aa, a);
  214. comp.mulsd(aa, a);
  215. comp.movapd(bb, b);
  216. comp.mulsd(bb, b);
  217. comp.addsd(bb, aa);
  218. //auto call = comp.call(imm(printD), FuncSignatureT<void, double>(CallConv::kIdHost));
  219. //call->setArg(0, bb);
  220. comp.comisd(bb, sixteen);
  221. comp.jnb(endLoop);
  222. comp.inc(k);
  223. comp.cmp(k, maxIter);
  224. comp.jle(startLoop);
  225. comp.bind(endLoop);
  226. comp.ret(k);
  227. comp.endFunc();
  228. auto err = comp.finalize();
  229. if (err == asmjit::kErrorOk) {
  230. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  231. if (err != asmjit::kErrorOk)
  232. throw "error adding function";
  233. }
  234. else {
  235. throw "error compiling";
  236. }
  237. return CompiledGenerator{ std::move(ed) };
  238. }
  239. struct CompileVisitorAVXFloat
  240. {
  241. using Reg = asmjit::x86::Ymm;
  242. asmjit::x86::Compiler& cc;
  243. Reg& a;
  244. Reg& b;
  245. Reg& x;
  246. Reg& y;
  247. Reg visitNode(ir::Node& node)
  248. {
  249. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  250. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  251. return *regPtr;
  252. }
  253. else {
  254. Reg reg = std::visit(*this, node);
  255. nodeData = reg;
  256. return reg;
  257. }
  258. }
  259. CompileVisitorAVXFloat(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  260. cc{ cc },
  261. a{ a }, b{ b },
  262. x{ x }, y{ y }
  263. {
  264. }
  265. Reg operator()(const ir::Constant& c) {
  266. auto constant = cc.newFloatConst(asmjit::ConstPool::kScopeLocal, mnd::convert<float>(c.value));
  267. auto reg = cc.newYmmPs();
  268. std::string commentStr = "move constant [";
  269. commentStr += std::to_string(mnd::convert<double>(c.value));
  270. commentStr += "]";
  271. cc.comment(commentStr.c_str());
  272. cc.vbroadcastss(reg, constant);
  273. return reg;
  274. }
  275. Reg operator()(const ir::Variable& v) {
  276. if (v.name == "z_re") {
  277. return a;
  278. }
  279. else if (v.name == "z_im") {
  280. return b;
  281. }
  282. else if (v.name == "c_re") {
  283. return x;
  284. }
  285. else if (v.name == "c_im") {
  286. return y;
  287. }
  288. else
  289. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  290. }
  291. Reg operator()(const ir::Negation& n) {
  292. auto sub = cc.newYmmPs();
  293. cc.vxorps(sub, sub, sub);
  294. cc.vsubps(sub, sub, visitNode(*n.value));
  295. return sub;
  296. }
  297. Reg operator()(const ir::Addition& add) {
  298. auto res = cc.newYmmPs();
  299. cc.vaddps(res, visitNode(*add.left), visitNode(*add.right));
  300. return res;
  301. }
  302. Reg operator()(const ir::Subtraction& add) {
  303. auto res = cc.newYmmPs();
  304. cc.vsubps(res, visitNode(*add.left), visitNode(*add.right));
  305. return res;
  306. }
  307. Reg operator()(const ir::Multiplication& add) {
  308. auto res = cc.newYmmPs();
  309. cc.vmulps(res, visitNode(*add.left), visitNode(*add.right));
  310. return res;
  311. }
  312. Reg operator()(const ir::Division& add) {
  313. auto res = cc.newYmmPs();
  314. cc.vdivps(res, visitNode(*add.left), visitNode(*add.right));
  315. return res;
  316. }
  317. static double myAtan2(double y, double x)
  318. {
  319. double result = ::atan2(y, x);
  320. printf("atan2(%f, %f) = %f\n", y, x, result);
  321. return result;
  322. }
  323. Reg operator()(const ir::Atan2& at2) {
  324. using namespace asmjit;
  325. auto y = visitNode(*at2.left);
  326. auto x = visitNode(*at2.right);
  327. auto arg = cc.newYmmPs();
  328. /*
  329. double(*atanFunc)(double, double) = ::atan2;
  330. cc.comment("call atan2");
  331. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  332. call->setArg(0, y);
  333. call->setArg(1, x);
  334. call->setRet(0, arg);*/
  335. return arg;
  336. }
  337. Reg operator()(const ir::Pow& p) {
  338. using namespace asmjit;
  339. auto a = visitNode(*p.left);
  340. auto b = visitNode(*p.right);
  341. auto arg = cc.newYmmPs();
  342. /*double(*powFunc)(double, double) = ::pow;
  343. cc.comment("call pow");
  344. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  345. call->setArg(0, a);
  346. call->setArg(1, b);
  347. call->setRet(0, arg);*/
  348. return arg;
  349. }
  350. Reg operator()(const ir::Cos& c) {
  351. using namespace asmjit;
  352. auto a = visitNode(*c.value);
  353. auto arg = cc.newYmmPs();
  354. /*double(*cosFunc)(double) = ::cos;
  355. cc.comment("call cos");
  356. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  357. call->setArg(0, a);
  358. call->setRet(0, arg);*/
  359. return arg;
  360. }
  361. Reg operator()(const ir::Sin& s) {
  362. using namespace asmjit;
  363. auto a = visitNode(*s.value);
  364. auto arg = cc.newYmmPs();
  365. /*double(*sinFunc)(double) = ::sin;
  366. cc.comment("call sin");
  367. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  368. call->setArg(0, a);
  369. call->setRet(0, arg);*/
  370. return arg;
  371. }
  372. Reg operator()(const ir::Exp& ex) {
  373. using namespace asmjit;
  374. auto a = visitNode(*ex.value);
  375. auto arg = cc.newYmmPs();
  376. /*double(*expFunc)(double) = ::exp;
  377. cc.comment("call exp");
  378. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  379. call->setArg(0, a);
  380. call->setRet(0, arg);*/
  381. return arg;
  382. }
  383. Reg operator()(const ir::Ln& l) {
  384. using namespace asmjit;
  385. auto a = visitNode(*l.value);
  386. auto arg = cc.newYmmPs();
  387. /*double(*logFunc)(double) = ::log;
  388. cc.comment("call log");
  389. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  390. call->setArg(0, a);
  391. call->setRet(0, arg);*/
  392. return arg;
  393. }
  394. };
  395. CompiledGeneratorVec compileAVXFloat(const ir::Formula& formula)
  396. {
  397. using namespace asmjit;
  398. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  399. JitRuntime& jitRuntime = *ed->jitRuntime;
  400. ed->code->init(jitRuntime.codeInfo());
  401. x86::Compiler& comp = *ed->compiler;
  402. x86::Mem sixteen = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(16.0f));
  403. x86::Mem one = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(1.0f));
  404. x86::Mem factors = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(0, 1, 2, 3, 4, 5, 6, 7));
  405. Label startLoop = comp.newLabel();
  406. Label endLoop = comp.newLabel();
  407. x86::Gp maxIter = comp.newInt32();
  408. x86::Gp k = comp.newInt32();
  409. x86::Gp resPtr = comp.newGpq();
  410. x86::Ymm adder = comp.newYmmPs();
  411. x86::Ymm counter = comp.newYmmPs();
  412. x86::Xmm xorig = comp.newXmmSs();
  413. x86::Xmm yorig = comp.newXmmSs();
  414. x86::Ymm dx = comp.newYmmPs();
  415. x86::Ymm x = comp.newYmmPs();
  416. x86::Ymm y = comp.newYmmPs();
  417. x86::Ymm a = comp.newYmmPs();
  418. x86::Ymm b = comp.newYmmPs();
  419. comp.addFunc(FuncSignatureT<int, float, float, float, int, float*>(CallConv::kIdHost));
  420. comp.setArg(0, xorig);
  421. comp.setArg(1, yorig);
  422. comp.setArg(2, dx.xmm());
  423. comp.setArg(3, maxIter);
  424. comp.setArg(4, resPtr);
  425. comp.vmovaps(adder, one);
  426. comp.vxorps(counter, counter, counter);
  427. comp.vshufps(xorig, xorig, xorig, 0);
  428. comp.vshufps(yorig, yorig, yorig, 0);
  429. comp.vshufps(dx.half(), dx.half(), dx.half(), 0);
  430. comp.vinsertf128(x, xorig.ymm(), xorig, 1);
  431. comp.vinsertf128(y, yorig.ymm(), yorig, 1);
  432. comp.vinsertf128(dx, dx, dx.xmm(), 1);
  433. comp.vmulps(dx, dx, factors);
  434. comp.vaddps(x, x, dx);
  435. CompileVisitorAVXFloat formVisitor{ comp, a, b, x, y };
  436. auto startA = std::visit(formVisitor, *formula.startA);
  437. auto startB = std::visit(formVisitor, *formula.startB);
  438. comp.vmovaps(a, startA);
  439. comp.vmovaps(b, startB);
  440. comp.xor_(k, k);
  441. comp.bind(startLoop);
  442. auto newA = std::visit(formVisitor, *formula.newA);
  443. auto newB = std::visit(formVisitor, *formula.newB);
  444. comp.vmovaps(a, newA);
  445. comp.vmovaps(b, newB);
  446. x86::Ymm aa = comp.newYmmPs();
  447. x86::Ymm bb = comp.newYmmPs();
  448. x86::Ymm cmp = comp.newYmmPs();
  449. comp.vmulps(aa, a, a);
  450. comp.vmulps(bb, b, b);
  451. comp.vaddps(bb, bb, aa);
  452. comp.vcmpps(cmp, bb, sixteen, 18);
  453. comp.vandps(adder, adder, cmp);
  454. comp.vaddps(counter, counter, adder);
  455. comp.cmp(k, maxIter);
  456. comp.je(endLoop);
  457. comp.add(k, 1);
  458. comp.vtestps(cmp, cmp);
  459. comp.jne(startLoop);
  460. comp.bind(endLoop);
  461. comp.vmovups(x86::xmmword_ptr(resPtr), counter.half());
  462. comp.vextractf128(x86::xmmword_ptr(resPtr, 16), counter, 0x1);
  463. comp.ret(k);
  464. comp.endFunc();
  465. auto err = comp.finalize();
  466. if (err == asmjit::kErrorOk) {
  467. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  468. if (err != asmjit::kErrorOk)
  469. throw "error adding function";
  470. }
  471. else {
  472. throw "error compiling";
  473. }
  474. return CompiledGeneratorVec{ std::move(ed) };
  475. }
  476. struct OpenClVisitor
  477. {
  478. int varnameCounter = 0;
  479. std::stringstream code;
  480. std::string floatTypeName;
  481. OpenClVisitor(int startVarname, const std::string& floatTypeName) :
  482. varnameCounter{ startVarname },
  483. floatTypeName{ floatTypeName }
  484. {
  485. }
  486. std::string createVarname(void)
  487. {
  488. return "tmp"s + std::to_string(varnameCounter++);
  489. }
  490. std::string visitNode(ir::Node& node)
  491. {
  492. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  493. if (std::string* var = std::any_cast<std::string>(&nodeData)) {
  494. return *var;
  495. }
  496. else {
  497. std::string value = std::visit(*this, node);
  498. if (!std::get_if<ir::Variable>(&node) && !std::get_if<ir::Constant>(&node)) {
  499. std::string varname = createVarname();
  500. code << floatTypeName << " " << varname << " = " << value << ";" << std::endl;
  501. nodeData = varname;
  502. return varname;
  503. }
  504. return value;
  505. }
  506. }
  507. std::string operator()(const ir::Constant& c) {
  508. return std::to_string(mnd::convert<double>(c.value)) + ((floatTypeName == "float") ? "f" : "");
  509. }
  510. std::string operator()(const ir::Variable& v) {
  511. return v.name;
  512. }
  513. std::string operator()(const ir::Negation& n) {
  514. return "-("s + visitNode(*n.value) + ")";
  515. }
  516. std::string operator()(const ir::Addition& a) {
  517. return "("s + visitNode(*a.left) + ") + (" + visitNode(*a.right) + ")";
  518. }
  519. std::string operator()(const ir::Subtraction& a) {
  520. return "("s + visitNode(*a.left) + ") - (" + visitNode(*a.right) + ")";
  521. }
  522. std::string operator()(const ir::Multiplication& a) {
  523. return "("s + visitNode(*a.left) + ") * (" + visitNode(*a.right) + ")";
  524. }
  525. std::string operator()(const ir::Division& a) {
  526. return "("s + visitNode(*a.left) + ") / (" + visitNode(*a.right) + ")";
  527. }
  528. std::string operator()(const ir::Atan2& a) {
  529. return "atan2("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  530. }
  531. std::string operator()(const ir::Pow& a) {
  532. return "pow("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  533. }
  534. std::string operator()(const ir::Cos& a) {
  535. return "cos("s + visitNode(*a.value) + ")";
  536. }
  537. std::string operator()(const ir::Sin& a) {
  538. return "sin("s + visitNode(*a.value) + ")";
  539. }
  540. std::string operator()(const ir::Exp& a) {
  541. return "exp("s + visitNode(*a.value) + ")";
  542. }
  543. std::string operator()(const ir::Ln& a) {
  544. return "log("s + visitNode(*a.value) + ")";
  545. }
  546. };
  547. std::string compileToOpenCl(const ir::Formula& formula)
  548. {
  549. OpenClVisitor z0Visitor{ 0, "float" };
  550. std::string startA = z0Visitor.visitNode(*formula.startA);
  551. std::string startB = z0Visitor.visitNode(*formula.startB);
  552. OpenClVisitor ocv{ z0Visitor.varnameCounter, "float" };
  553. std::string newA = ocv.visitNode(*formula.newA);
  554. std::string newB = ocv.visitNode(*formula.newB);
  555. std::string prelude =
  556. "__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
  557. " int index = get_global_id(0);\n"
  558. " int ix = index % width;\n"
  559. " int iy = index / width;\n"
  560. " float c_re = ix * pixelScaleX + xl;\n"
  561. " float c_im = iy * pixelScaleY + yt;\n";
  562. prelude += z0Visitor.code.str() +
  563. " float z_re = " + startA + ";\n" +
  564. " float z_im = " + startB + ";\n" +
  565. "\n"
  566. " int n = 0;\n"
  567. " while (n < max - 1) {\n";
  568. /*
  569. std::string orig =
  570. " float aa = a * a;"
  571. " float bb = b * b;"
  572. " float ab = a * b;"
  573. " a = aa - bb + x;"
  574. " b = ab + ab + y;";
  575. */
  576. std::string after =
  577. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  578. " n++;\n"
  579. " }\n"
  580. " if (n >= max - 1) {\n"
  581. " A[index] = max;\n"
  582. " }\n"
  583. " else {\n"
  584. " A[index] = ((float)n);\n"
  585. " }\n"
  586. "}\n";
  587. std::string code = prelude + ocv.code.str();
  588. code += "z_re = " + newA + ";\n";
  589. code += "z_im = " + newB + ";\n";
  590. code += after;
  591. //code = mnd::getFloat_cl();
  592. printf("cl: %s\n", code.c_str()); fflush(stdout);
  593. return code;
  594. }
  595. std::string compileToOpenClDouble(const ir::Formula& formula)
  596. {
  597. OpenClVisitor z0Visitor{ 0, "double" };
  598. std::string startA = z0Visitor.visitNode(*formula.startA);
  599. std::string startB = z0Visitor.visitNode(*formula.startB);
  600. OpenClVisitor ocv{ z0Visitor.varnameCounter, "double" };
  601. std::string newA = ocv.visitNode(*formula.newA);
  602. std::string newB = ocv.visitNode(*formula.newB);
  603. std::string prelude =
  604. "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
  605. "__kernel void iterate(__global float* A, const int width, double xl, double yt, double pixelScaleX, double pixelScaleY, int max, int smooth, int julia, double juliaX, double juliaY) {\n"
  606. " int index = get_global_id(0);\n"
  607. " int ix = index % width;\n"
  608. " int iy = index / width;\n"
  609. " double c_re = ix * pixelScaleX + xl;\n"
  610. " double c_im = iy * pixelScaleY + yt;\n";
  611. prelude += z0Visitor.code.str() +
  612. " double z_re = " + startA + ";\n" +
  613. " double z_im = " + startB + ";\n" +
  614. "\n"
  615. " int n = 0;\n"
  616. " while (n < max - 1) {\n";
  617. std::string after =
  618. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  619. " n++;\n"
  620. " }\n"
  621. " if (n >= max - 1) {\n"
  622. " A[index] = max;\n"
  623. " }\n"
  624. " else {\n"
  625. " A[index] = ((float)n);\n"
  626. " }\n"
  627. "}\n";
  628. std::string code = prelude + ocv.code.str();
  629. code += "z_re = " + newA + ";\n";
  630. code += "z_im = " + newB + ";\n";
  631. code += after;
  632. //code = mnd::getFloat_cl();
  633. printf("cld: %s\n", code.c_str()); fflush(stdout);
  634. return code;
  635. }
  636. #ifdef WITH_OPENCL
  637. std::unique_ptr<MandelGenerator> compileCl(const ir::Formula& formula, MandelDevice& md)
  638. {
  639. return std::make_unique<CompiledClGenerator>(md, compileToOpenCl(formula));
  640. }
  641. std::unique_ptr<MandelGenerator> compileClDouble(const ir::Formula& formula, MandelDevice& md)
  642. {
  643. return std::make_unique<CompiledClGeneratorDouble>(md, compileToOpenClDouble(formula));
  644. }
  645. #endif
  646. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileCpu(mnd::MandelContext& mndCtxt,
  647. const IterationFormula& z0,
  648. const IterationFormula& zi)
  649. {
  650. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  651. IterationFormula z0o = z0.clone();
  652. IterationFormula zio = zi.clone();
  653. z0o.optimize();
  654. zio.optimize();
  655. ir::Formula irf = mnd::expand(z0o, zio);
  656. irf.optimize();
  657. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  658. auto dg = std::make_unique<CompiledGenerator>(compile(irf));
  659. printf("asm: %s\n", dg->dump().c_str()); fflush(stdout);
  660. vec.push_back(std::move(dg));
  661. if (mndCtxt.getCpuInfo().hasAvx()) {
  662. auto dgavx = std::make_unique<CompiledGeneratorVec>(compileAVXFloat(irf));
  663. printf("asm avxvec: %s\n", dgavx->dump().c_str()); fflush(stdout);
  664. vec.push_back(std::move(dgavx));
  665. }
  666. //vec.push_back(std::make_unique<NaiveIRGenerator<mnd::DoubleDouble>>(irf));
  667. //vec.push_back(std::make_unique<NaiveIRGenerator<mnd::QuadDouble>>(irf));
  668. //auto dg = std::make_unique<NaiveIRGenerator>(*irf, mnd::getPrecision<double>());
  669. //vec.push_back(std::move(ng));
  670. return vec;
  671. }
  672. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileOpenCl(mnd::MandelDevice& dev,
  673. const IterationFormula& z0,
  674. const IterationFormula& zi)
  675. {
  676. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  677. IterationFormula z0o = z0.clone();
  678. IterationFormula zio = zi.clone();
  679. z0o.optimize();
  680. zio.optimize();
  681. printf("if: %s\n", mnd::toString(*zio.expr).c_str()); fflush(stdout);
  682. ir::Formula irf = mnd::expand(z0o, zio);
  683. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  684. irf.optimize();
  685. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  686. auto fl = compileCl(irf, dev);
  687. vec.push_back(std::move(fl));
  688. if (dev.supportsDouble()) {
  689. irf.clearNodeData();
  690. auto fld = compileClDouble(irf, dev);
  691. vec.push_back(std::move(fld));
  692. }
  693. return vec;// { { mnd::GeneratorType::FLOAT, std::move(fl) } };
  694. }
  695. GeneratorCollection compileFormula(mnd::MandelContext& mndCtxt, const IterationFormula& z0,
  696. const IterationFormula& zi)
  697. {
  698. GeneratorCollection cr;
  699. cr.cpuGenerators = compileCpu(mndCtxt, z0, zi);
  700. for (auto& dev : mndCtxt.getDevices()) {
  701. auto gens = compileOpenCl(*dev, z0, zi);
  702. std::move(gens.begin(), gens.end(), std::back_inserter(cr.clGenerators));
  703. }
  704. cr.adaptiveGenerator = std::make_unique<mnd::AdaptiveGenerator>();
  705. if (!cr.clGenerators.empty()) {
  706. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<float>(), *cr.clGenerators[0]);
  707. }
  708. if (!cr.cpuGenerators.empty()) {
  709. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<double>(), *cr.cpuGenerators[0]);
  710. }
  711. return cr;
  712. }
  713. }
  714. mnd::GeneratorCollection::GeneratorCollection(void) :
  715. cpuGenerators{},
  716. clGenerators{},
  717. adaptiveGenerator{ nullptr }
  718. {
  719. }