Nicolas Winkler 5 years ago
parent
commit
341f5d8a86

+ 8 - 1
choosegenerators.cpp

@@ -297,9 +297,16 @@ void ChooseGenerators::on_compile_clicked()
     QString formula = this->ui->formula->text();
     mnd::IterationFormula itf{ mnd::parse(formula.toStdString()) };
     //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
-    auto cg = std::make_unique<mnd::CompiledGenerator>(mndCtxt);
+    mnd::ir::Formula irform = mnd::expand(itf);
+    auto cg = std::make_unique<mnd::CompiledGenerator>(mnd::compile(irform));
+    std::string expr = mnd::toString(*itf.expr);
     std::string asmCode = cg->dump();
+    printf("%s\n", expr.c_str()); fflush(stdout);
+    printf("%s\n", irform.toString().c_str()); fflush(stdout);
     printf("%s\n", asmCode.c_str()); fflush(stdout);
+    /*QMessageBox msgBox(nullptr);
+    msgBox.setText(QString::fromStdString(asmCode));
+    msgBox.exec();*/
     chosenGenerator = std::move(cg);
 }
 

+ 37 - 7
libmandel/include/Arena.h

@@ -1,6 +1,7 @@
 #ifndef MANDEL_ARENA_H
 #define MANDEL_ARENA_H
 
+#include <vector>
 #include <array>
 #include <utility>
 #include <memory>
@@ -14,22 +15,51 @@ namespace mnd
         {
             struct Chunk
             {
-                std::array<T, chunkSize> data;
+                char data[sizeof(T) * chunkSize];
                 int used = 0;
 
-                bool full(void) const { return used = chunkSize; }
-                T* allocate() { return data[used++]; }
+                bool full(void) const { return used == chunkSize; }
+
+                template<typename... Args>
+                T* allocate(Args&&... args)
+                {
+                    return new(reinterpret_cast<T*>(&data[(used++) * sizeof(T)])) T(std::forward<Args>(args)...);
+                }
+
+                ~Chunk(void)
+                {
+                    for (int i = used - 1; i >= 0; i--) {
+                        reinterpret_cast<T*>(&data[i * sizeof(T)])->~T();
+                    }
+                }
             };
 
             std::vector<std::unique_ptr<Chunk>> chunks;
         public:
-            T* allocate(void)
+
+            Arena(void) = default;
+            Arena(const Arena&) = delete;
+            Arena(Arena&&) = default;
+            ~Arena(void)
+            {
+                for (auto it = chunks.rbegin(); it != chunks.rend(); ++it) {
+                    *it = nullptr;
+                }
+            }
+
+            Arena& operator=(const Arena&) = delete;
+            Arena& operator=(Arena&&) = default;
+
+            Chunk& lastChunk(void) { return *chunks[chunks.size() - 1]; }
+
+            template<typename... Args>
+            T* allocate(Args&&... args)
             {
-                if (chunks.empty() || chunks[chunks.size() - 1].full()) {
-                    chunks.push_back(Chunk{});
+                if (chunks.empty() || lastChunk().full()) {
+                    chunks.push_back(std::make_unique<Chunk>());
                 }
 
-                return chunks[chunks.size() - 1].allocate();
+                return lastChunk().allocate(std::forward<Args>(args)...);
             }
         };
     }

+ 11 - 1
libmandel/include/IterationCompiler.h

@@ -2,6 +2,7 @@
 #define MANDEL_ITERATIONCOMPILER_H
 
 #include "Generators.h"
+#include "IterationIR.h"
 
 namespace mnd
 {
@@ -11,20 +12,29 @@ namespace mnd
     class MandelContext;
 
     mnd::ExecData compile(mnd::MandelContext& mndCtxt);
-}
 
+ 
+}
+void squareTest();
 
 class mnd::CompiledGenerator : public mnd::MandelGenerator
 {
     std::unique_ptr<ExecData> execData;
 public:
     CompiledGenerator(MandelContext& mndContext);
+    CompiledGenerator(std::unique_ptr<ExecData> execData);
+    CompiledGenerator(CompiledGenerator&&);
     virtual ~CompiledGenerator(void);
     virtual void generate(const MandelInfo& info, float* data);
 
     std::string dump(void) const;
 };
 
