Pārlūkot izejas kodu

Improved Cl Code Generation

Nicolas Winkler 5 gadi atpakaļ
vecāks
revīzija
4ec2902292

+ 12 - 3
choosegenerators.cpp

@@ -296,18 +296,27 @@ void ChooseGenerators::on_compile_clicked()
 {
     QString formula = this->ui->formula->text();
     mnd::IterationFormula itf{ mnd::parse(formula.toStdString()) };
+
+    std::string expr = mnd::toString(*itf.expr);
+    printf("%s\n", expr.c_str()); fflush(stdout);
     //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
     mnd::ir::Formula irform = mnd::expand(itf);
+    printf("%s\n", irform.toString().c_str()); fflush(stdout);
     auto cg = std::make_unique<mnd::CompiledGenerator>(mnd::compile(irform));
-    std::string expr = mnd::toString(*itf.expr);
     std::string asmCode = cg->dump();
-    printf("%s\n", expr.c_str()); fflush(stdout);
-    printf("%s\n", irform.toString().c_str()); fflush(stdout);
     printf("%s\n", asmCode.c_str()); fflush(stdout);
     /*QMessageBox msgBox(nullptr);
     msgBox.setText(QString::fromStdString(asmCode));
     msgBox.exec();*/
     chosenGenerator = std::move(cg);
+
+    const mnd::MandelDevice& dev = mndCtxt.getDevices()[0];
+    try {
+        chosenGenerator = mnd::compileCl(irform, dev);
+    }
+    catch(const std::string& msg) {
+        printf("error compiling: %s", msg.c_str());
+    }
 }
 
 void ChooseGenerators::on_benchmark_clicked()

+ 1 - 1
libmandel/include/ClGenerators.h

@@ -44,7 +44,7 @@ protected:
 class mnd::ClGeneratorFloat : public ClGenerator
 {
 public:
-    ClGeneratorFloat(cl::Device device);
+    ClGeneratorFloat(cl::Device device, const std::string& code);
     virtual ~ClGeneratorFloat(void) = default;
 
 protected:

+ 22 - 2
libmandel/include/IterationCompiler.h

@@ -2,18 +2,21 @@
 #define MANDEL_ITERATIONCOMPILER_H
 
 #include "Generators.h"
+#include "ClGenerators.h"
 #include "IterationIR.h"
+#include <memory>
 
 namespace mnd
 {
     struct ExecData;
     class CompiledGenerator;
+    class CompiledClGenerator;
 
+    // forward declare
     class MandelContext;
+    class MandelDevice;
 
     mnd::ExecData compile(mnd::MandelContext& mndCtxt);
-
- 
 }
 void squareTest();
 
@@ -30,9 +33,26 @@ public:
     std::string dump(void) const;
 };
 
+
+#ifdef WITH_OPENCL
+class mnd::CompiledClGenerator : public mnd::ClGeneratorFloat
+{
+public:
+    CompiledClGenerator(const MandelDevice& device, const std::string& code);
+    //virtual ~CompiledGenerator(void);
+    //virtual void generate(const MandelInfo& info, float* data);
+    virtual std::string getKernelCode(bool smooth) const override;
+    virtual void generate(const MandelInfo& info, float* data);
+
+    //std::string dump(void) const;
+};
+#endif // WITH_OPENCL
+
+
 namespace mnd
 {
     CompiledGenerator compile(const ir::Formula& formula);
+    std::unique_ptr<MandelGenerator> compileCl(const ir::Formula& formula, const MandelDevice& md);
 }
 
 

+ 31 - 4
libmandel/include/IterationIR.h

@@ -4,6 +4,7 @@
 #include <string>
 #include <vector>
 #include <variant>
+#include <any>
 
 #include "IterationFormula.h"
 #include "Arena.h"
