Browse Source

float avx compiler

Nicolas Winkler 5 years ago
parent
commit
fc96323fef

+ 5 - 3
Almond.cpp

@@ -172,14 +172,14 @@ void Almond::on_displayInfo_stateChanged(int checked)
 void Almond::on_chooseGenerator_clicked()
 {
     std::unique_ptr<ChooseGenerators> generatorsDialog;
-    if (currentView == MANDELBROT)
+    if (currentView == MANDELBROT || currentView == JULIA)
         generatorsDialog = std::make_unique<ChooseGenerators>(mandelContext, *mandelGenerator, *this);
     else if (currentView == CUSTOM)
         generatorsDialog = std::make_unique<ChooseGenerators>(mandelContext, this->currentCustom->gc, *customGenerator, *this);
     else
         return;
 
-    generatorsDialog->exec();
+    auto response = generatorsDialog->exec();
 
     auto gen = generatorsDialog->extractChosenGenerator();
     if (gen) {
@@ -310,7 +310,9 @@ void Almond::on_radioButton_2_toggled(bool checked)
 
 void Almond::on_createCustom_clicked()
 {
-    customGeneratorDialog->exec();
+    auto response = customGeneratorDialog->exec();
+    if (response != 1)
+        return;
     if (auto* frac = customGeneratorDialog->getLastCompiled()) {
         customGenerator = frac->gc.adaptiveGenerator.get();
         customGenerators.push_back(std::make_unique<FractalDef>(std::move(*frac)));

+ 1 - 0
Bitmap.h

@@ -35,6 +35,7 @@ public:
     template<typename T>
     Bitmap<T> map(std::function<T(Pixel)> f) const {
         Bitmap<T> b{ width, height };
+#pragma omp parallel for
         for (long i = 0; i < width * height; i++) {
             b.pixels[i] = f(pixels[i]);
         }

+ 0 - 6
customgenerator.cpp

@@ -54,12 +54,6 @@ void CustomGenerator::compile()
 }
 
 
-void CustomGenerator::on_compile_clicked()
-{
-    compile();
-}
-
-
 FractalDef* CustomGenerator::getLastCompiled(void)
 {
     if (!fractalDefs.empty())

+ 0 - 2
customgenerator.h

@@ -35,8 +35,6 @@ public:
     void compile();
 
 private slots:
-    void on_compile_clicked();
-
     void on_buttonBox_accepted();
 
 private:

+ 18 - 2
libmandel/include/IterationGenerator.h

@@ -18,6 +18,7 @@ namespace mnd
     template<typename T>
     class NaiveIRGenerator;
     class CompiledGenerator;
+    class CompiledGeneratorVec;
     class CompiledClGenerator;
     class CompiledClGeneratorDouble;
 
@@ -70,15 +71,30 @@ public:
 #if defined(__x86_64__) || defined(_M_X64)
 class mnd::CompiledGenerator : public mnd::MandelGenerator
 {
+protected:
     std::unique_ptr<ExecData> execData;
 public:
-    CompiledGenerator(std::unique_ptr<ExecData> execData);
+    CompiledGenerator(std::unique_ptr<ExecData> execData,
+        mnd::Precision prec = mnd::Precision::DOUBLE,
+        mnd::CpuExtension ex = mnd::CpuExtension::NONE);
+    CompiledGenerator(const CompiledGenerator&) = delete;
     CompiledGenerator(CompiledGenerator&&);
     virtual ~CompiledGenerator(void);
-    virtual void generate(const MandelInfo& info, float* data);
+    virtual void generate(const MandelInfo& info, float* data) override;
 
     std::string dump(void) const;
 };
+
+
+class mnd::CompiledGeneratorVec : public mnd::CompiledGenerator
+{
+public:
+    CompiledGeneratorVec(std::unique_ptr<ExecData> execData);
+    CompiledGeneratorVec(const CompiledGeneratorVec&) = delete;
+    CompiledGeneratorVec(CompiledGeneratorVec&&);
+    virtual ~CompiledGeneratorVec(void);
+    virtual void generate(const MandelInfo& info, float* data) override;
+};
 #endif
 
 

+ 2 - 2
libmandel/src/CpuGeneratorsAVX.cpp

@@ -497,9 +497,9 @@ void CpuGenerator<mnd::DoubleDouble, mnd::X86_AVX, parallel>::generate(const mnd
 
             for (int k = 0; k < 4 && i + k < info.bWidth; k++) {
                 if (info.smooth)
-                    data[i + k + j * info.bWidth] = ftRes[k] <= 0 ? info.maxIter :
+                    data[i + k + j * info.bWidth] = float(ftRes[k] <= 0 ? info.maxIter :
                         ftRes[k] >= info.maxIter ? info.maxIter :
-                        ((float)ftRes[k]) + 1 - ::log(::log(resa[k] * resa[k] + resb[k] * resb[k]) / 2) / ::log(2.0f);
+                        ((float)ftRes[k]) + 1 - ::logf(::logf(float(resa[k] * resa[k] + resb[k] * resb[k])) / 2) / ::logf(2.0f));
                 else
                     data[i + k + j * info.bWidth] = ftRes[k] > 0 ? float(ftRes[k]) : info.maxIter;
             }

+ 284 - 0
libmandel/src/IterationCompiler.cpp

@@ -279,6 +279,287 @@ namespace mnd
     }
 
 
+    struct CompileVisitorAVXFloat
+    {
+        using Reg = asmjit::x86::Ymm;
+
+        asmjit::x86::Compiler& cc;
+        Reg& a;
+        Reg& b;
+        Reg& x;
+        Reg& y;
+
+        Reg visitNode(ir::Node& node)
+        {
+            auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
+            if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
+                return *regPtr;
+            }
+            else {
+                Reg reg = std::visit(*this, node);
+                nodeData = reg;
+                return reg;
+            }
+        }
+
+        CompileVisitorAVXFloat(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
+            cc{ cc },
+            a{ a }, b{ b },
+            x{ x }, y{ y }
+        {
+        }
+
+        Reg operator()(const ir::Constant& c) {
+            auto constant = cc.newFloatConst(asmjit::ConstPool::kScopeLocal, mnd::convert<float>(c.value));
+            auto reg = cc.newYmmPs();
+            std::string commentStr = "move constant [";
+            commentStr += std::to_string(mnd::convert<double>(c.value));
+            commentStr += "]";
+            cc.comment(commentStr.c_str());
+            cc.vbroadcastss(reg, constant);
+            return reg;
+        }
+
+        Reg operator()(const ir::Variable& v) {
+            if (v.name == "z_re") {
+                return a;
+            }
+            else if (v.name == "z_im") {
+                return b;
+            }
+            else if (v.name == "c_re") {
+                return x;
+            }
+            else if (v.name == "c_im") {
+                return y;
+            }
+            else
+                throw mnd::ParseError(std::string("unknown variable: ") + v.name);
+        }
+
+        Reg operator()(const ir::Negation& n) {
+            auto sub = cc.newYmmPs();
+            cc.vxorps(sub, sub, sub);
+            cc.vsubps(sub, sub, visitNode(*n.value));
+            return sub;
+        }
+
+        Reg operator()(const ir::Addition& add) {
+            auto res = cc.newYmmPs();
+            cc.vaddps(res, visitNode(*add.left), visitNode(*add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Subtraction& add) {
+            auto res = cc.newYmmPs();
+            cc.vsubps(res, visitNode(*add.left), visitNode(*add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Multiplication& add) {
+            auto res = cc.newYmmPs();
+            cc.vmulps(res, visitNode(*add.left), visitNode(*add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Division& add) {
+            auto res = cc.newYmmPs();
+            cc.vdivps(res, visitNode(*add.left), visitNode(*add.right));
+            return res;
+        }
+
+        static double myAtan2(double y, double x)
+        {
+            double result = ::atan2(y, x);
+            printf("atan2(%f, %f) = %f\n", y, x, result);
+            return result;
+        }
+
+        Reg operator()(const ir::Atan2& at2) {
+            using namespace asmjit;
+            auto y = visitNode(*at2.left);
+            auto x = visitNode(*at2.right);
+            auto arg = cc.newYmmPs();
+            /*
+            double(*atanFunc)(double, double) = ::atan2;
+            cc.comment("call atan2");
+            auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
+            call->setArg(0, y);
+            call->setArg(1, x);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+
+        Reg operator()(const ir::Pow& p) {
+            using namespace asmjit;
+            auto a = visitNode(*p.left);
+            auto b = visitNode(*p.right);
+
+            auto arg = cc.newYmmPs();
+            /*double(*powFunc)(double, double) = ::pow;
+            cc.comment("call pow");
+            auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHost));
+            call->setArg(0, a);
+            call->setArg(1, b);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+
+        Reg operator()(const ir::Cos& c) {
+            using namespace asmjit;
+            auto a = visitNode(*c.value);
+
+            auto arg = cc.newYmmPs();
+            /*double(*cosFunc)(double) = ::cos;
+            cc.comment("call cos");
+            auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
+            call->setArg(0, a);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+
+        Reg operator()(const ir::Sin& s) {
+            using namespace asmjit;
+            auto a = visitNode(*s.value);
+
+            auto arg = cc.newYmmPs();
+            /*double(*sinFunc)(double) = ::sin;
+            cc.comment("call sin");
+            auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
+            call->setArg(0, a);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+
+        Reg operator()(const ir::Exp& ex) {
+            using namespace asmjit;
+            auto a = visitNode(*ex.value);
+
+            auto arg = cc.newYmmPs();
+            /*double(*expFunc)(double) = ::exp;
+            cc.comment("call exp");
+            auto call = cc.call(imm(expFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
+            call->setArg(0, a);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+
+        Reg operator()(const ir::Ln& l) {
+            using namespace asmjit;
+            auto a = visitNode(*l.value);
+
+            auto arg = cc.newYmmPs();
+            /*double(*logFunc)(double) = ::log;
+            cc.comment("call log");
+            auto call = cc.call(imm(logFunc), FuncSignatureT<double, double>(CallConv::kIdHost));
+            call->setArg(0, a);
+            call->setRet(0, arg);*/
+            return arg;
+        }
+    };
+
+    CompiledGeneratorVec compileAVXFloat(const ir::Formula& formula)
+    {
+        using namespace asmjit;
+        std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
+        JitRuntime& jitRuntime = *ed->jitRuntime;
+        ed->code->init(jitRuntime.codeInfo());
+
+        x86::Compiler& comp = *ed->compiler;
+
+        x86::Mem sixteen = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(16.0f));
+        x86::Mem one = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(1.0f));
+        x86::Mem factors = comp.newYmmConst(ConstPool::kScopeLocal, Data256::fromF32(0, 1, 2, 3, 4, 5, 6, 7));
+
+        Label startLoop = comp.newLabel();
+        Label endLoop = comp.newLabel();
+        x86::Gp maxIter = comp.newInt32();
+        x86::Gp k = comp.newInt32();
+        x86::Gp resPtr = comp.newGpq();
+        x86::Ymm adder = comp.newYmmPs();
+        x86::Ymm counter = comp.newYmmPs();
+        x86::Xmm xorig = comp.newXmmSs();
+        x86::Xmm yorig = comp.newXmmSs();
+        x86::Ymm dx = comp.newYmmPs();
+        x86::Ymm x = comp.newYmmPs();
+        x86::Ymm y = comp.newYmmPs();
+        x86::Ymm a = comp.newYmmPs();
+        x86::Ymm b = comp.newYmmPs();
+        comp.addFunc(FuncSignatureT<int, float, float, float, int, float*>(CallConv::kIdHost));
+
+        comp.setArg(0, xorig);
+        comp.setArg(1, yorig);
+        comp.setArg(2, dx.xmm());
+        comp.setArg(3, maxIter);
+        comp.setArg(4, resPtr);
+
+        comp.vmovaps(adder, one);
+        comp.vxorps(counter, counter, counter);
+
+        comp.vshufps(xorig, xorig, xorig, 0);
+        comp.vshufps(yorig, yorig, yorig, 0);
+        comp.vshufps(dx.half(), dx.half(), dx.half(), 0);
+        comp.vinsertf128(x, xorig.ymm(), xorig, 1);
+        comp.vinsertf128(y, yorig.ymm(), yorig, 1);
+        comp.vinsertf128(dx, dx, dx.xmm(), 1);
+
+        comp.vmulps(dx, dx, factors);
+        comp.vaddps(x, x, dx);
+
+        CompileVisitorAVXFloat formVisitor{ comp, a, b, x, y };
+        auto startA = std::visit(formVisitor, *formula.startA);
+        auto startB = std::visit(formVisitor, *formula.startB);
+        comp.vmovaps(a, startA);
+        comp.vmovaps(b, startB);
+
+        comp.xor_(k, k);
+
+        comp.bind(startLoop);
+
+
+        auto newA = std::visit(formVisitor, *formula.newA);
+        auto newB = std::visit(formVisitor, *formula.newB);
+        comp.vmovaps(a, newA);
+        comp.vmovaps(b, newB);
+
+        x86::Ymm aa = comp.newYmmPs();
+        x86::Ymm bb = comp.newYmmPs();
+        x86::Ymm cmp = comp.newYmmPs();
+        comp.vmulps(aa, a, a);
+        comp.vmulps(bb, b, b);
+        comp.vaddps(bb, bb, aa);
+        comp.vcmpps(cmp, bb, sixteen, 18);
+        comp.vandps(adder, adder, cmp);
+        comp.vaddps(counter, counter, adder);
+
+        comp.cmp(k, maxIter);
+        comp.je(endLoop);
+        comp.add(k, 1);
+
+        comp.vtestps(cmp, cmp);
+        comp.jne(startLoop);
+
+        comp.bind(endLoop);
+
+        comp.vmovups(x86::xmmword_ptr(resPtr), counter.half());
+        comp.vextractf128(x86::xmmword_ptr(resPtr, 16), counter, 0x1);
+
+        comp.ret(k);
+        comp.endFunc();
+        auto err = comp.finalize();
+        if (err == asmjit::kErrorOk) {
+            err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
+            if (err != asmjit::kErrorOk)
+                throw "error adding function";
+        }
+        else {
+            throw "error compiling";
+        }
+
+        return CompiledGeneratorVec{ std::move(ed) };
+    }
+
+
     struct OpenClVisitor
     {
         int varnameCounter = 0;
@@ -442,13 +723,16 @@ namespace mnd
         irf.optimize();
         printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
         auto dg = std::make_unique<CompiledGenerator>(compile(irf));
+        auto dgavx = std::make_unique<CompiledGeneratorVec>(compileAVXFloat(irf));
         printf("asm: %s\n", dg->dump().c_str()); fflush(stdout);
+        printf("asm avxvec: %s\n", dgavx->dump().c_str()); fflush(stdout);
         
         //auto dg = std::make_unique<NaiveIRGenerator>(*irf, mnd::getPrecision<double>());
 
         std::vector<std::unique_ptr<mnd::MandelGenerator>> vec;
         //vec.push_back(std::move(ng));
         vec.push_back(std::move(dg));
+        vec.push_back(std::move(dgavx));
         return vec;
     }
 

+ 3 - 0
libmandel/src/IterationFormula.cpp

@@ -207,6 +207,9 @@ mnd::IterationFormula mnd::IterationFormula::clone(void) const
             else if constexpr (std::is_same<T, mnd::Negation>::value) {
                 return mnd::Negation{ cloner(*x.operand) };
             }
+            else if constexpr (std::is_same<T, mnd::Addition>::value) {
+                return mnd::Addition{ cloner(*x.left), cloner(*x.right), x.subtraction };
+            }
             else {
                 return T{ cloner(*x.left), cloner(*x.right) };
             }

+ 41 - 2
libmandel/src/IterationGenerator.cpp

@@ -260,12 +260,14 @@ double NaiveIRGenerator<T>::calc(mnd::ir::Node* expr, double a, double b, double
 
 
 using mnd::CompiledGenerator;
+using mnd::CompiledGeneratorVec;
 using mnd::CompiledClGenerator;
 using mnd::CompiledClGeneratorDouble;
 
 
-CompiledGenerator::CompiledGenerator(std::unique_ptr<mnd::ExecData> execData) :
-    MandelGenerator{ mnd::Precision::DOUBLE },
+CompiledGenerator::CompiledGenerator(std::unique_ptr<mnd::ExecData> execData,
+    mnd::Precision prec, mnd::CpuExtension ex) :
+    MandelGenerator{ prec, ex },
     execData{ std::move(execData) }
 {
 }
@@ -328,6 +330,43 @@ std::string CompiledGenerator::dump(void) const
 }
 
 
+CompiledGeneratorVec::CompiledGeneratorVec(std::unique_ptr<mnd::ExecData> execData) :
+    CompiledGenerator{ std::move(execData), mnd::Precision::FLOAT, mnd::CpuExtension::X86_AVX }
+{
+}
+
+
+CompiledGeneratorVec::CompiledGeneratorVec(CompiledGeneratorVec&&) = default;
+
+
+CompiledGeneratorVec::~CompiledGeneratorVec(void)
+{
+}
+
+
+void CompiledGeneratorVec::generate(const mnd::MandelInfo& info, float* data)
+{
+    using IterFunc = int (*)(float, float, float, int, float*);
+
+    double dx = mnd::convert<double>(info.view.width / info.bWidth);
+
+    omp_set_num_threads(omp_get_num_procs());
+#pragma omp parallel for schedule(static, 1)
+    for (int i = 0; i < info.bHeight; i++) {
+        double y = mnd::convert<double>(info.view.y + info.view.height * i / info.bHeight);
+        for (int j = 0; j < info.bWidth; j += 8) {
+            double x = mnd::convert<double>(info.view.x + info.view.width * j / info.bWidth);
+            float result[8];
+            IterFunc iterFunc = asmjit::ptr_as_func<IterFunc>(this->execData->iterationFunc);
+            int k = iterFunc(x, y, dx, info.maxIter-1, result);
+
+            for (int k = 0; k < 8 && j + k < info.bWidth; k++)
+                data[i * info.bWidth + j + k] = result[k];
+        }
+    }
+}
+
+
 #ifdef WITH_OPENCL
 CompiledClGenerator::CompiledClGenerator(mnd::MandelDevice& device, const std::string& code) :
     ClGeneratorFloat{ device, code }