+namespace mnd
+{
+    CompiledGenerator compile(const ir::Formula& formula);
+}
+
 
 #endif // MANDEL_ITERATIONCOMPILER_H
 

+ 49 - 5
libmandel/include/IterationIR.h

@@ -3,6 +3,7 @@
 
 #include <string>
 #include <vector>
+#include <variant>
 
 #include "IterationFormula.h"
 #include "Arena.h"
@@ -13,24 +14,37 @@ namespace mnd
     {
         struct Constant;
         struct Variable;
+        struct UnaryOperation;
         struct Negation;
         struct BinaryOperation;
         struct Addition;
+        struct Subtraction;
         struct Multiplication;
+        struct Atan2;
+        struct Pow;
+        struct Cos;
+        struct Sin;
 
         using Node = std::variant<
             Constant,
             Variable,
             Negation,
             Addition,
-            Multiplication
+            Subtraction,
+            Multiplication,
+            Atan2,
+            Pow,
+            Cos,
+            Sin
         >;
 
-        class Formula
+        struct Formula
         {
             util::Arena<Node> nodeArena;
             Node* newA;
             Node* newB;
+
+            std::string toString(void) const;
         };
     }
 
@@ -50,16 +64,21 @@ struct mnd::ir::Variable
 };
 
 
-struct mnd::ir::Negation
+struct mnd::ir::UnaryOperation
 {
     Node* value;
 };
 
 
+struct mnd::ir::Negation : mnd::ir::UnaryOperation
+{
+};
+
+
 struct mnd::ir::BinaryOperation
 {
-    Node* a;
-    Node* b;
+    Node* left;
+    Node* right;
 };
 
 
@@ -68,9 +87,34 @@ struct mnd::ir::Addition : mnd::ir::BinaryOperation
 };
 
 
+struct mnd::ir::Subtraction : mnd::ir::BinaryOperation
+{
+};
+
+
 struct mnd::ir::Multiplication : mnd::ir::BinaryOperation
 {
 };
 
 
+struct mnd::ir::Atan2 : mnd::ir::BinaryOperation
+{
+};
+
+
+struct mnd::ir::Pow : mnd::ir::BinaryOperation
+{
+};
+
+
+struct mnd::ir::Cos : mnd::ir::UnaryOperation
+{
+};
+
+
+struct mnd::ir::Sin : mnd::ir::UnaryOperation
+{
+};
+
+
 #endif // MANDEL_ITERATIONIR_H

+ 245 - 3
libmandel/src/IterationCompiler.cpp

@@ -1,7 +1,7 @@
 #include "IterationCompiler.h"
-#include "IterationIR.h"
 
 #include <asmjit/asmjit.h>
+#include <cmath>
 #include "Mandel.h"
 #include <omp.h>
 
@@ -31,6 +31,195 @@ namespace mnd
 
         ~ExecData(void) = default;
     };