@@ -12,6 +13,8 @@ namespace mnd
 {
     namespace ir
     {
+        struct NodeBase;
+
         struct Constant;
         struct Variable;
         struct UnaryOperation;
@@ -52,68 +55,92 @@ namespace mnd
 }
 
 
-struct mnd::ir::Constant
+struct mnd::ir::NodeBase
+{
+    std::any nodeData;
+};
+
+
+struct mnd::ir::Constant : NodeBase
 {
     double value;
+    inline Constant(double val) : value{ val } {}
 };
 
 
-struct mnd::ir::Variable
+struct mnd::ir::Variable : NodeBase
 {
     std::string name;
+    inline Variable(const std::string name) : name{ name } {}
 };
 
 
-struct mnd::ir::UnaryOperation
+struct mnd::ir::UnaryOperation : NodeBase
 {
     Node* value;
+    inline UnaryOperation(Node* value) : value{ value } {}
 };
 
 
 struct mnd::ir::Negation : mnd::ir::UnaryOperation
 {
+    inline Negation(Node* value) : UnaryOperation{ value } {}
 };
 
 
-struct mnd::ir::BinaryOperation
+struct mnd::ir::BinaryOperation : NodeBase
 {
     Node* left;
     Node* right;
+    inline BinaryOperation(Node* left, Node* right) :
+        left{ left }, right{ right } {}
 };
 
 
 struct mnd::ir::Addition : mnd::ir::BinaryOperation
 {
+    inline Addition(Node* left, Node* right) :
+        BinaryOperation{ left, right } {}
 };
 
 
 struct mnd::ir::Subtraction : mnd::ir::BinaryOperation
 {
+    inline Subtraction(Node* left, Node* right) :
+        BinaryOperation{ left, right } {}
 };
 
 
 struct mnd::ir::Multiplication : mnd::ir::BinaryOperation
 {
+    inline Multiplication(Node* left, Node* right) :
+        BinaryOperation{ left, right } {}
 };
 
 
 struct mnd::ir::Atan2 : mnd::ir::BinaryOperation
 {
+    inline Atan2(Node* left, Node* right) :
+        BinaryOperation{ left, right } {}
 };
 
 
 struct mnd::ir::Pow : mnd::ir::BinaryOperation
 {
+    inline Pow(Node* left, Node* right) :
+        BinaryOperation{ left, right } {}
 };
 
 
 struct mnd::ir::Cos : mnd::ir::UnaryOperation
 {
+    inline Cos(Node* value) : UnaryOperation{ value } {}
 };
 
 
 struct mnd::ir::Sin : mnd::ir::UnaryOperation
 {
+    inline Sin(Node* value) : UnaryOperation{ value } {}
 };
 
 

+ 5 - 1
libmandel/include/Mandel.h

@@ -25,6 +25,8 @@ namespace mnd
     class MandelContext;
     class MandelDevice;
 
+    struct ClDeviceWrapper;
+
     extern MandelContext initializeContext(void);
 
     const std::string& getGeneratorName(mnd::GeneratorType);
@@ -66,10 +68,11 @@ private:
 
     std::string vendor;
     std::string name;
+    std::unique_ptr<ClDeviceWrapper> clDevice;
 
     std::map<GeneratorType, std::unique_ptr<MandelGenerator>> mandelGenerators;
 
-    MandelDevice(void);
+    MandelDevice(ClDeviceWrapper);
 public:
     MandelDevice(const MandelDevice&) = delete;
     MandelDevice(MandelDevice&&) = default;
@@ -80,6 +83,7 @@ public:
     inline const std::string& getName(void) const { return name; }
 
     MandelGenerator* getGenerator(GeneratorType type) const;
+    inline const ClDeviceWrapper& getClDevice(void) const { return *clDevice; }
 
     std::vector<GeneratorType> getSupportedTypes(void) const;
 };

+ 24 - 0
libmandel/include/OpenClInternal.h

@@ -0,0 +1,24 @@
+#ifndef MANDEL_OPENCLINTERNAL_H
+#define MANDEL_OPENCLINTERNAL_H
+
+#ifdef WITH_OPENCL
+#ifdef __APPLE__
+#include <OpenCL/cl.hpp>
+#else
+#include <CL/cl.hpp>
+#endif
+#endif
+namespace mnd
+{
+    struct ClDeviceWrapper
+    {
+#ifdef WITH_OPENCL
+        cl::Device device;
+#endif
+    };
+}
+
+
+#endif // MANDEL_OPENCLINTERNAL_H
+
+

+ 2 - 4
libmandel/src/ClGenerators.cpp

@@ -121,15 +121,13 @@ void ClGenerator::generate(const mnd::MandelInfo& info, float* data)
 }
 
 
