Ver código fonte

avxfma faast

Nicolas Winkler 5 anos atrás
pai
commit
42dfaeabc6
1 arquivos alterados com 44 adições e 17 exclusões
  1. 44 17
      libmandel/src/CpuGeneratorsAVXFMA.cpp

+ 44 - 17
libmandel/src/CpuGeneratorsAVXFMA.cpp

@@ -44,12 +44,14 @@ void CpuGenerator<float, mnd::X86_AVX_FMA, parallel>::generate(const mnd::Mandel
         T y = T(view.y) + T(j) * T(view.height / info.bHeight);
         __m256 ys = {y, y, y, y, y, y, y, y};
         long i = 0;
-        for (i; i < info.bWidth; i += 16) {
+        for (i; i < info.bWidth; i += 24) {
             __m256 pixc = { float(i), float(i + 1), float(i + 2), float(i + 3), float(i + 4), float(i + 5), float(i + 6), float(i + 7) };
             __m256 pixc2 = { float(i + 8), float(i + 9), float(i + 10), float(i + 11), float(i + 12), float(i + 13), float(i + 14), float(i + 15) };
+            __m256 pixc3 = { float(i + 16), float(i + 17), float(i + 18), float(i + 19), float(i + 20), float(i + 21), float(i + 22), float(i + 23) };
 
             __m256 xs = _mm256_add_ps(_mm256_mul_ps(dpp, pixc), viewx);
             __m256 xs2 = _mm256_add_ps(_mm256_mul_ps(dpp, pixc2), viewx);
+            __m256 xs3 = _mm256_add_ps(_mm256_mul_ps(dpp, pixc3), viewx);
 
             __m256 counter = { 0, 0, 0, 0, 0, 0, 0, 0 };
             __m256 adder = { 1, 1, 1, 1, 1, 1, 1, 1 };
@@ -61,62 +63,84 @@ void CpuGenerator<float, mnd::X86_AVX_FMA, parallel>::generate(const mnd::Mandel
             __m256 resultsa2 = { 0, 0, 0, 0, 0, 0, 0, 0 };
             __m256 resultsb2 = { 0, 0, 0, 0, 0, 0, 0, 0 };
 
+            __m256 counter3 = { 0, 0, 0, 0, 0, 0, 0, 0 };
+            __m256 adder3 = { 1, 1, 1, 1, 1, 1, 1, 1 };
+            __m256 resultsa3 = { 0, 0, 0, 0, 0, 0, 0, 0 };
+            __m256 resultsb3 = { 0, 0, 0, 0, 0, 0, 0, 0 };
+
             __m256 threshold = { 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f, 16.0f };
             __m256 two = { 2, 2, 2, 2, 2, 2, 2, 2 };
 
             __m256 a = xs;
             __m256 a2 = xs2;
+            __m256 a3 = xs3;
             __m256 b = ys;
             __m256 b2 = ys;
+            __m256 b3 = ys;
 
             __m256 cx = info.julia ? juliaX : xs;
             __m256 cx2 = info.julia ? juliaX : xs2;
+            __m256 cx3 = info.julia ? juliaX : xs3;
             __m256 cy = info.julia ? juliaY : ys;
 
             if (info.smooth) {
                 for (int k = 0; k < info.maxIter; k++) {
                     __m256 bb = _mm256_mul_ps(b, b);
                     __m256 bb2 = _mm256_mul_ps(b2, b2);
+                    __m256 bb3 = _mm256_mul_ps(b3, b3);
                     __m256 ab = _mm256_mul_ps(a, b);
                     __m256 ab2 = _mm256_mul_ps(a2, b2);
+                    __m256 ab3 = _mm256_mul_ps(a3, b3);
                     a = _mm256_add_ps(_mm256_fmsub_ps(a, a, bb), cx);
                     a2 = _mm256_add_ps(_mm256_fmsub_ps(a2, a2, bb2), cx2);
+                    a3 = _mm256_add_ps(_mm256_fmsub_ps(a3, a3, bb3), cx3);
                     b = _mm256_fmadd_ps(two, ab, cy);
                     b2 = _mm256_fmadd_ps(two, ab2, cy);
+                    b3 = _mm256_fmadd_ps(two, ab3, cy);
                     __m256 cmp = _mm256_cmp_ps(_mm256_fmadd_ps(a, a, bb), threshold, _CMP_LE_OQ);
                     __m256 cmp2 = _mm256_cmp_ps(_mm256_fmadd_ps(a2, a2, bb2), threshold, _CMP_LE_OQ);
+                    __m256 cmp3 = _mm256_cmp_ps(_mm256_fmadd_ps(a3, a3, bb3), threshold, _CMP_LE_OQ);
                     resultsa = _mm256_or_ps(_mm256_andnot_ps(cmp, resultsa), _mm256_and_ps(cmp, a));
                     resultsb = _mm256_or_ps(_mm256_andnot_ps(cmp, resultsb), _mm256_and_ps(cmp, b));
                     resultsa2 = _mm256_or_ps(_mm256_andnot_ps(cmp2, resultsa2), _mm256_and_ps(cmp2, a2));
                     resultsb2 = _mm256_or_ps(_mm256_andnot_ps(cmp2, resultsb2), _mm256_and_ps(cmp2, b2));
+                    resultsa3 = _mm256_or_ps(_mm256_andnot_ps(cmp3, resultsa3), _mm256_and_ps(cmp3, a3));
+                    resultsb3 = _mm256_or_ps(_mm256_andnot_ps(cmp3, resultsb3), _mm256_and_ps(cmp3, b3));
                     adder = _mm256_and_ps(adder, cmp);
                     counter = _mm256_add_ps(counter, adder);
                     adder2 = _mm256_and_ps(adder2, cmp2);
                     counter2 = _mm256_add_ps(counter2, adder2);
-                    if ((k & 0x7) == 0 && _mm256_testz_ps(cmp, cmp) != 0 && _mm256_testz_ps(cmp2, cmp2) != 0) {
+                    adder3 = _mm256_and_ps(adder3, cmp3);
+                    counter3 = _mm256_add_ps(counter3, adder3);
+                    if ((k & 0x7) == 0 && _mm256_testz_ps(cmp, cmp) != 0 && _mm256_testz_ps(cmp2, cmp2) != 0 && _mm256_testz_ps(cmp3, cmp3) != 0) {
                         break;
                     }
                 }
             }
             else {
                 for (int k = 0; k < info.maxIter; k++) {
-                    __m256 aa = _mm256_mul_ps(a, a);
-                    __m256 aa2 = _mm256_mul_ps(a2, a2);
                     __m256 bb = _mm256_mul_ps(b, b);
                     __m256 bb2 = _mm256_mul_ps(b2, b2);
-                    __m256 abab = _mm256_mul_ps(a, b); abab = _mm256_add_ps(abab, abab);
-                    __m256 abab2 = _mm256_mul_ps(a2, b2); abab2 = _mm256_add_ps(abab2, abab2);
-                    a = _mm256_add_ps(_mm256_sub_ps(aa, bb), cx);
-                    a2 = _mm256_add_ps(_mm256_sub_ps(aa2, bb2), cx2);
-                    b = _mm256_add_ps(abab, cy);
-                    b2 = _mm256_add_ps(abab2, cy);
-                    __m256 cmp = _mm256_cmp_ps(_mm256_add_ps(aa, bb), threshold, _CMP_LE_OQ);
-                    __m256 cmp2 = _mm256_cmp_ps(_mm256_add_ps(aa2, bb2), threshold, _CMP_LE_OQ);
+                    __m256 bb3 = _mm256_mul_ps(b3, b3);
+                    __m256 ab = _mm256_mul_ps(a, b);
+                    __m256 ab2 = _mm256_mul_ps(a2, b2);
+                    __m256 ab3 = _mm256_mul_ps(a3, b3);
+                    a = _mm256_add_ps(_mm256_fmsub_ps(a, a, bb), cx);
+                    a2 = _mm256_add_ps(_mm256_fmsub_ps(a2, a2, bb2), cx2);
+                    a3 = _mm256_add_ps(_mm256_fmsub_ps(a3, a3, bb3), cx3);
+                    b = _mm256_fmadd_ps(two, ab, cy);
+                    b2 = _mm256_fmadd_ps(two, ab2, cy);
+                    b3 = _mm256_fmadd_ps(two, ab3, cy);
+                    __m256 cmp = _mm256_cmp_ps(_mm256_fmadd_ps(a, a, bb), threshold, _CMP_LE_OQ);
+                    __m256 cmp2 = _mm256_cmp_ps(_mm256_fmadd_ps(a2, a2, bb2), threshold, _CMP_LE_OQ);
+                    __m256 cmp3 = _mm256_cmp_ps(_mm256_fmadd_ps(a3, a3, bb3), threshold, _CMP_LE_OQ);
                     adder = _mm256_and_ps(adder, cmp);
                     counter = _mm256_add_ps(counter, adder);
                     adder2 = _mm256_and_ps(adder2, cmp2);
                     counter2 = _mm256_add_ps(counter2, adder2);
-                    if ((k & 0x7) == 0 && _mm256_testz_ps(cmp, cmp) != 0 && _mm256_testz_ps(cmp2, cmp2) != 0) {
+                    adder3 = _mm256_and_ps(adder3, cmp3);
+                    counter3 = _mm256_add_ps(counter3, adder3);
+                    if ((k & 0x7) == 0 && _mm256_testz_ps(cmp, cmp) != 0 && _mm256_testz_ps(cmp2, cmp2) != 0 && _mm256_testz_ps(cmp3, cmp3) != 0) {
                         break;
                     }
                 }
@@ -130,18 +154,21 @@ void CpuGenerator<float, mnd::X86_AVX_FMA, parallel>::generate(const mnd::Mandel
                 return static_cast<float*>(aligned);
             };
 
-            float resData[64];
+            float resData[96];
             float* ftRes = alignVec(resData);
-            float* resa = ftRes + 16;
-            float* resb = resa + 16;
+            float* resa = ftRes + 24;
+            float* resb = resa + 24;
 
             _mm256_store_ps(ftRes, counter);
             _mm256_store_ps(ftRes + 8, counter2);
+            _mm256_store_ps(ftRes + 16, counter3);
             _mm256_store_ps(resa, resultsa);
             _mm256_store_ps(resa + 8, resultsa2);
+            _mm256_store_ps(resa + 16, resultsa3);
             _mm256_store_ps(resb, resultsb);
             _mm256_store_ps(resb + 8, resultsb2);
-            for (int k = 0; k < 16 && i + k < info.bWidth; k++) {
+            _mm256_store_ps(resb + 16, resultsb3);
+            for (int k = 0; k < 24 && i + k < info.bWidth; k++) {
                 if (info.smooth) {
                     data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter :
                         ftRes[k] >= info.maxIter ? info.maxIter :