+
+
+    struct CompileVisitor
+    {
+        using Reg = asmjit::x86::Xmm;
+
+        asmjit::x86::Compiler& cc;
+        Reg& a;
+        Reg& b;
+        Reg& x;
+        Reg& y;
+
+        CompileVisitor(asmjit::x86::Compiler& cc, Reg& a, Reg& b, Reg& x, Reg& y) :
+            cc{ cc },
+            a{ a }, b{ b },
+            x{ x }, y{ y }
+        {
+        }
+
+        Reg operator()(const ir::Constant& c) {
+            auto constant = cc.newDoubleConst(asmjit::ConstPool::kScopeLocal, c.value);
+            auto reg = cc.newXmmSd();
+            cc.movsd(reg, constant);
+            return reg;
+        }
+
+        Reg operator()(const ir::Variable& v) {
+            if (v.name == "a") {
+                return a;
+            }
+            else if (v.name == "b") {
+                return b;
+            }
+            else if (v.name == "x") {
+                return x;
+            }
+            else if (v.name == "y") {
+                return y;
+            }
+            else
+                throw "unknown variable";
+        }
+
+        Reg operator()(const ir::Negation& n) {
+            auto sub = cc.newXmmSd();
+            cc.xorpd(sub, sub);
+            cc.subsd(sub, std::visit((*this), *n.value));
+            return sub;
+        }
+
+        Reg operator()(const ir::Addition& add) {
+            auto res = cc.newXmmSd();
+            cc.movapd(res, std::visit((*this), *add.left));
+            cc.addsd(res, std::visit((*this), *add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Subtraction& add) {
+            auto res = cc.newXmmSd();
+            cc.movapd(res, std::visit((*this), *add.left));
+            cc.subsd(res, std::visit((*this), *add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Multiplication& add) {
+            auto res = cc.newXmmSd();
+            cc.movapd(res, std::visit((*this), *add.left));
+            cc.mulsd(res, std::visit((*this), *add.right));
+            return res;
+        }
+
+        Reg operator()(const ir::Atan2& at2) {
+            using namespace asmjit;
+            auto y = std::visit((*this), *at2.left);
+            auto x = std::visit((*this), *at2.right);
+
+            auto arg = cc.newXmmSd();
+            double(*atanFunc)(double, double) = ::atan2;
+            auto call = cc.call(imm(atanFunc), FuncSignatureT<double, double, double>(CallConv::kIdHostCDecl));
+            call->setArg(0, y);
+            call->setArg(1, x);
+            call->setRet(0, arg);
+            return arg;
+        }
+
+        Reg operator()(const ir::Pow& p) {
+            using namespace asmjit;
+            auto a = std::visit((*this), *p.left);
+            auto b = std::visit((*this), *p.right);
+
+            auto arg = cc.newXmmSd();
+            double(*powFunc)(double, double) = ::pow;
+            auto call = cc.call(imm(powFunc), FuncSignatureT<double, double, double>(CallConv::kIdHostCDecl));
+            call->setArg(0, a);
+            call->setArg(1, b);
+            call->setRet(0, arg);
+            return arg;
+        }
+
+        Reg operator()(const ir::Cos& c) {
+            using namespace asmjit;
+            auto a = std::visit((*this), *c.value);
+
+            auto arg = cc.newXmmSd();
+            double(*cosFunc)(double) = ::cos;
+            auto call = cc.call(imm(cosFunc), FuncSignatureT<double, double>(CallConv::kIdHostCDecl));
+            call->setArg(0, a);
+            call->setRet(0, arg);
+            return arg;
+        }
+
+        Reg operator()(const ir::Sin& s) {
+            using namespace asmjit;
+            auto a = std::visit((*this), *s.value);
+
+            auto arg = cc.newXmmSd();
+            double(*sinFunc)(double) = ::sin;
+            auto call = cc.call(imm(sinFunc), FuncSignatureT<double, double>(CallConv::kIdHostCDecl));
+            call->setArg(0, a);
+            call->setRet(0, arg);
+            return arg;
+        }
+    };
+
+
+    CompiledGenerator compile(const ir::Formula& formula)
+    {
+        using namespace asmjit;
+        std::unique_ptr<mnd::ExecData> ed = std::make_unique<mnd::ExecData>();
+        JitRuntime& jitRuntime = *ed->jitRuntime;
+        ed->code->init(jitRuntime.codeInfo());
+
+        x86::Compiler& comp = *ed->compiler;
+
+        x86::Mem sixteen = comp.newDoubleConst(ConstPool::kScopeLocal, 16.0);
+
+        Label startLoop = comp.newLabel();
+        Label endLoop = comp.newLabel();
+        x86::Gp maxIter = comp.newInt32();
+        x86::Gp k = comp.newInt32();
+        x86::Xmm x = comp.newXmmSd();
+        x86::Xmm y = comp.newXmmSd();
+        x86::Xmm a = comp.newXmmSd();
+        x86::Xmm b = comp.newXmmSd();
+        comp.addFunc(FuncSignatureT<int, double, double, int>(CallConv::kIdHost));
+        comp.setArg(0, x);
+        comp.setArg(1, y);
+        comp.setArg(2, maxIter);
+        comp.movapd(a, x);
+        comp.movapd(b, y);
+
+        comp.xor_(k, k);
+
+        comp.bind(startLoop);
+
+        CompileVisitor cv{ comp, a, b, x, y };
+        auto newA = std::visit(cv, *formula.newA);
+        auto newB = std::visit(cv, *formula.newB);
+        comp.movapd(a, newA);
+        comp.movapd(b, newB);
+
+        x86::Xmm aa = comp.newXmmSd();
+        x86::Xmm bb = comp.newXmmSd();
+        comp.movapd(aa, a);
+        comp.mulsd(aa, a);
+        comp.movapd(bb, b);
+        comp.mulsd(bb, b);
+        comp.addsd(bb, aa);
+
+        comp.comisd(bb, sixteen);
+        comp.jle(endLoop);
+
+        comp.inc(k);
+        comp.cmp(k, maxIter);
+        comp.jne(startLoop);
+        comp.bind(endLoop);
+        comp.ret(k);
+        comp.endFunc();
+        auto err = comp.finalize();
+        if (err == asmjit::kErrorOk) {
+            err = jitRuntime.add(&ed->iterationFunc, ed->code.get());
+            if (err != asmjit::kErrorOk)
+                throw "error adding function";
+        }
+        else {
+            throw "error compiling";
+        }
+        return CompiledGenerator{ std::move(ed) };
+    }
 }
 
 
@@ -45,6 +234,16 @@ CompiledGenerator::CompiledGenerator(mnd::MandelContext& mndContext) :
 }
 
 
+CompiledGenerator::CompiledGenerator(std::unique_ptr<mnd::ExecData> execData) :
+    MandelGenerator{ 1.0e-15 },
+    execData{ std::move(execData) }
+{
+}
+
+
+CompiledGenerator::CompiledGenerator(CompiledGenerator&&) = default;
+
+
 CompiledGenerator::~CompiledGenerator(void)
 {
 }
@@ -91,8 +290,6 @@ void CompiledGenerator::generate(const mnd::MandelInfo& info, float* data)
 }
 
 
-
-
 std::string CompiledGenerator::dump(void) const
 {
     asmjit::String d;
@@ -232,5 +429,50 @@ namespace mnd
 }
 
 
+void squareTest()
+{
+    mnd::Expression power = mnd::Pow{
+        std::make_unique<mnd::Expression>(mnd::Variable{ "z" }),
+        std::make_unique<mnd::Expression>(mnd::Constant{ 2.3 })
+    };
+
+    mnd::IterationFormula fmla(std::move(power));
+
+    mnd::ir::Formula p = mnd::expand(fmla);
+
+    mnd::ExecData ed;
+    JitRuntime& jitRuntime = *ed.jitRuntime;
+    ed.code->init(jitRuntime.codeInfo());
+
+    x86::Compiler& comp = *ed.compiler;
+
+    comp.addFunc(FuncSignatureT<double, double, double>(CallConv::kIdHost));
+    x86::Xmm x = comp.newXmmSd();
+    x86::Xmm y = comp.newXmmSd();
+    x86::Xmm a = comp.newXmmSd();
+    x86::Xmm b = comp.newXmmSd();
+    comp.setArg(0, x);
+    comp.setArg(1, y);
+    comp.movapd(a, x);
+    comp.movapd(b, y);
+
+
+    mnd::CompileVisitor cv{ comp, a, b, x, y };
+    auto newA = std::visit(cv, *p.newA);
+    auto newB = std::visit(cv, *p.newB);
+    comp.movapd(a, newA);
+    comp.movapd(b, newB);
+    comp.ret(b);
+    comp.endFunc();
+    comp.finalize();
+
+    double (*func)(double, double);
+
+    jitRuntime.add(&func, ed.code.get());
+
+    double result = func(1.0, 3.0);
+    printf("result: %f\n", result);
+}
+
 
 

+ 7 - 1
libmandel/src/IterationFormula.cpp

@@ -19,6 +19,7 @@ mnd::IterationFormula::IterationFormula(mnd::Expression expr) :
 
 static const std::string regexIdent = "[A-Za-z][A-Za-z0-9]*";
 static const std::string regexNum = "[1-9][0-9]*";
+static const std::string regexFloat = "(\\d*\\.?\\d+|\\d+\\.?\\d*)([eE][-+]\\d+)?";
 
 
 class Parser
@@ -26,6 +27,7 @@ class Parser
     static const std::regex tokenize;
     static const std::regex ident;
     static const std::regex num;
+    static const std::regex floatNum;
     std::string in;
     std::regex_iterator<std::string::iterator> rit;
 
@@ -46,6 +48,9 @@ public:
             if (std::regex_match(token, num)) {
                 output.push_back(mnd::Constant{ std::atof(token.c_str()) });
             }
+            else if (std::regex_match(token, floatNum)) {
+                output.push_back(mnd::Constant{ std::atof(token.c_str()) });
+            }
             else if (std::regex_match(token, ident)) {
                 output.push_back(mnd::Variable{ token });
             }
@@ -157,9 +162,10 @@ private:
     }
 };
 
