CpuGeneratorsAVX512.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. #include "CpuGenerators.h"
  2. #include <immintrin.h>
  3. #include <omp.h>
  4. #include <memory>
  5. using mnd::CpuGenerator;
  6. namespace mnd
  7. {
  8. template class CpuGenerator<float, mnd::X86_AVX_512, false>;
  9. template class CpuGenerator<float, mnd::X86_AVX_512, true>;
  10. template class CpuGenerator<double, mnd::X86_AVX_512, false>;
  11. template class CpuGenerator<double, mnd::X86_AVX_512, true>;
  12. }
  13. template<bool parallel>
  14. void CpuGenerator<float, mnd::X86_AVX_512, parallel>::generate(const mnd::MandelInfo& info, float* data)
  15. {
  16. using T = float;
  17. const MandelViewport& view = info.view;
  18. const float dppf = float(view.width / info.bWidth);
  19. const float viewxf = float(view.x);
  20. __m512 viewx = _mm512_set1_ps(viewxf);
  21. __m512 dpp = _mm512_set1_ps(dppf);
  22. __m512 enumerate = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
  23. __m512 two = _mm512_set1_ps(2);
  24. T jX = mnd::convert<T>(info.juliaX);
  25. T jY = mnd::convert<T>(info.juliaY);
  26. __m512 juliaX = _mm512_set1_ps(jX);
  27. __m512 juliaY = _mm512_set1_ps(jY);
  28. #if defined(_OPENMP)
  29. if constexpr(parallel)
  30. omp_set_num_threads(omp_get_num_procs());
  31. #pragma omp parallel for schedule(static, 1) if (parallel)
  32. #endif
  33. for (long j = 0; j < info.bHeight; j++) {
  34. T y = T(view.y + double(j) * view.height / info.bHeight);
  35. __m512 ys = _mm512_set1_ps(y);
  36. for (long i = 0; i < info.bWidth; i += 2 * 16) {
  37. __m512 pixc0 = _mm512_add_ps(_mm512_set1_ps(float(i)), enumerate);
  38. __m512 pixc1 = _mm512_add_ps(_mm512_set1_ps(float(i + 16)), enumerate);
  39. //__m512 pixc2 = _mm512_add_ps(_mm512_set1_ps(float(i + 32)), enumerate);
  40. __m512 xs0 = _mm512_fmadd_ps(dpp, pixc0, viewx);
  41. __m512 xs1 = _mm512_fmadd_ps(dpp, pixc1, viewx);
  42. //__m512 xs2 = _mm512_fmadd_ps(dpp, pixc2, viewx);
  43. __m512 counter0 = _mm512_setzero_ps();
  44. __m512 counter1 = _mm512_setzero_ps();
  45. //__m512 counter2 = _mm512_setzero_ps();
  46. __m512 adder0 = _mm512_set1_ps(1);
  47. __m512 adder1 = _mm512_set1_ps(1);
  48. //__m512 adder2 = _mm512_set1_ps(1);
  49. __m512 resultsa0 = _mm512_setzero_ps();
  50. __m512 resultsa1 = _mm512_setzero_ps();
  51. //__m512 resultsa2 = _mm512_setzero_ps();
  52. __m512 resultsb0 = _mm512_setzero_ps();
  53. __m512 resultsb1 = _mm512_setzero_ps();
  54. //__m512 resultsb2 = _mm512_setzero_ps();
  55. __m512 threshold = _mm512_set1_ps(16);
  56. __m512 cx0 = xs0;
  57. __m512 cx1 = xs1;
  58. __m512 cy = ys;
  59. if (info.julia) {
  60. cx0 = juliaX;
  61. cx1 = juliaX;
  62. cy = juliaY;
  63. }
  64. __m512 a0 = xs0;
  65. __m512 a1 = xs1;
  66. //__m512 a2 = xs2;
  67. __m512 b0 = ys;
  68. __m512 b1 = ys;
  69. //__m512 b2 = ys;
  70. if (info.smooth) {
  71. __mmask16 cmp0 = 0xFFFF;
  72. __mmask16 cmp1 = 0xFFFF;
  73. for (int k = 0; k < info.maxIter; k++) {
  74. __m512 aa0 = _mm512_mul_ps(a0, a0);
  75. __m512 aa1 = _mm512_mul_ps(a1, a1);
  76. //__m512 aa2 = _mm512_mul_ps(a2, a2);
  77. __m512 abab0 = _mm512_mul_ps(a0, b0);
  78. __m512 abab1 = _mm512_mul_ps(a1, b1);
  79. //__m512 abab2 = _mm512_mul_ps(a2, b2);
  80. a0 = _mm512_sub_ps(aa0, _mm512_fmsub_ps(b0, b0, cx0));
  81. a1 = _mm512_sub_ps(aa1, _mm512_fmsub_ps(b1, b1, cx1));
  82. //a2 = _mm512_sub_ps(aa2, _mm512_fmsub_ps(b2, b2, xs2));
  83. b0 = _mm512_fmadd_ps(two, abab0, cy);
  84. b1 = _mm512_fmadd_ps(two, abab1, cy);
  85. //b2 = _mm512_fmadd_ps(two, abab2, ys);
  86. resultsa0 = _mm512_mask_blend_ps(cmp0, resultsa0, a0);
  87. resultsa1 = _mm512_mask_blend_ps(cmp1, resultsa1, a1);
  88. //resultsa2 = _mm512_mask_blend_ps(cmp2, resultsa2, a2);
  89. resultsb0 = _mm512_mask_blend_ps(cmp0, resultsb0, b0);
  90. resultsb1 = _mm512_mask_blend_ps(cmp1, resultsb1, b1);
  91. //resultsb2 = _mm512_mask_blend_ps(cmp2, resultsb2, b2);
  92. cmp0 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b0, b0, aa0), threshold, _CMP_LE_OQ);
  93. cmp1 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b1, b1, aa1), threshold, _CMP_LE_OQ);
  94. //__mmask16 cmp2 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b2, b2, aa2), threshold, _CMP_LE_OQ);
  95. counter0 = _mm512_mask_add_ps(counter0, cmp0, counter0, adder0);
  96. counter1 = _mm512_mask_add_ps(counter1, cmp1, counter1, adder1);
  97. //counter2 = _mm512_mask_add_ps(counter2, cmp2, counter2, adder2);
  98. if (cmp0 == 0 && cmp1 == 0 /*&& cmp2 == 0*/) {
  99. break;
  100. }
  101. }
  102. }
  103. else {
  104. for (int k = 0; k < info.maxIter; k++) {
  105. __m512 aa0 = _mm512_mul_ps(a0, a0);
  106. __m512 aa1 = _mm512_mul_ps(a1, a1);
  107. //__m512 aa2 = _mm512_mul_ps(a2, a2);
  108. __m512 abab0 = _mm512_mul_ps(a0, b0);
  109. __m512 abab1 = _mm512_mul_ps(a1, b1);
  110. //__m512 abab2 = _mm512_mul_ps(a2, b2);
  111. __mmask16 cmp0 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b0, b0, aa0), threshold, _CMP_LE_OQ);
  112. __mmask16 cmp1 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b1, b1, aa1), threshold, _CMP_LE_OQ);
  113. //__mmask16 cmp2 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b2, b2, aa2), threshold, _CMP_LE_OQ);
  114. a0 = _mm512_sub_ps(aa0, _mm512_fmsub_ps(b0, b0, cx0));
  115. a1 = _mm512_sub_ps(aa1, _mm512_fmsub_ps(b1, b1, cx1));
  116. //a2 = _mm512_sub_ps(aa2, _mm512_fmsub_ps(b2, b2, xs2));
  117. b0 = _mm512_fmadd_ps(two, abab0, cy);
  118. b1 = _mm512_fmadd_ps(two, abab1, cy);
  119. //b2 = _mm512_fmadd_ps(two, abab2, ys);
  120. counter0 = _mm512_mask_add_ps(counter0, cmp0, counter0, adder0);
  121. counter1 = _mm512_mask_add_ps(counter1, cmp1, counter1, adder1);
  122. //counter2 = _mm512_mask_add_ps(counter2, cmp2, counter2, adder2);
  123. if (cmp0 == 0 && cmp1 == 0 /*&& cmp2 == 0*/) {
  124. break;
  125. }
  126. }
  127. }
  128. auto alignVec = [](float* data) -> float* {
  129. void* aligned = data;
  130. ::size_t length = 3 * 64 * sizeof(float);
  131. std::align(64, 48 * sizeof(float), aligned, length);
  132. return static_cast<float*>(aligned);
  133. };
  134. float resData[3 * 64];
  135. float* ftRes = alignVec(resData);
  136. float* resa = ftRes + 3 * 16;
  137. float* resb = ftRes + 6 * 16;
  138. _mm512_store_ps(ftRes, counter0);
  139. _mm512_store_ps(ftRes + 16, counter1);
  140. //_mm512_store_ps(ftRes + 32, counter2);
  141. if (info.smooth) {
  142. _mm512_store_ps(resa, resultsa0);
  143. _mm512_store_ps(resa + 16, resultsa1);
  144. //_mm512_store_ps(resa + 32, resultsa2);
  145. _mm512_store_ps(resb, resultsb0);
  146. _mm512_store_ps(resb + 16, resultsb1);
  147. //_mm512_store_ps(resb + 32, resultsb2);
  148. }
  149. for (int k = 0; k < 2 * 16 && i + k < info.bWidth; k++) {
  150. if (info.smooth) {
  151. data[i + k + j * info.bWidth] = ftRes[k] < 0 ? info.maxIter :
  152. ftRes[k] >= info.maxIter ? info.maxIter :
  153. ((float)ftRes[k]) + 1 - ::log(::log(resa[k] * resa[k] + resb[k] * resb[k]) / 2) / ::log(2.0f);
  154. }
  155. else {
  156. data[i + k + j * info.bWidth] = ftRes[k] < 0 ? info.maxIter : ftRes[k];
  157. }
  158. }
  159. }
  160. }
  161. }
  162. template<bool parallel>
  163. void CpuGenerator<double, mnd::X86_AVX_512, parallel>::generate(const mnd::MandelInfo& info, float* data)
  164. {
  165. using T = double;
  166. const MandelViewport& view = info.view;
  167. const double dppf = double(view.width / info.bWidth);
  168. const double viewxf = double(view.x);
  169. __m512d viewx = { viewxf, viewxf, viewxf, viewxf, viewxf, viewxf, viewxf, viewxf };
  170. __m512d dpp = { dppf, dppf, dppf, dppf, dppf, dppf, dppf, dppf };
  171. T jX = mnd::convert<T>(info.juliaX);
  172. T jY = mnd::convert<T>(info.juliaY);
  173. __m512d juliaX = _mm512_set1_pd(jX);
  174. __m512d juliaY = _mm512_set1_pd(jY);
  175. #if defined(_OPENMP)
  176. if constexpr(parallel)
  177. omp_set_num_threads(omp_get_num_procs());
  178. # pragma omp parallel for schedule(static, 1) if (parallel)
  179. #endif
  180. for (long j = 0; j < info.bHeight; j++) {
  181. T y = T(view.y + double(j) * view.height / info.bHeight);
  182. __m512d ys = { y, y, y, y, y, y, y, y };
  183. for (long i = 0; i < info.bWidth; i += 8) {
  184. __m512d pixc = { double(i), double(i + 1), double(i + 2), double(i + 3), double(i + 4), double(i + 5), double(i + 6), double(i + 7) };
  185. __m512d xs = _mm512_fmadd_pd(dpp, pixc, viewx);
  186. __m512d counter = { 0, 0, 0, 0, 0, 0, 0, 0 };
  187. __m512d adder = { 1, 1, 1, 1, 1, 1, 1, 1 };
  188. __m512d two = { 2, 2, 2, 2, 2, 2, 2, 2 };
  189. __m512d resultsa = { 0, 0, 0, 0, 0, 0, 0, 0 };
  190. __m512d resultsb = { 0, 0, 0, 0, 0, 0, 0, 0 };
  191. __m512d threshold = { 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f };
  192. __m512d cx = xs;
  193. __m512d cy = ys;
  194. if (info.julia) {
  195. cx = juliaX;
  196. cy = juliaY;
  197. }
  198. __m512d a = xs;
  199. __m512d b = ys;
  200. if (info.smooth) {
  201. __mmask8 cmp = 0xFF;
  202. for (int k = 0; k < info.maxIter; k++) {
  203. __m512d aa = _mm512_mul_pd(a, a);
  204. __m512d ab = _mm512_mul_pd(a, b);
  205. a = _mm512_sub_pd(aa, _mm512_fmsub_pd(b, b, cx));
  206. b = _mm512_fmadd_pd(two, ab, cy);
  207. resultsa = _mm512_mask_blend_pd(cmp, resultsa, a);
  208. resultsb = _mm512_mask_blend_pd(cmp, resultsb, b);
  209. cmp = _mm512_cmp_pd_mask(_mm512_fmadd_pd(b, b, aa), threshold, _CMP_LE_OQ);
  210. counter = _mm512_mask_add_pd(counter, cmp, counter, adder);
  211. if (cmp == 0) {
  212. break;
  213. }
  214. }
  215. }
  216. else {
  217. for (int k = 0; k < info.maxIter; k++) {
  218. __m512d aa = _mm512_mul_pd(a, a);
  219. __m512d ab = _mm512_mul_pd(a, b);
  220. __mmask8 cmp = _mm512_cmp_pd_mask(_mm512_fmadd_pd(b, b, aa), threshold, _CMP_LE_OQ);
  221. a = _mm512_sub_pd(aa, _mm512_fmsub_pd(b, b, cx));
  222. b = _mm512_fmadd_pd(two, ab, cy);
  223. counter = _mm512_mask_add_pd(counter, cmp, counter, adder);
  224. if (cmp == 0) {
  225. break;
  226. }
  227. }
  228. }
  229. auto alignVec = [](double* data) -> double* {
  230. void* aligned = data;
  231. ::size_t length = 32 * sizeof(double);
  232. std::align(64, 24 * sizeof(double), aligned, length);
  233. return static_cast<double*>(aligned);
  234. };
  235. double resData[64];
  236. double* ftRes = alignVec(resData);
  237. double* resa = ftRes + 8;
  238. double* resb = ftRes + 16;
  239. _mm512_store_pd(ftRes, counter);
  240. if (info.smooth) {
  241. _mm512_store_pd(resa, resultsa);
  242. _mm512_store_pd(resb, resultsb);
  243. }
  244. for (int k = 0; k < 8 && i + k < info.bWidth; k++) {
  245. if (info.smooth) {
  246. data[i + k + j * info.bWidth] = ftRes[k] < 0 ? info.maxIter :
  247. ftRes[k] >= info.maxIter ? info.maxIter :
  248. ((float)ftRes[k]) + 1 - ::log(::log((float) (resa[k] * resa[k] + resb[k] * resb[k])) / 2) / ::log(2.0f);
  249. }
  250. else {
  251. data[i + k + j * info.bWidth] = ftRes[k] < 0 ? info.maxIter : ftRes[k];
  252. }
  253. }
  254. }
  255. }
  256. }