소스 검색

improving compilation

Nicolas Winkler 5 년 전
부모
커밋
e2e54cb662
2개의 변경된 파일119개의 추가작업 그리고 30개의 파일을 삭제
  1. 4 3
      choosegenerators.cpp
  2. 115 27
      libmandel/src/IterationIR.cpp

+ 4 - 3
choosegenerators.cpp

@@ -300,9 +300,10 @@ void ChooseGenerators::on_compile_clicked()
 
     std::string expr = mnd::toString(*itf.expr);
     printf("%s\n", expr.c_str()); fflush(stdout);
-    //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
-    //return;
+    chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
+    return;
     mnd::ir::Formula irform = mnd::expand(itf);
+    irform.constantPropagation();
     printf("%s\n", irform.toString().c_str()); fflush(stdout);
     auto cg = std::make_unique<mnd::CompiledGenerator>(mnd::compile(irform));
     std::string asmCode = cg->dump();
@@ -314,7 +315,7 @@ void ChooseGenerators::on_compile_clicked()
 
     const mnd::MandelDevice& dev = mndCtxt.getDevices()[0];
     try {
-        //chosenGenerator = mnd::compileCl(irform, dev);
+        chosenGenerator = mnd::compileCl(irform, dev);
     }
     catch(const std::string& msg) {
         printf("error compiling: %s", msg.c_str());

+ 115 - 27
libmandel/src/IterationIR.cpp

@@ -1,6 +1,7 @@
 #include "IterationIR.h"
 
 #include <utility>
+#include <optional>
 
 using namespace mnd;
 
@@ -302,6 +303,8 @@ struct ConstantPropagator
     mnd::ir::Formula& formula;
     mnd::util::Arena<Node>& arena;
 
+    using MaybeNode = std::optional<Node>;
+
     ConstantPropagator(mnd::ir::Formula& formula) :
         formula{ formula },
         arena{ formula.nodeArena }
@@ -313,101 +316,186 @@ struct ConstantPropagator
         visitNode(formula.newB);
     }
 
+    bool hasBeenVisited(Node* n) {
+        return std::visit([] (auto& x) {
+            if (auto* b = std::any_cast<bool>(&x.nodeData))
+                return *b;
+            else
+                return false;
+        }, *n);
+    }
+
     void visitNode(Node* n) {
-        std::visit(*this, *n, n);
+        if (!hasBeenVisited(n)) {
+            MaybeNode mbn = std::visit(*this, *n);
+            if (mbn.has_value()) {
+                *n = std::move(mbn.value());
+            }
+            std::visit([] (auto& x) { x.nodeData = true; }, *n);
+        }
     }
 
     ir::Constant* getIfConstant(Node* n) {
         return std::get_if<ir::Constant>(n);
     }
 
-    void operator()(ir::Constant& x, Node* node) {
-        x.nodeDate = true;
+    MaybeNode operator()(ir::Constant& x) {
+        return std::nullopt;
     }
-    void operator()(ir::Variable& x, Node* node) {
-        x.nodeDate = true;
+    MaybeNode operator()(ir::Variable& x) {
+        return std::nullopt;
     }
 
-    void operator()(ir::Negation& n, Node* node) {
+    MaybeNode operator()(ir::Negation& n) {
+        visitNode(n.value);
         if (auto* c = getIfConstant(n.value)) {
-            *node = ir::Constant{ -c->value };
+            return ir::Constant{ -c->value };
+        }
+        if (auto* neg = std::get_if<ir::Negation>(n.value)) {
+            return *neg->value;
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Addition& n, Node* node) {
+    MaybeNode operator()(ir::Addition& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ ca->value + cb->value };
+            return ir::Constant{ ca->value + cb->value };
+        }
+        else if (ca && ca->value == 0) {
+            return *n.right;
+        }
+        else if (cb && cb->value == 0) {
+            return *n.left;
+        }
+        else if (auto* nright = std::get_if<ir::Negation>(n.right)) {
+            return ir::Subtraction{ n.left, nright->value };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Subtraction& n, Node* node) {
+    MaybeNode operator()(ir::Subtraction& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ ca->value - cb->value };
+            return ir::Constant{ ca->value - cb->value };
         }
+        else if (ca && ca->value == 0) {
+            return ir::Negation{ n.right };
+        }
+        else if (cb && cb->value == 0) {
+            return *n.left;
+        }
+        return std::nullopt;
     }
 
-    void operator()(ir::Multiplication& n, Node* node) {
+    MaybeNode operator()(ir::Multiplication& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ ca->value * cb->value };
+            return ir::Constant{ ca->value * cb->value };
+        }
+        else if (ca && ca->value == 0) {
+            return ir::Constant{ 0 };
+        }
+        else if (cb && cb->value == 0) {
+            return ir::Constant{ 0 };
+        }
+        else if (ca && ca->value == 1) {
+            return *n.right;
         }
+        else if (cb && cb->value == 1) {
+            return *n.left;
+        }
+        return std::nullopt;
     }
 
-    void operator()(ir::Division& n, Node* node) {
+    MaybeNode operator()(ir::Division& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ ca->value / cb->value };
+            return ir::Constant{ ca->value / cb->value };
+        }
+        else if (ca && ca->value == 0) {
+            return ir::Constant{ 0 };
+        }
+        else if (cb && cb->value == 1) {
+            return *n.left;
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Atan2& n, Node* node) {
+    MaybeNode operator()(ir::Atan2& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ mnd::atan2(ca->value, cb->value) };
+            return ir::Constant{ mnd::atan2(ca->value, cb->value) };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Pow& n, Node* node) {
+    MaybeNode operator()(ir::Pow& n) {
+        visitNode(n.left);
+        visitNode(n.right);
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
-            *node = ir::Constant{ mnd::pow(ca->value, cb->value) };
+            return ir::Constant{ mnd::pow(ca->value, cb->value) };
+        }
+        else if (cb && cb->value == 1) {
+            return *n.left;
+        }
+        else if (cb && cb->value == 1) {
+            return ir::Constant{ 1 };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Cos& n, Node* node) {
+    MaybeNode operator()(ir::Cos& n) {
+        visitNode(n.value);
         auto* ca = getIfConstant(n.value);
         if (ca) {
-            *node = ir::Constant{ mnd::cos(ca->value) };
+            return ir::Constant{ mnd::cos(ca->value) };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Sin& n, Node* node) {
+    MaybeNode operator()(ir::Sin& n) {
+        visitNode(n.value);
         auto* ca = getIfConstant(n.value);
         if (ca) {
-            *node = ir::Constant{ mnd::sin(ca->value) };
+            return ir::Constant{ mnd::sin(ca->value) };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Exp& n, Node* node) {
+    MaybeNode operator()(ir::Exp& n) {
+        visitNode(n.value);
         auto* ca = getIfConstant(n.value);
         if (ca) {
-            *node = ir::Constant{ mnd::exp(ca->value) };
+            return ir::Constant{ mnd::exp(ca->value) };
         }
+        return std::nullopt;
     }
 
-    void operator()(ir::Ln& n, Node* node) {
+    MaybeNode operator()(ir::Ln& n) {
+        visitNode(n.value);
         auto* ca = getIfConstant(n.value);
         if (ca) {
-            *node = ir::Constant{ mnd::log(ca->value) };
+            return ir::Constant{ mnd::log(ca->value) };
         }
+        return std::nullopt;
     }
 };