-const std::regex Parser::tokenize = std::regex(regexIdent + "|" + regexNum + "|[\\+\\-\\*/\\^]|[\\(\\)]");
+const std::regex Parser::tokenize = std::regex(regexIdent + "|" + regexFloat + "|[\\+\\-\\*/\\^]|[\\(\\)]");
 const std::regex Parser::ident = std::regex(regexIdent);
 const std::regex Parser::num = std::regex(regexNum);
+const std::regex Parser::floatNum = std::regex(regexFloat);
 
 
 

+ 169 - 2
libmandel/src/IterationIR.cpp

@@ -1,22 +1,189 @@
 #include "IterationIR.h"
 
-
+#include <utility>
 
 using namespace mnd;
 
 
 namespace mnd
 {
+    using ir::Node;
 
     struct ConvertVisitor
     {
-        util::Arena<ir::Node>& arena;
+        using NodePair = std::pair<Node*, Node*>;
+        util::Arena<Node>& arena;
+
+        ConvertVisitor(util::Arena<Node>& arena) :
+            arena{ arena }
+        {
+        }
+
+        NodePair operator() (const Constant& c)
+        {
+            Node* cnst = arena.allocate(ir::Constant{ c.value });
+            Node* zero = arena.allocate(ir::Constant{ 0.0 });
+
+            return { cnst, zero };
+        }
+
+        NodePair operator() (const Variable& v)
+        {
+            if (v.name == "z") {
+                Node* a = arena.allocate(ir::Variable{ "a" });
+                Node* b = arena.allocate(ir::Variable{ "b" });
+
+                return { a, b };
+            }
+            else if (v.name == "c") {
+                Node* x = arena.allocate(ir::Variable{ "x" });
+                Node* y = arena.allocate(ir::Variable{ "y" });
+
+                return { x, y };
+            }
+            else if (v.name == "i") {
+                Node* x = arena.allocate(ir::Constant{ 0.0 });
+                Node* y = arena.allocate(ir::Constant{ 1.0 });
+
+                return { x, y };
+            }
+            else
+                throw "unknown variable";
+        }
+
+        NodePair operator() (const UnaryOperation& v)
+        {
+            auto [opa, opb] = std::visit(*this, *v.operand);
+
+            Node* a = arena.allocate(ir::Negation{ opa });
+            Node* b = arena.allocate(ir::Negation{ opb });
+
+            return { a, b };
+        }
+
+        NodePair operator() (const Addition& add)
+        {
+            auto [lefta, leftb] = std::visit(*this, *add.left);
+            auto [righta, rightb] = std::visit(*this, *add.right);
+
+            if (add.subtraction) {
+                Node* a = arena.allocate(ir::Subtraction{ lefta, righta });
+                Node* b = arena.allocate(ir::Subtraction{ leftb, rightb });
+
+                return { a, b };
+            }
+            else {
+                Node* a = arena.allocate(ir::Addition{ lefta, righta });
+                Node* b = arena.allocate(ir::Addition{ leftb, rightb });
+
+                return { a, b };
+            }
+        }
+
+        NodePair operator() (const Multiplication& mul)
+        {
+            auto [a, b] = std::visit(*this, *mul.left);
+            auto [c, d] = std::visit(*this, *mul.right);
+
+            Node* ac = arena.allocate(ir::Multiplication{ a, c });
+            Node* bd = arena.allocate(ir::Multiplication{ b, d });
+            Node* ad = arena.allocate(ir::Multiplication{ a, d });
+            Node* bc = arena.allocate(ir::Multiplication{ b, c });
+
+            Node* newa = arena.allocate(ir::Subtraction{ ac, bd });
+            Node* newb = arena.allocate(ir::Addition{ ad, bc });
+
+            return { newa, newb };
+        }
+
+        NodePair operator() (const Division& mul)
+        {
+            // TODO implement
+            throw "unimplemented";
+            return { nullptr, nullptr };
+        }
+
+        NodePair operator() (const Pow& p)
+        {
+            auto [a, b] = std::visit(*this, *p.left);
+            auto [c, unused] = std::visit(*this, *p.right);
 
+            auto half = arena.allocate(ir::Constant{ 0.5 });
+
+            auto arg = arena.allocate(ir::Atan2{ b, a });
+            auto aa = arena.allocate(ir::Multiplication{ a, a });
+            auto bb = arena.allocate(ir::Multiplication{ b, b });
+            auto absSq = arena.allocate(ir::Addition{ aa, bb });
+
+            auto halfc = arena.allocate(ir::Multiplication{ c, half });
+
+            auto newAbs = arena.allocate(ir::Pow{ absSq, halfc });
+            auto newArg = arena.allocate(ir::Multiplication{ arg, c });
+
+            auto cosArg = arena.allocate(ir::Cos{ newArg });
+            auto sinArg = arena.allocate(ir::Sin{ newArg });
+            auto newA = arena.allocate(ir::Multiplication{ cosArg, newAbs });
+            auto newB = arena.allocate(ir::Multiplication{ sinArg, newAbs });
+
+            return { newA, newB };
+        }
     };
 
     ir::Formula expand(const mnd::IterationFormula& fmla)
     {
         ir::Formula formula;
+        ConvertVisitor cv{ formula.nodeArena };
+        std::tie(formula.newA, formula.newB) = std::visit(cv, *fmla.expr);
+        return formula;
     }
 }
 
