IterationCompiler.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. #include "IterationCompiler.h"
  2. #include "ExecData.h"
  3. #include "Mandel.h"
  4. #include "OpenClInternal.h"
  5. #include "OpenClCode.h"
  6. #include <cmath>
  7. #include <omp.h>
  8. #include <any>
  9. #include <string>
  10. using namespace std::string_literals;
  11. namespace mnd
  12. {
  13. struct CompileVisitor
  14. {
  15. using Reg = asmjit::x86::Xmm;
  16. asmjit::x86::Compiler& cc;
  17. Reg& a;
  18. Reg& b;
  19. Reg& x;
  20. Reg& y;
  21. Reg visitNode(ir::Node& node)
  22. {
  23. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  24. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  25. return *regPtr;
  26. }
  27. else {
  28. Reg reg = std::visit(*this, node);
  29. nodeData = reg;
  30. return reg;
  31. }
  32. }
  33. CompileVisitor(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  34. cc{ cc },
  35. a{ a }, b{ b },
  36. x{ x }, y{ y }
  37. {
  38. }
  39. Reg operator()(const ir::Constant& c) {
  40. auto constant = cc.newDoubleConst(asmjit::ConstPool::kScopeLocal, mnd::convert<double>(c.value));
  41. auto reg = cc.newXmmSd();
  42. std::string commentStr = "move constant [";
  43. commentStr += std::to_string(mnd::convert<double>(c.value));
  44. commentStr += "]";
  45. cc.comment(commentStr.c_str());
  46. cc.movsd(reg, constant);
  47. return reg;
  48. }
  49. Reg operator()(const ir::Variable& v) {
  50. if (v.name == "z_re") {
  51. return a;
  52. }
  53. else if (v.name == "z_im") {
  54. return b;
  55. }
  56. else if (v.name == "c_re") {
  57. return x;
  58. }
  59. else if (v.name == "c_im") {
  60. return y;
  61. }
  62. else
  63. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  64. }
  65. Reg operator()(const ir::Negation& n) {
  66. auto sub = cc.newXmmSd();
  67. cc.xorpd(sub, sub);
  68. cc.subsd(sub, visitNode(*n.value));
  69. return sub;
  70. }
  71. Reg operator()(const ir::Addition& add) {
  72. auto res = cc.newXmmSd();
  73. cc.movapd(res, visitNode(*add.left));
  74. cc.addsd(res, visitNode(*add.right));
  75. return res;
  76. }
  77. Reg operator()(const ir::Subtraction& add) {
  78. auto res = cc.newXmmSd();
  79. cc.movapd(res, visitNode(*add.left));
  80. cc.subsd(res, visitNode(*add.right));
  81. return res;
  82. }
  83. Reg operator()(const ir::Multiplication& add) {
  84. auto res = cc.newXmmSd();
  85. cc.movapd(res, visitNode(*add.left));
  86. cc.mulsd(res, visitNode(*add.right));
  87. return res;
  88. }
  89. Reg operator()(const ir::Division& add) {
  90. auto res = cc.newXmmSd();
  91. cc.movapd(res, visitNode(*add.left));
  92. cc.divsd(res, visitNode(*add.right));
  93. return res;
  94. }
  95. static double myAtan2(double y, double x)
  96. {
  97. double result = ::atan2(y, x);
  98. printf("atan2(%f, %f) = %f\n", y, x, result);
  99. return result;
  100. }
  101. Reg operator()(const ir::Atan2& at2) {
  102. using namespace asmjit;
  103. auto y = visitNode(*at2.left);
  104. auto x = visitNode(*at2.right);
  105. auto arg = cc.newXmmSd();
  106. double(*atanFunc)(double, double) = ::atan2;
  107. cc.comment("call atan2");
  108. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  109. call->setArg(0, y);
  110. call->setArg(1, x);
  111. call->setRet(0, arg);
  112. return arg;
  113. }
  114. Reg operator()(const ir::Pow& p) {
  115. using namespace asmjit;
  116. auto a = visitNode(*p.left);
  117. auto b = visitNode(*p.right);
  118. auto arg = cc.newXmmSd();
  119. double(*powFunc)(double, double) = ::pow;
  120. cc.comment("call pow");
  121. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  122. call->setArg(0, a);
  123. call->setArg(1, b);
  124. call->setRet(0, arg);
  125. return arg;
  126. }
  127. Reg operator()(const ir::Cos& c) {
  128. using namespace asmjit;
  129. auto a = visitNode(*c.value);
  130. auto arg = cc.newXmmSd();
  131. double(*cosFunc)(double) = ::cos;
  132. cc.comment("call cos");
  133. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  134. call->setArg(0, a);
  135. call->setRet(0, arg);
  136. return arg;
  137. }
  138. Reg operator()(const ir::Sin& s) {
  139. using namespace asmjit;
  140. auto a = visitNode(*s.value);
  141. auto arg = cc.newXmmSd();
  142. double(*sinFunc)(double) = ::sin;
  143. cc.comment("call sin");
  144. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  145. call->setArg(0, a);
  146. call->setRet(0, arg);
  147. return arg;
  148. }
  149. Reg operator()(const ir::Exp& ex) {
  150. using namespace asmjit;
  151. auto a = visitNode(*ex.value);
  152. auto arg = cc.newXmmSd();
  153. double(*expFunc)(double) = ::exp;
  154. cc.comment("call exp");
  155. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  156. call->setArg(0, a);
  157. call->setRet(0, arg);
  158. return arg;
  159. }
  160. Reg operator()(const ir::Ln& l) {
  161. using namespace asmjit;
  162. auto a = visitNode(*l.value);
  163. auto arg = cc.newXmmSd();
  164. double(*logFunc)(double) = ::log;
  165. cc.comment("call log");
  166. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  167. call->setArg(0, a);
  168. call->setRet(0, arg);
  169. return arg;
  170. }
  171. };
  172. static void printD(double d) {
  173. printf("val: %f\n", d); fflush(stdout);
  174. }
  175. CompiledGenerator compile(const ir::Formula& formula)
  176. {
  177. using namespace asmjit;
  178. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  179. JitRuntime& jitRuntime = *ed->jitRuntime;
  180. ed->code->init(jitRuntime.codeInfo());
  181. x86::Compiler& comp = *ed->compiler;
  182. x86::Mem sixteen = comp.newDoubleConst(ConstPool::kScopeLocal, 16.0);
  183. Label startLoop = comp.newLabel();
  184. Label endLoop = comp.newLabel();
  185. x86::Gp maxIter = comp.newInt32();
  186. x86::Gp k = comp.newInt32();
  187. x86::Xmm x = comp.newXmmSd();
  188. x86::Xmm y = comp.newXmmSd();
  189. x86::Xmm a = comp.newXmmSd();
  190. x86::Xmm b = comp.newXmmSd();
  191. comp.addFunc(FuncSignatureT<int, double, double, int>(CallConv::kIdHost));
  192. comp.setArg(0, x);
  193. comp.setArg(1, y);
  194. CompileVisitor formVisitor{ comp, a, b, x, y };
  195. auto startA = std::visit(formVisitor, *formula.startA);
  196. auto startB = std::visit(formVisitor, *formula.startB);
  197. comp.movapd(a, startA);
  198. comp.movapd(b, startB);
  199. comp.setArg(2, maxIter);
  200. //comp.movapd(a, x);
  201. //comp.movapd(b, y);
  202. //comp.xorpd(a, a);
  203. //comp.movapd(b, b);
  204. comp.xor_(k, k);
  205. comp.bind(startLoop);
  206. auto newA = std::visit(formVisitor, *formula.newA);
  207. auto newB = std::visit(formVisitor, *formula.newB);
  208. comp.movapd(a, newA);
  209. comp.movapd(b, newB);
  210. x86::Xmm aa = comp.newXmmSd();
  211. x86::Xmm bb = comp.newXmmSd();
  212. comp.movapd(aa, a);
  213. comp.mulsd(aa, a);
  214. comp.movapd(bb, b);
  215. comp.mulsd(bb, b);
  216. comp.addsd(bb, aa);
  217. //auto call = comp.call(imm(printD), FuncSignatureT<void, double>(CallConv::kIdHost));
  218. //call->setArg(0, bb);
  219. comp.comisd(bb, sixteen);
  220. comp.jnb(endLoop);
  221. comp.inc(k);
  222. comp.cmp(k, maxIter);
  223. comp.jle(startLoop);
  224. comp.bind(endLoop);
  225. comp.ret(k);
  226. comp.endFunc();
  227. auto err = comp.finalize();
  228. if (err == asmjit::kErrorOk) {
  229. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  230. if (err != asmjit::kErrorOk)
  231. throw "error adding function";
  232. }
  233. else {
  234. throw "error compiling";
  235. }
  236. return CompiledGenerator{ std::move(ed) };
  237. }
  238. struct CompileVisitorAVXFloat
  239. {
  240. using Reg = asmjit::x86::Ymm;
  241. asmjit::x86::Compiler& cc;
  242. Reg& a;
  243. Reg& b;
  244. Reg& x;
  245. Reg& y;
  246. Reg visitNode(ir::Node& node)
  247. {
  248. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  249. if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
  250. return *regPtr;
  251. }
  252. else {
  253. Reg reg = std::visit(*this, node);
  254. nodeData = reg;
  255. return reg;
  256. }
  257. }
  258. CompileVisitorAVXFloat(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
  259. cc{ cc },
  260. a{ a }, b{ b },
  261. x{ x }, y{ y }
  262. {
  263. }
  264. Reg operator()(const ir::Constant& c) {
  265. auto constant = cc.newFloatConst(asmjit::ConstPool::kScopeLocal, mnd::convert<float>(c.value));
  266. auto reg = cc.newYmmPs();
  267. std::string commentStr = "move constant [";
  268. commentStr += std::to_string(mnd::convert<double>(c.value));
  269. commentStr += "]";
  270. cc.comment(commentStr.c_str());
  271. cc.vbroadcastss(reg, constant);
  272. return reg;
  273. }
  274. Reg operator()(const ir::Variable& v) {
  275. if (v.name == "z_re") {
  276. return a;
  277. }
  278. else if (v.name == "z_im") {
  279. return b;
  280. }
  281. else if (v.name == "c_re") {
  282. return x;
  283. }
  284. else if (v.name == "c_im") {
  285. return y;
  286. }
  287. else
  288. throw mnd::ParseError(std::string("unknown variable: ") + v.name);
  289. }
  290. Reg operator()(const ir::Negation& n) {
  291. auto sub = cc.newYmmPs();
  292. cc.vxorps(sub, sub, sub);
  293. cc.vsubps(sub, sub, visitNode(*n.value));
  294. return sub;
  295. }
  296. Reg operator()(const ir::Addition& add) {
  297. auto res = cc.newYmmPs();
  298. cc.vaddps(res, visitNode(*add.left), visitNode(*add.right));
  299. return res;
  300. }
  301. Reg operator()(const ir::Subtraction& add) {
  302. auto res = cc.newYmmPs();
  303. cc.vsubps(res, visitNode(*add.left), visitNode(*add.right));
  304. return res;
  305. }
  306. Reg operator()(const ir::Multiplication& add) {
  307. auto res = cc.newYmmPs();
  308. cc.vmulps(res, visitNode(*add.left), visitNode(*add.right));
  309. return res;
  310. }
  311. Reg operator()(const ir::Division& add) {
  312. auto res = cc.newYmmPs();
  313. cc.vdivps(res, visitNode(*add.left), visitNode(*add.right));
  314. return res;
  315. }
  316. static double myAtan2(double y, double x)
  317. {
  318. double result = ::atan2(y, x);
  319. printf("atan2(%f, %f) = %f\n", y, x, result);
  320. return result;
  321. }
  322. Reg operator()(const ir::Atan2& at2) {
  323. using namespace asmjit;
  324. auto y = visitNode(*at2.left);
  325. auto x = visitNode(*at2.right);
  326. auto arg = cc.newYmmPs();
  327. /*
  328. double(*atanFunc)(double, double) = ::atan2;
  329. cc.comment("call atan2");
  330. auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  331. call->setArg(0, y);
  332. call->setArg(1, x);
  333. call->setRet(0, arg);*/
  334. return arg;
  335. }
  336. Reg operator()(const ir::Pow& p) {
  337. using namespace asmjit;
  338. auto a = visitNode(*p.left);
  339. auto b = visitNode(*p.right);
  340. auto arg = cc.newYmmPs();
  341. /*double(*powFunc)(double, double) = ::pow;
  342. cc.comment("call pow");
  343. auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
  344. call->setArg(0, a);
  345. call->setArg(1, b);
  346. call->setRet(0, arg);*/
  347. return arg;
  348. }
  349. Reg operator()(const ir::Cos& c) {
  350. using namespace asmjit;
  351. auto a = visitNode(*c.value);
  352. auto arg = cc.newYmmPs();
  353. /*double(*cosFunc)(double) = ::cos;
  354. cc.comment("call cos");
  355. auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  356. call->setArg(0, a);
  357. call->setRet(0, arg);*/
  358. return arg;
  359. }
  360. Reg operator()(const ir::Sin& s) {
  361. using namespace asmjit;
  362. auto a = visitNode(*s.value);
  363. auto arg = cc.newYmmPs();
  364. /*double(*sinFunc)(double) = ::sin;
  365. cc.comment("call sin");
  366. auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  367. call->setArg(0, a);
  368. call->setRet(0, arg);*/
  369. return arg;
  370. }
  371. Reg operator()(const ir::Exp& ex) {
  372. using namespace asmjit;
  373. auto a = visitNode(*ex.value);
  374. auto arg = cc.newYmmPs();
  375. /*double(*expFunc)(double) = ::exp;
  376. cc.comment("call exp");
  377. auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  378. call->setArg(0, a);
  379. call->setRet(0, arg);*/
  380. return arg;
  381. }
  382. Reg operator()(const ir::Ln& l) {
  383. using namespace asmjit;
  384. auto a = visitNode(*l.value);
  385. auto arg = cc.newYmmPs();
  386. /*double(*logFunc)(double) = ::log;
  387. cc.comment("call log");
  388. auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
  389. call->setArg(0, a);
  390. call->setRet(0, arg);*/
  391. return arg;
  392. }
  393. };
  394. CompiledGeneratorVec compileAVXFloat(const ir::Formula& formula)
  395. {
  396. using namespace asmjit;
  397. std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
  398. JitRuntime& jitRuntime = *ed->jitRuntime;
  399. ed->code->init(jitRuntime.codeInfo());
  400. x86::Compiler& comp = *ed->compiler;
  401. x86::Mem sixteen = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(16.0f));
  402. x86::Mem one = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(1.0f));
  403. x86::Mem factors = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(0, 1, 2, 3, 4, 5, 6, 7));
  404. Label startLoop = comp.newLabel();
  405. Label endLoop = comp.newLabel();
  406. x86::Gp maxIter = comp.newInt32();
  407. x86::Gp k = comp.newInt32();
  408. x86::Gp resPtr = comp.newGpq();
  409. x86::Ymm adder = comp.newYmmPs();
  410. x86::Ymm counter = comp.newYmmPs();
  411. x86::Xmm xorig = comp.newXmmSs();
  412. x86::Xmm yorig = comp.newXmmSs();
  413. x86::Ymm dx = comp.newYmmPs();
  414. x86::Ymm x = comp.newYmmPs();
  415. x86::Ymm y = comp.newYmmPs();
  416. x86::Ymm a = comp.newYmmPs();
  417. x86::Ymm b = comp.newYmmPs();
  418. comp.addFunc(FuncSignatureT<int, float, float, float, int, float*>(CallConv::kIdHost));
  419. comp.setArg(0, xorig);
  420. comp.setArg(1, yorig);
  421. comp.setArg(2, dx.xmm());
  422. comp.setArg(3, maxIter);
  423. comp.setArg(4, resPtr);
  424. comp.vmovaps(adder, one);
  425. comp.vxorps(counter, counter, counter);
  426. comp.vshufps(xorig, xorig, xorig, 0);
  427. comp.vshufps(yorig, yorig, yorig, 0);
  428. comp.vshufps(dx.half(), dx.half(), dx.half(), 0);
  429. comp.vinsertf128(x, xorig.ymm(), xorig, 1);
  430. comp.vinsertf128(y, yorig.ymm(), yorig, 1);
  431. comp.vinsertf128(dx, dx, dx.xmm(), 1);
  432. comp.vmulps(dx, dx, factors);
  433. comp.vaddps(x, x, dx);
  434. CompileVisitorAVXFloat formVisitor{ comp, a, b, x, y };
  435. auto startA = std::visit(formVisitor, *formula.startA);
  436. auto startB = std::visit(formVisitor, *formula.startB);
  437. comp.vmovaps(a, startA);
  438. comp.vmovaps(b, startB);
  439. comp.xor_(k, k);
  440. comp.bind(startLoop);
  441. auto newA = std::visit(formVisitor, *formula.newA);
  442. auto newB = std::visit(formVisitor, *formula.newB);
  443. comp.vmovaps(a, newA);
  444. comp.vmovaps(b, newB);
  445. x86::Ymm aa = comp.newYmmPs();
  446. x86::Ymm bb = comp.newYmmPs();
  447. x86::Ymm cmp = comp.newYmmPs();
  448. comp.vmulps(aa, a, a);
  449. comp.vmulps(bb, b, b);
  450. comp.vaddps(bb, bb, aa);
  451. comp.vcmpps(cmp, bb, sixteen, 18);
  452. comp.vandps(adder, adder, cmp);
  453. comp.vaddps(counter, counter, adder);
  454. comp.cmp(k, maxIter);
  455. comp.je(endLoop);
  456. comp.add(k, 1);
  457. comp.vtestps(cmp, cmp);
  458. comp.jne(startLoop);
  459. comp.bind(endLoop);
  460. comp.vmovups(x86::xmmword_ptr(resPtr), counter.half());
  461. comp.vextractf128(x86::xmmword_ptr(resPtr, 16), counter, 0x1);
  462. comp.ret(k);
  463. comp.endFunc();
  464. auto err = comp.finalize();
  465. if (err == asmjit::kErrorOk) {
  466. err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
  467. if (err != asmjit::kErrorOk)
  468. throw "error adding function";
  469. }
  470. else {
  471. throw "error compiling";
  472. }
  473. return CompiledGeneratorVec{ std::move(ed) };
  474. }
  475. struct OpenClVisitor
  476. {
  477. int varnameCounter = 0;
  478. std::stringstream code;
  479. std::string floatTypeName;
  480. OpenClVisitor(int startVarname, const std::string& floatTypeName) :
  481. varnameCounter{ startVarname },
  482. floatTypeName{ floatTypeName }
  483. {
  484. }
  485. std::string createVarname(void)
  486. {
  487. return "tmp"s + std::to_string(varnameCounter++);
  488. }
  489. std::string visitNode(ir::Node& node)
  490. {
  491. auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
  492. if (std::string* var = std::any_cast<std::string>(&nodeData)) {
  493. return *var;
  494. }
  495. else {
  496. std::string value = std::visit(*this, node);
  497. if (!std::get_if<ir::Variable>(&node) && !std::get_if<ir::Constant>(&node)) {
  498. std::string varname = createVarname();
  499. code << floatTypeName << " " << varname << " = " << value << ";" << std::endl;
  500. nodeData = varname;
  501. return varname;
  502. }
  503. return value;
  504. }
  505. }
  506. std::string operator()(const ir::Constant& c) {
  507. return std::to_string(mnd::convert<double>(c.value)) + ((floatTypeName == "float") ? "f" : "");
  508. }
  509. std::string operator()(const ir::Variable& v) {
  510. return v.name;
  511. }
  512. std::string operator()(const ir::Negation& n) {
  513. return "-("s + visitNode(*n.value) + ")";
  514. }
  515. std::string operator()(const ir::Addition& a) {
  516. return "("s + visitNode(*a.left) + ") + (" + visitNode(*a.right) + ")";
  517. }
  518. std::string operator()(const ir::Subtraction& a) {
  519. return "("s + visitNode(*a.left) + ") - (" + visitNode(*a.right) + ")";
  520. }
  521. std::string operator()(const ir::Multiplication& a) {
  522. return "("s + visitNode(*a.left) + ") * (" + visitNode(*a.right) + ")";
  523. }
  524. std::string operator()(const ir::Division& a) {
  525. return "("s + visitNode(*a.left) + ") / (" + visitNode(*a.right) + ")";
  526. }
  527. std::string operator()(const ir::Atan2& a) {
  528. return "atan2("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  529. }
  530. std::string operator()(const ir::Pow& a) {
  531. return "pow("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
  532. }
  533. std::string operator()(const ir::Cos& a) {
  534. return "cos("s + visitNode(*a.value) + ")";
  535. }
  536. std::string operator()(const ir::Sin& a) {
  537. return "sin("s + visitNode(*a.value) + ")";
  538. }
  539. std::string operator()(const ir::Exp& a) {
  540. return "exp("s + visitNode(*a.value) + ")";
  541. }
  542. std::string operator()(const ir::Ln& a) {
  543. return "log("s + visitNode(*a.value) + ")";
  544. }
  545. };
  546. std::string compileToOpenCl(const ir::Formula& formula)
  547. {
  548. OpenClVisitor z0Visitor{ 0, "float" };
  549. std::string startA = z0Visitor.visitNode(*formula.startA);
  550. std::string startB = z0Visitor.visitNode(*formula.startB);
  551. OpenClVisitor ocv{ z0Visitor.varnameCounter, "float" };
  552. std::string newA = ocv.visitNode(*formula.newA);
  553. std::string newB = ocv.visitNode(*formula.newB);
  554. std::string prelude =
  555. "__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
  556. " int index = get_global_id(0);\n"
  557. " int ix = index % width;\n"
  558. " int iy = index / width;\n"
  559. " float c_re = ix * pixelScaleX + xl;\n"
  560. " float c_im = iy * pixelScaleY + yt;\n";
  561. prelude += z0Visitor.code.str() +
  562. " float z_re = " + startA + ";\n" +
  563. " float z_im = " + startB + ";\n" +
  564. "\n"
  565. " int n = 0;\n"
  566. " while (n < max - 1) {\n";
  567. /*
  568. std::string orig =
  569. " float aa = a * a;"
  570. " float bb = b * b;"
  571. " float ab = a * b;"
  572. " a = aa - bb + x;"
  573. " b = ab + ab + y;";
  574. */
  575. std::string after =
  576. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  577. " n++;\n"
  578. " }\n"
  579. " if (n >= max - 1) {\n"
  580. " A[index] = max;\n"
  581. " }\n"
  582. " else {\n"
  583. " A[index] = ((float)n);\n"
  584. " }\n"
  585. "}\n";
  586. std::string code = prelude + ocv.code.str();
  587. code += "z_re = " + newA + ";\n";
  588. code += "z_im = " + newB + ";\n";
  589. code += after;
  590. //code = mnd::getFloat_cl();
  591. printf("cl: %s\n", code.c_str()); fflush(stdout);
  592. return code;
  593. }
  594. std::string compileToOpenClDouble(const ir::Formula& formula)
  595. {
  596. OpenClVisitor z0Visitor{ 0, "double" };
  597. std::string startA = z0Visitor.visitNode(*formula.startA);
  598. std::string startB = z0Visitor.visitNode(*formula.startB);
  599. OpenClVisitor ocv{ z0Visitor.varnameCounter, "double" };
  600. std::string newA = ocv.visitNode(*formula.newA);
  601. std::string newB = ocv.visitNode(*formula.newB);
  602. std::string prelude =
  603. "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
  604. "__kernel void iterate(__global float* A, const int width, double xl, double yt, double pixelScaleX, double pixelScaleY, int max, int smooth, int julia, double juliaX, double juliaY) {\n"
  605. " int index = get_global_id(0);\n"
  606. " int ix = index % width;\n"
  607. " int iy = index / width;\n"
  608. " double c_re = ix * pixelScaleX + xl;\n"
  609. " double c_im = iy * pixelScaleY + yt;\n";
  610. prelude += z0Visitor.code.str() +
  611. " double z_re = " + startA + ";\n" +
  612. " double z_im = " + startB + ";\n" +
  613. "\n"
  614. " int n = 0;\n"
  615. " while (n < max - 1) {\n";
  616. std::string after =
  617. " if (z_re * z_re + z_im * z_im > 16) break;\n"
  618. " n++;\n"
  619. " }\n"
  620. " if (n >= max - 1) {\n"
  621. " A[index] = max;\n"
  622. " }\n"
  623. " else {\n"
  624. " A[index] = ((float)n);\n"
  625. " }\n"
  626. "}\n";
  627. std::string code = prelude + ocv.code.str();
  628. code += "z_re = " + newA + ";\n";
  629. code += "z_im = " + newB + ";\n";
  630. code += after;
  631. //code = mnd::getFloat_cl();
  632. printf("cld: %s\n", code.c_str()); fflush(stdout);
  633. return code;
  634. }
  635. #ifdef WITH_OPENCL
  636. std::unique_ptr<MandelGenerator> compileCl(const ir::Formula& formula, MandelDevice& md)
  637. {
  638. return std::make_unique<CompiledClGenerator>(md, compileToOpenCl(formula));
  639. }
  640. std::unique_ptr<MandelGenerator> compileClDouble(const ir::Formula& formula, MandelDevice& md)
  641. {
  642. return std::make_unique<CompiledClGeneratorDouble>(md, compileToOpenClDouble(formula));
  643. }
  644. #endif
  645. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileCpu(mnd::MandelContext& mndCtxt,
  646. const IterationFormula& z0,
  647. const IterationFormula& zi)
  648. {
  649. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  650. IterationFormula z0o = z0.clone();
  651. IterationFormula zio = zi.clone();
  652. z0o.optimize();
  653. zio.optimize();
  654. ir::Formula irf = mnd::expand(z0o, zio);
  655. irf.optimize();
  656. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  657. auto dg = std::make_unique<CompiledGenerator>(compile(irf));
  658. printf("asm: %s\n", dg->dump().c_str()); fflush(stdout);
  659. vec.push_back(std::move(dg));
  660. if (mndCtxt.getCpuInfo().hasAvx()) {
  661. auto dgavx = std::make_unique<CompiledGeneratorVec>(compileAVXFloat(irf));
  662. printf("asm avxvec: %s\n", dgavx->dump().c_str()); fflush(stdout);
  663. vec.push_back(std::move(dgavx));
  664. }
  665. //auto dg = std::make_unique<NaiveIRGenerator>(*irf, mnd::getPrecision<double>());
  666. //vec.push_back(std::move(ng));
  667. return vec;
  668. }
  669. std::vector<std::unique_ptr<mnd::MandelGenerator>> compileOpenCl(mnd::MandelDevice& dev,
  670. const IterationFormula& z0,
  671. const IterationFormula& zi)
  672. {
  673. std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
  674. IterationFormula z0o = z0.clone();
  675. IterationFormula zio = zi.clone();
  676. z0o.optimize();
  677. zio.optimize();
  678. printf("if: %s\n", mnd::toString(*zio.expr).c_str()); fflush(stdout);
  679. ir::Formula irf = mnd::expand(z0o, zio);
  680. irf.optimize();
  681. printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
  682. auto fl = compileCl(irf, dev);
  683. vec.push_back(std::move(fl));
  684. if (dev.supportsDouble()) {
  685. irf.clearNodeData();
  686. auto fld = compileClDouble(irf, dev);
  687. vec.push_back(std::move(fld));
  688. }
  689. return vec;// { { mnd::GeneratorType::FLOAT, std::move(fl) } };
  690. }
  691. GeneratorCollection compileFormula(mnd::MandelContext& mndCtxt, const IterationFormula& z0,
  692. const IterationFormula& zi)
  693. {
  694. GeneratorCollection cr;
  695. cr.cpuGenerators = compileCpu(mndCtxt, z0, zi);
  696. for (auto& dev : mndCtxt.getDevices()) {
  697. auto gens = compileOpenCl(*dev, z0, zi);
  698. std::move(gens.begin(), gens.end(), std::back_inserter(cr.clGenerators));
  699. }
  700. cr.adaptiveGenerator = std::make_unique<mnd::AdaptiveGenerator>();
  701. if (!cr.clGenerators.empty()) {
  702. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<float>(), *cr.clGenerators[0]);
  703. }
  704. if (!cr.cpuGenerators.empty()) {
  705. cr.adaptiveGenerator->addGenerator(mnd::getPrecision<double>(), *cr.cpuGenerators[0]);
  706. }
  707. return cr;
  708. }
  709. }
  710. mnd::GeneratorCollection::GeneratorCollection(void) :
  711. cpuGenerators{},
  712. clGenerators{},
  713. adaptiveGenerator{ nullptr }
  714. {
  715. }