CpuGeneratorsAVX512.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include "CpuGenerators.h"
  2. #include <immintrin.h>
  3. #include <omp.h>
  4. #include <memory>
  5. using mnd::CpuGenerator;
  6. namespace mnd
  7. {
  8. template class CpuGenerator<float, mnd::X86_AVX_512, false>;
  9. template class CpuGenerator<float, mnd::X86_AVX_512, true>;
  10. template class CpuGenerator<double, mnd::X86_AVX_512, false>;
  11. template class CpuGenerator<double, mnd::X86_AVX_512, true>;
  12. }
  13. template<bool parallel>
  14. void CpuGenerator<float, mnd::X86_AVX_512, parallel>::generate(const mnd::MandelInfo& info, float* data)
  15. {
  16. using T = float;
  17. const MandelViewport& view = info.view;
  18. const float dppf = float(view.width / info.bWidth);
  19. const float viewxf = float(view.x);
  20. __m512 viewx = _mm512_set1_ps(viewxf);
  21. __m512 dpp = _mm512_set1_ps(dppf);
  22. __m512 enumerate = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
  23. __m512 two = _mm512_set1_ps(2);
  24. #if defined(_OPENMP)
  25. if constexpr(parallel)
  26. omp_set_num_threads(omp_get_num_procs());
  27. #pragma omp parallel for schedule(static, 1) if (parallel)
  28. #endif
  29. for (long j = 0; j < info.bHeight; j++) {
  30. T y = T(view.y + double(j) * view.height / info.bHeight);
  31. __m512 ys = _mm512_set1_ps(y);
  32. for (long i = 0; i < info.bWidth; i += 2 * 16) {
  33. __m512 pixc0 = _mm512_add_ps(_mm512_set1_ps(float(i)), enumerate);
  34. __m512 pixc1 = _mm512_add_ps(_mm512_set1_ps(float(i + 16)), enumerate);
  35. //__m512 pixc2 = _mm512_add_ps(_mm512_set1_ps(float(i + 32)), enumerate);
  36. __m512 xs0 = _mm512_fmadd_ps(dpp, pixc0, viewx);
  37. __m512 xs1 = _mm512_fmadd_ps(dpp, pixc1, viewx);
  38. //__m512 xs2 = _mm512_fmadd_ps(dpp, pixc2, viewx);
  39. __m512 counter0 = _mm512_setzero_ps();
  40. __m512 counter1 = _mm512_setzero_ps();
  41. //__m512 counter2 = _mm512_setzero_ps();
  42. __m512 adder0 = _mm512_set1_ps(1);
  43. __m512 adder1 = _mm512_set1_ps(1);
  44. //__m512 adder2 = _mm512_set1_ps(1);
  45. __m512 resultsa0 = _mm512_setzero_ps();
  46. __m512 resultsa1 = _mm512_setzero_ps();
  47. //__m512 resultsa2 = _mm512_setzero_ps();
  48. __m512 resultsb0 = _mm512_setzero_ps();
  49. __m512 resultsb1 = _mm512_setzero_ps();
  50. //__m512 resultsb2 = _mm512_setzero_ps();
  51. __m512 threshold = _mm512_set1_ps(16);
  52. __m512 a0 = xs0;
  53. __m512 a1 = xs1;
  54. //__m512 a2 = xs2;
  55. __m512 b0 = ys;
  56. __m512 b1 = ys;
  57. //__m512 b2 = ys;
  58. if (info.smooth) {
  59. for (int k = 0; k < info.maxIter; k++) {
  60. __m512 aa0 = _mm512_mul_ps(a0, a0);
  61. __m512 aa1 = _mm512_mul_ps(a1, a1);
  62. //__m512 aa2 = _mm512_mul_ps(a2, a2);
  63. __m512 abab0 = _mm512_mul_ps(a0, b0);
  64. __m512 abab1 = _mm512_mul_ps(a1, b1);
  65. //__m512 abab2 = _mm512_mul_ps(a2, b2);
  66. __mmask16 cmp0 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b0, b0, aa0), threshold, _CMP_LE_OQ);
  67. __mmask16 cmp1 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b1, b1, aa1), threshold, _CMP_LE_OQ);
  68. //__mmask16 cmp2 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b2, b2, aa2), threshold, _CMP_LE_OQ);
  69. a0 = _mm512_sub_ps(aa0, _mm512_fmsub_ps(b0, b0, xs0));
  70. a1 = _mm512_sub_ps(aa1, _mm512_fmsub_ps(b1, b1, xs1));
  71. //a2 = _mm512_sub_ps(aa2, _mm512_fmsub_ps(b2, b2, xs2));
  72. b0 = _mm512_fmadd_ps(two, abab0, ys);
  73. b1 = _mm512_fmadd_ps(two, abab1, ys);
  74. //b2 = _mm512_fmadd_ps(two, abab2, ys);
  75. counter0 = _mm512_mask_add_ps(counter0, cmp0, counter0, adder0);
  76. counter1 = _mm512_mask_add_ps(counter1, cmp1, counter1, adder1);
  77. //counter2 = _mm512_mask_add_ps(counter2, cmp2, counter2, adder2);
  78. resultsa0 = _mm512_mask_blend_ps(cmp0, resultsa0, a0);
  79. resultsa1 = _mm512_mask_blend_ps(cmp1, resultsa1, a1);
  80. //resultsa2 = _mm512_mask_blend_ps(cmp2, resultsa2, a2);
  81. resultsb0 = _mm512_mask_blend_ps(cmp0, resultsb0, b0);
  82. resultsb1 = _mm512_mask_blend_ps(cmp1, resultsb1, b1);
  83. //resultsb2 = _mm512_mask_blend_ps(cmp2, resultsb2, b2);
  84. if (cmp0 == 0 && cmp1 == 0 /*&& cmp2 == 0*/) {
  85. break;
  86. }
  87. }
  88. }
  89. else {
  90. for (int k = 0; k < info.maxIter; k++) {
  91. __m512 aa0 = _mm512_mul_ps(a0, a0);
  92. __m512 aa1 = _mm512_mul_ps(a1, a1);
  93. //__m512 aa2 = _mm512_mul_ps(a2, a2);
  94. __m512 abab0 = _mm512_mul_ps(a0, b0);
  95. __m512 abab1 = _mm512_mul_ps(a1, b1);
  96. //__m512 abab2 = _mm512_mul_ps(a2, b2);
  97. __mmask16 cmp0 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b0, b0, aa0), threshold, _CMP_LE_OQ);
  98. __mmask16 cmp1 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b1, b1, aa1), threshold, _CMP_LE_OQ);
  99. //__mmask16 cmp2 = _mm512_cmp_ps_mask(_mm512_fmadd_ps(b2, b2, aa2), threshold, _CMP_LE_OQ);
  100. a0 = _mm512_sub_ps(aa0, _mm512_fmsub_ps(b0, b0, xs0));
  101. a1 = _mm512_sub_ps(aa1, _mm512_fmsub_ps(b1, b1, xs1));
  102. //a2 = _mm512_sub_ps(aa2, _mm512_fmsub_ps(b2, b2, xs2));
  103. b0 = _mm512_fmadd_ps(two, abab0, ys);
  104. b1 = _mm512_fmadd_ps(two, abab1, ys);
  105. //b2 = _mm512_fmadd_ps(two, abab2, ys);
  106. counter0 = _mm512_mask_add_ps(counter0, cmp0, counter0, adder0);
  107. counter1 = _mm512_mask_add_ps(counter1, cmp1, counter1, adder1);
  108. //counter2 = _mm512_mask_add_ps(counter2, cmp2, counter2, adder2);
  109. if (cmp0 == 0 && cmp1 == 0 /*&& cmp2 == 0*/) {
  110. break;
  111. }
  112. }
  113. }
  114. auto alignVec = [](float* data) -> float* {
  115. void* aligned = data;
  116. ::size_t length = 3 * 64 * sizeof(float);
  117. std::align(64, 48 * sizeof(float), aligned, length);
  118. return static_cast<float*>(aligned);
  119. };
  120. float resData[3 * 64];
  121. float* ftRes = alignVec(resData);
  122. float* resa = ftRes + 3 * 16;
  123. float* resb = ftRes + 6 * 16;
  124. _mm512_store_ps(ftRes, counter0);
  125. _mm512_store_ps(ftRes + 16, counter1);
  126. //_mm512_store_ps(ftRes + 32, counter2);
  127. if (info.smooth) {
  128. _mm512_store_ps(resa, resultsa0);
  129. _mm512_store_ps(resa + 16, resultsa1);
  130. //_mm512_store_ps(resa + 32, resultsa2);
  131. _mm512_store_ps(resb, resultsb0);
  132. _mm512_store_ps(resb + 16, resultsb1);
  133. //_mm512_store_ps(resb + 32, resultsb2);
  134. }
  135. for (int k = 0; k < 2 * 16 && i + k < info.bWidth; k++) {
  136. if (info.smooth) {
  137. data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter :
  138. ftRes[k] >= info.maxIter ? info.maxIter :
  139. ((float)ftRes[k]) + 1 - ::log(::log(resa[k] * resa[k] + resb[k] * resb[k]) / 2) / ::log(2.0f);
  140. }
  141. else {
  142. data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter : ftRes[k];
  143. }
  144. }
  145. }
  146. }
  147. }
  148. template<bool parallel>
  149. void CpuGenerator<double, mnd::X86_AVX_512, parallel>::generate(const mnd::MandelInfo& info, float* data)
  150. {
  151. using T = double;
  152. const MandelViewport& view = info.view;
  153. const double dppf = double(view.width / info.bWidth);
  154. const double viewxf = double(view.x);
  155. __m512d viewx = { viewxf, viewxf, viewxf, viewxf, viewxf, viewxf, viewxf, viewxf };
  156. __m512d dpp = { dppf, dppf, dppf, dppf, dppf, dppf, dppf, dppf };
  157. #if defined(_OPENMP)
  158. if constexpr(parallel)
  159. omp_set_num_threads(omp_get_num_procs());
  160. # pragma omp parallel for schedule(static, 1) if (parallel)
  161. #endif
  162. for (long j = 0; j < info.bHeight; j++) {
  163. T y = T(view.y + double(j) * view.height / info.bHeight);
  164. __m512d ys = { y, y, y, y, y, y, y, y };
  165. for (long i = 0; i < info.bWidth; i += 8) {
  166. __m512d pixc = { double(i), double(i + 1), double(i + 2), double(i + 3), double(i + 4), double(i + 5), double(i + 6), double(i + 7) };
  167. __m512d xs = _mm512_fmadd_pd(dpp, pixc, viewx);
  168. __m512d counter = { 0, 0, 0, 0, 0, 0, 0, 0 };
  169. __m512d adder = { 1, 1, 1, 1, 1, 1, 1, 1 };
  170. __m512d two = { 2, 2, 2, 2, 2, 2, 2, 2 };
  171. __m512d resultsa = { 0, 0, 0, 0, 0, 0, 0, 0 };
  172. __m512d resultsb = { 0, 0, 0, 0, 0, 0, 0, 0 };
  173. __m512d threshold = { 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f };
  174. __m512d a = xs;
  175. __m512d b = ys;
  176. if (info.smooth) {
  177. for (int k = 0; k < info.maxIter; k++) {
  178. __m512d aa = _mm512_mul_pd(a, a);
  179. __m512d ab = _mm512_mul_pd(a, b);
  180. __mmask8 cmp = _mm512_cmp_pd_mask(_mm512_fmadd_pd(b, b, aa), threshold, _CMP_LE_OQ);
  181. a = _mm512_sub_pd(aa, _mm512_fmsub_pd(b, b, xs));
  182. b = _mm512_fmadd_pd(two, ab, ys);
  183. resultsa = _mm512_mask_blend_pd(cmp, resultsa, a);
  184. resultsb = _mm512_mask_blend_pd(cmp, resultsb, b);
  185. counter = _mm512_mask_add_pd(counter, cmp, counter, adder);
  186. if (cmp == 0) {
  187. break;
  188. }
  189. }
  190. }
  191. else {
  192. for (int k = 0; k < info.maxIter; k++) {
  193. __m512d aa = _mm512_mul_pd(a, a);
  194. __m512d ab = _mm512_mul_pd(a, b);
  195. __mmask8 cmp = _mm512_cmp_pd_mask(_mm512_fmadd_pd(b, b, aa), threshold, _CMP_LE_OQ);
  196. a = _mm512_sub_pd(aa, _mm512_fmsub_pd(b, b, xs));
  197. b = _mm512_fmadd_pd(two, ab, ys);
  198. counter = _mm512_mask_add_pd(counter, cmp, counter, adder);
  199. if (cmp == 0) {
  200. break;
  201. }
  202. }
  203. }
  204. auto alignVec = [](double* data) -> double* {
  205. void* aligned = data;
  206. ::size_t length = 32 * sizeof(double);
  207. std::align(64, 24 * sizeof(double), aligned, length);
  208. return static_cast<double*>(aligned);
  209. };
  210. double resData[64];
  211. double* ftRes = alignVec(resData);
  212. double* resa = ftRes + 8;
  213. double* resb = ftRes + 16;
  214. _mm512_store_pd(ftRes, counter);
  215. if (info.smooth) {
  216. _mm512_store_pd(resa, resultsa);
  217. _mm512_store_pd(resb, resultsb);
  218. }
  219. for (int k = 0; k < 8 && i + k < info.bWidth; k++) {
  220. if (info.smooth) {
  221. data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter :
  222. ftRes[k] >= info.maxIter ? info.maxIter :
  223. ((float)ftRes[k]) + 1 - ::log(::log((float) (resa[k] * resa[k] + resb[k] * resb[k])) / 2) / ::log(2.0f);
  224. }
  225. else {
  226. data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter : ftRes[k];
  227. }
  228. }
  229. }
  230. }
  231. }