+
+std::string mnd::ir::Formula::toString(void) const
+{
+    struct ToStringVisitor
+    {
+        std::string operator()(const ir::Constant& c) {
+            return std::to_string(c.value);
+        }
+
+        std::string operator()(const ir::Variable& v) {
+            return v.name;
+        }
+
+        std::string operator()(const ir::Negation& n) {
+            return "-(" + std::visit(*this, *n.value) + ")";
+        }
+
+        std::string operator()(const ir::Addition& n) {
+            return "(" + std::visit(*this, *n.left) + ") + (" + std::visit(*this, *n.right) + ")";
+        }
+
+        std::string operator()(const ir::Subtraction& n) {
+            return "(" + std::visit(*this, *n.left) + ") - (" + std::visit(*this, *n.right) + ")";
+        }
+
+        std::string operator()(const ir::Multiplication& n) {
+            return "(" + std::visit(*this, *n.left) + ") * (" + std::visit(*this, *n.right) + ")";
+        }
+
+        std::string operator()(const ir::Atan2& n) {
+            return "atan2(" + std::visit(*this, *n.left) + ", " + std::visit(*this, *n.right) + ")";
+        }
+
+        std::string operator()(const ir::Pow& n) {
+            return std::visit(*this, *n.left) + " ^ " + std::visit(*this, *n.right);
+        }
+
+        std::string operator()(const ir::Cos& n) {
+            return "cos(" + std::visit(*this, *n.value) + ")";
+        }
+
+        std::string operator()(const ir::Sin& n) {
+            return "sin(" + std::visit(*this, *n.value) + ")";
+        }
+    };
+
+    return std::string("a = ") + std::visit(ToStringVisitor{}, *this->newA) + 
+        "\nb = " + std::visit(ToStringVisitor{}, *this->newB);
+}