Ver código fonte

implemented constant propagation

Nicolas Winkler 5 anos atrás
pai
commit
3cc70f20bf

+ 2 - 2
choosegenerators.cpp

@@ -300,8 +300,8 @@ void ChooseGenerators::on_compile_clicked()
 
     std::string expr = mnd::toString(*itf.expr);
     printf("%s\n", expr.c_str()); fflush(stdout);
-    chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
-    return;
+    //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
+    //return;
     mnd::ir::Formula irform = mnd::expand(itf);
     printf("%s\n", irform.toString().c_str()); fflush(stdout);
     auto cg = std::make_unique<mnd::CompiledGenerator>(mnd::compile(irform));

+ 4 - 0
libmandel/include/Types.h

@@ -64,6 +64,10 @@ namespace mnd
     inline Float512 log(const Float512& x) { return boost::multiprecision::log(x); }
     inline Float512 log2(const Float512& x) { return boost::multiprecision::log2(x); }
     inline Float512 pow(const Float512& x, const Float512& y) { return boost::multiprecision::pow(x, y); }
+    inline Float512 atan2(const Float512& y, const Float512& x) { return boost::multiprecision::atan2(y, x); }
+    inline Float512 cos(const Float512& x) { return boost::multiprecision::cos(x); }
+    inline Float512 sin(const Float512& x) { return boost::multiprecision::sin(x); }
+    inline Float512 exp(const Float512& x) { return boost::multiprecision::exp(x); }
 
     using Real = Float512;
     using Integer = boost::multiprecision::int512_t;

+ 0 - 1
libmandel/src/IterationCompiler.cpp

@@ -116,7 +116,6 @@ namespace mnd
 
         Reg operator()(const ir::Multiplication& add) {
             auto res = cc.newXmmSd();
-            cc.comment("multiply");
             cc.movapd(res, visitNode(*add.left));
             cc.mulsd(res, visitNode(*add.right));
             return res;

+ 1 - 1
libmandel/src/IterationGenerator.cpp

@@ -91,7 +91,7 @@ std::complex<double> NaiveGenerator::calc(mnd::Expression& expr, std::complex<do
     std::visit([this, &result, z, c] (auto&& ex) {
         using T = std::decay_t<decltype(ex)>;
         if constexpr (std::is_same<T, mnd::Constant>::value) {
-            result = std::complex{ ex.re, ex.im };
+            result = std::complex{ mnd::convert<double>(ex.re), mnd::convert<double>(ex.im) };
         }
         else if constexpr (std::is_same<T, mnd::Variable>::value) {
             if (ex.name == "z")

+ 115 - 1
libmandel/src/IterationIR.cpp

@@ -297,10 +297,124 @@ std::string mnd::ir::Formula::toString(void) const
         "\nb = " + std::visit(ToStringVisitor{}, *this->newB);
 }
 
+struct ConstantPropagator
+{
+    mnd::ir::Formula& formula;
+    mnd::util::Arena<Node>& arena;
+
+    ConstantPropagator(mnd::ir::Formula& formula) :
+        formula{ formula },
+        arena{ formula.nodeArena }
+    {
+    }
+
+    void propagateConstants(void) {
+        visitNode(formula.newA);
+        visitNode(formula.newB);
+    }
+
+    void visitNode(Node* n) {
+        std::visit(*this, *n, n);
+    }
+
+    ir::Constant* getIfConstant(Node* n) {
+        return std::get_if<ir::Constant>(n);
+    }
+
+    void operator()(ir::Constant& x, Node* node) {
+        x.nodeDate = true;
+    }
+    void operator()(ir::Variable& x, Node* node) {
+        x.nodeDate = true;
+    }
+
+    void operator()(ir::Negation& n, Node* node) {
+        if (auto* c = getIfConstant(n.value)) {
+            *node = ir::Constant{ -c->value };
+        }
+    }
+
+    void operator()(ir::Addition& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ ca->value + cb->value };
+        }
+    }
+
+    void operator()(ir::Subtraction& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ ca->value - cb->value };
+        }
+    }
+
+    void operator()(ir::Multiplication& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ ca->value * cb->value };
+        }
+    }
+
+    void operator()(ir::Division& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ ca->value / cb->value };
+        }
+    }
+
+    void operator()(ir::Atan2& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ mnd::atan2(ca->value, cb->value) };
+        }
+    }
+
+    void operator()(ir::Pow& n, Node* node) {
+        auto* ca = getIfConstant(n.left);
+        auto* cb = getIfConstant(n.right);
+        if (ca && cb) {
+            *node = ir::Constant{ mnd::pow(ca->value, cb->value) };
+        }
+    }
+
+    void operator()(ir::Cos& n, Node* node) {
+        auto* ca = getIfConstant(n.value);
+        if (ca) {
+            *node = ir::Constant{ mnd::cos(ca->value) };
+        }
+    }
+
+    void operator()(ir::Sin& n, Node* node) {
+        auto* ca = getIfConstant(n.value);
+        if (ca) {
+            *node = ir::Constant{ mnd::sin(ca->value) };
+        }
+    }
+
+    void operator()(ir::Exp& n, Node* node) {
+        auto* ca = getIfConstant(n.value);
+        if (ca) {
+            *node = ir::Constant{ mnd::exp(ca->value) };
+        }
+    }
+
+    void operator()(ir::Ln& n, Node* node) {
+        auto* ca = getIfConstant(n.value);
+        if (ca) {
+            *node = ir::Constant{ mnd::log(ca->value) };
+        }
+    }
+};
 
 void mnd::ir::Formula::constantPropagation(void)
 {
-
+    ConstantPropagator cp { *this };
+    cp.propagateConstants();
 }