-ClGeneratorFloat::ClGeneratorFloat(cl::Device device) :
+ClGeneratorFloat::ClGeneratorFloat(cl::Device device, const std::string& code) :
     ClGenerator{ device, mnd::getPrecision<float>() }
 {
     context = Context{ device };
     Program::Sources sources;
 
-    std::string kcode = this->getKernelCode(false);
-
-    sources.push_back({ kcode.c_str(), kcode.length() });
+    sources.push_back({ code.c_str(), code.length() });
 
     program = Program{ context, sources };
     if (program.build({ device }) != CL_SUCCESS) {

+ 195 - 14
libmandel/src/IterationCompiler.cpp

@@ -1,10 +1,16 @@
 #include "IterationCompiler.h"
 
+#include "Mandel.h"
+#include "OpenClInternal.h"
+#include "OpenClCode.h"
+
 #include <asmjit/asmjit.h>
 #include <cmath>
-#include "Mandel.h"
 #include <omp.h>
+#include <any>
+#include <string>
 
+using namespace std::string_literals;
 namespace mnd
 {
     struct ExecData
@@ -43,6 +49,19 @@ namespace mnd
         Reg& x;
         Reg& y;
 
+        Reg visitNode(ir::Node& node)
+        {
+            auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
+            if (Reg* regPtr = std::any_cast<Reg>(&nodeData)) {
+                return *regPtr;
+            }
+            else {
+                Reg reg = std::visit(*this, node);
+                nodeData = reg;
+                return reg;
+            }
+        }
+
         CompileVisitor(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
             cc{ cc },
             a{ a }, b{ b },
@@ -77,35 +96,35 @@ namespace mnd
         Reg operator()(const ir::Negation& n) {
             auto sub = cc.newXmmSd();
             cc.xorpd(sub, sub);
-            cc.subsd(sub, std::visit((*this), *n.value));
+            cc.subsd(sub, visitNode(*n.value));
             return sub;
         }
 
         Reg operator()(const ir::Addition& add) {
             auto res = cc.newXmmSd();
-            cc.movapd(res, std::visit((*this), *add.left));
-            cc.addsd(res, std::visit((*this), *add.right));
+            cc.movapd(res, visitNode(*add.left));
+            cc.addsd(res, visitNode(*add.right));
             return res;
         }
 
         Reg operator()(const ir::Subtraction& add) {
             auto res = cc.newXmmSd();
-            cc.movapd(res, std::visit((*this), *add.left));
-            cc.subsd(res, std::visit((*this), *add.right));
+            cc.movapd(res, visitNode(*add.left));
+            cc.subsd(res, visitNode(*add.right));
             return res;
         }
 
         Reg operator()(const ir::Multiplication& add) {
             auto res = cc.newXmmSd();
-            cc.movapd(res, std::visit((*this), *add.left));
-            cc.mulsd(res, std::visit((*this), *add.right));
+            cc.movapd(res, visitNode(*add.left));
+            cc.mulsd(res, visitNode(*add.right));
             return res;
         }
 
         Reg operator()(const ir::Atan2& at2) {
             using namespace asmjit;
-            auto y = std::visit((*this), *at2.left);
-            auto x = std::visit((*this), *at2.right);
+            auto y = visitNode(*at2.left);
+            auto x = visitNode(*at2.right);
 
             auto arg = cc.newXmmSd();
             double(*atanFunc)(double, double) = ::atan2;
@@ -118,8 +137,8 @@ namespace mnd
 
         Reg operator()(const ir::Pow& p) {
             using namespace asmjit;
-            auto a = std::visit((*this), *p.left);
-            auto b = std::visit((*this), *p.right);
+            auto a = visitNode(*p.left);
+            auto b = visitNode(*p.right);
 
             auto arg = cc.newXmmSd();
             double(*powFunc)(double, double) = ::pow;
@@ -132,7 +151,7 @@ namespace mnd
 
         Reg operator()(const ir::Cos& c) {
             using namespace asmjit;
-            auto a = std::visit((*this), *c.value);
+            auto a = visitNode(*c.value);
 
             auto arg = cc.newXmmSd();
             double(*cosFunc)(double) = ::cos;
@@ -144,7 +163,7 @@ namespace mnd
 
         Reg operator()(const ir::Sin& s) {
             using namespace asmjit;
-            auto a = std::visit((*this), *s.value);
+            auto a = visitNode(*s.value);
 
             auto arg = cc.newXmmSd();
             double(*sinFunc)(double) = ::sin;
@@ -220,10 +239,133 @@ namespace mnd
         }
         return CompiledGenerator{ std::move(ed) };
     }
+
+
+    struct OpenClVisitor
+    {
+        int varnameCounter = 0;
+        std::stringstream code;
+
+        std::string createVarname(void)
+        {
+            return "tmp"s + std::to_string(varnameCounter++);
+        }
+
+        std::string visitNode(ir::Node& node)
+        {
+            auto& nodeData = std::visit([] (auto& n) -> std::any& { return n.nodeData; }, node);
+            if (std::string* var = std::any_cast<std::string>(&nodeData)) {
+                return *var;
+            }
+            else {
+                std::string value = std::visit(*this, node);
+                std::string varname = createVarname();
+                code << "float " << varname << " = " << value << ";" << std::endl;
+                nodeData = varname;
+                return varname;
+            }
+        }
+
+        std::string operator()(const ir::Constant& c) {
+            return std::to_string(c.value);
+        }
+
+        std::string operator()(const ir::Variable& v) {
+            return v.name;
+        }
+
+        std::string operator()(const ir::Negation& n) {
+            return "-("s + visitNode(*n.value) + ")";
+        }
+
+        std::string operator()(const ir::Addition& a) {
+            return "("s + visitNode(*a.left) + ") + (" + visitNode(*a.right) + ")";
+        }
+
+        std::string operator()(const ir::Subtraction& a) {
+            return "("s + visitNode(*a.left) + ") - (" + visitNode(*a.right) + ")";
+        }
+
+        std::string operator()(const ir::Multiplication& a) {
+            return "("s + visitNode(*a.left) + ") * (" + visitNode(*a.right) + ")";
+        }
+
+        std::string operator()(const ir::Atan2& a) {
+            return "atan2("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
+        }
+
+        std::string operator()(const ir::Pow& a) {
+            return "pow("s + visitNode(*a.left) + ", " + visitNode(*a.right) + ")";
+        }
+
+        std::string operator()(const ir::Cos& a) {
+            return "cos("s + visitNode(*a.value) + ")";
+        }
+
+        std::string operator()(const ir::Sin& a) {
+            return "sin("s + visitNode(*a.value) + ")";
+        }
+    };
+
+    std::string compileToOpenCl(const ir::Formula& formula)
+    {
+        OpenClVisitor ocv;
+        std::string newA = ocv.visitNode(*formula.newA);
+        std::string newB = ocv.visitNode(*formula.newB);
+        std::string prelude = 
+"__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
+"   int index = get_global_id(0);\n"
+"   int ix = index % width;\n"
+"   int iy = index / width;\n"
+"   float a = ix * pixelScaleX + xl;\n"
+"   float b = iy * pixelScaleY + yt;\n"
+"   float x = a;\n"
+"   float y = b;\n"
+"\n"
+"   int n = 0;\n"
+"   while (n < max - 1) {\n";
+
+        std::string orig = 
+"       float aa = a * a;"
+"       float bb = b * b;"
+"       float ab = a * b;"
+"       a = aa - bb + x;"
+"       b = ab + ab + y;";
+
+    
+        std::string after = 
+"       if (a * a + b * b > 16) break;\n"
+"       n++;\n"
+"   }\n"
+"   if (n >= max - 1) {\n"
+"       A[index] = max;\n"
+"   }\n"
+"   else {\n"
+"       A[index] = ((float)n);\n"
+"   }\n"
+"}\n";
+
+
+        std::string code = prelude + ocv.code.str();
+        code += "a = " + newA + ";\n";
+        code += "b = " + newB + ";\n";
+        code += after;
+        //code = mnd::getFloat_cl();
+        printf("cl: %s\n", code.c_str());
+        return code;
+    }
+
+#ifdef WITH_OPENCL
+    std::unique_ptr<MandelGenerator> compileCl(const ir::Formula& formula, const MandelDevice& md)
+    {
+        return std::make_unique<CompiledClGenerator>(md, compileToOpenCl(formula));
+    }
+#endif
 }
 
 
 using mnd::CompiledGenerator;
+using mnd::CompiledClGenerator;
 
 
 
@@ -298,6 +440,45 @@ std::string CompiledGenerator::dump(void) const
 }
 
 
+#ifdef WITH_OPENCL
+CompiledClGenerator::CompiledClGenerator(const MandelDevice& device, const std::string& code) :
+    ClGeneratorFloat{ device.getClDevice().device, code }
+{
+}
+
+
+std::string CompiledClGenerator::getKernelCode(bool smooth) const
+{
+    return "";
+}
+
+void CompiledClGenerator::generate(const mnd::MandelInfo& info, float* data)
+{
+    ::size_t bufferSize = info.bWidth * info.bHeight * sizeof(float);
+
+    cl::Buffer buffer_A(context, CL_MEM_WRITE_ONLY, bufferSize);
+    float pixelScaleX = float(info.view.width / info.bWidth);
+    float pixelScaleY = float(info.view.height / info.bHeight);
+
+    cl::Kernel iterate = cl::Kernel(program, "iterate");
+    iterate.setArg(0, buffer_A);
+    iterate.setArg(1, int(info.bWidth));
+    iterate.setArg(2, float(info.view.x));
+    iterate.setArg(3, float(info.view.y));
+    iterate.setArg(4, float(pixelScaleX));
+    iterate.setArg(5, float(pixelScaleY));
+    iterate.setArg(6, int(info.maxIter));
+    iterate.setArg(7, int(info.smooth ? 1 : 0));
+    iterate.setArg(8, int(info.julia ? 1 : 0));
+    iterate.setArg(9, float(info.juliaX));
+    iterate.setArg(10, float(info.juliaY));
+
+    queue.enqueueNDRangeKernel(iterate, 0, cl::NDRange(info.bWidth * info.bHeight / 4));
+    queue.enqueueReadBuffer(buffer_A, CL_TRUE, 0, bufferSize, data);
+}
+
+#endif // WITH_OPENCL
+
 using namespace asmjit;
 
 struct Visitor

+ 1 - 0
libmandel/src/IterationIR.cpp

@@ -29,6 +29,7 @@ namespace mnd
 
         NodePair operator() (const Variable& v)
         {
+            //printf("var %s\n", v.name.c_str()); fflush(stdout);
             if (v.name == "z") {
                 Node* a = arena.allocate(ir::Variable{ "a" });
                 Node* b = arena.allocate(ir::Variable{ "b" });

+ 6 - 3
libmandel/src/Mandel.cpp

@@ -4,6 +4,8 @@
 #include "CpuGenerators.h"
 #include "JuliaGenerators.h"
 #include "ClGenerators.h"
+#include "OpenClInternal.h"
+#include "OpenClCode.h"
 
 #include <asmjit/asmjit.h>
 
@@ -78,7 +80,8 @@ MandelContext mnd::initializeContext(void)
 }
 
 
-MandelDevice::MandelDevice(void)
+MandelDevice::MandelDevice(ClDeviceWrapper device) :
+    clDevice{ std::make_unique<ClDeviceWrapper>(std::move(device)) }
 {
 }
 
@@ -267,7 +270,7 @@ std::vector<MandelDevice> MandelContext::createDevices(void)
             auto supportsDouble = extensions.find("cl_khr_fp64") != std::string::npos;
 
             //printf("Device extensions: %s\n", ext.c_str());
-            MandelDevice md;
+            MandelDevice md{ ClDeviceWrapper{ device } };
 
             //printf("clock: %d", device.getInfo<CL_DEVICE_MAX_CLOCK_FREQUENCY>());
 
@@ -275,7 +278,7 @@ std::vector<MandelDevice> MandelContext::createDevices(void)
             md.vendor = device.getInfo<CL_DEVICE_VENDOR>();
             //printf("    using opencl device: %s\n", md.name.c_str());
             try {
-                md.mandelGenerators.insert({ GeneratorType::FLOAT, std::make_unique<ClGeneratorFloat>(device) });
+                md.mandelGenerators.insert({ GeneratorType::FLOAT, std::make_unique<ClGeneratorFloat>(device, mnd::getFloat_cl()) });
                 md.mandelGenerators.insert({ GeneratorType::FIXED64, std::make_unique<ClGenerator64>(device) });
                 md.mandelGenerators.insert({ GeneratorType::DOUBLE_FLOAT, std::make_unique<ClGeneratorDoubleFloat>(device) });
             }