Nicolas Winkler 5 yıl önce
ebeveyn
işleme
9fe1e6d54e

+ 1 - 0
libmandel/src/IterationCompiler.cpp

@@ -802,6 +802,7 @@ namespace mnd
         printf("if: %s\n", mnd::toString(*zio.expr).c_str()); fflush(stdout);
 
         ir::Formula irf = mnd::expand(z0o, zio);
+        printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
         irf.optimize();
         printf("ir: %s\n", irf.toString().c_str()); fflush(stdout);
         auto fl = compileCl(irf, dev);

+ 44 - 9
libmandel/src/IterationFormula.cpp

@@ -111,7 +111,19 @@ struct SimpleOptimizer
     {
         visitExpr(a.left);
         visitExpr(a.right);
-        // TODO implement
+
+        auto* leftConst = std::get_if<mnd::Constant>(a.left.get());
+        auto* rightConst = std::get_if<mnd::Constant>(a.right.get());
+
+        if (leftConst && rightConst) {
+            return mnd::Constant {
+                (leftConst->re * rightConst->re + leftConst->im * rightConst->im) /
+                (rightConst->re * rightConst->re + rightConst->im * rightConst->im),
+                (leftConst->im * rightConst->re - leftConst->re * rightConst->im) /
+                (rightConst->re * rightConst->re + rightConst->im * rightConst->im)
+            };
+        }
+
         return std::nullopt;
     }
 
@@ -238,6 +250,8 @@ class Parser
     std::vector<std::string> tokens;
     std::stack<char> operators;
     std::vector<mnd::Expression> output;
+
+    bool expectingBinaryOperator;
 public:
     Parser(const std::string& s) :
         in{ s },
@@ -249,41 +263,53 @@ public:
         std::string token;
         bool unary = true;
         while (getToken(token)) {
-            if (std::regex_match(token, num)) {
-                output.push_back(mnd::Constant{ std::atof(token.c_str()) });
-            }
-            else if (std::regex_match(token, floatNum)) {
+            if (std::regex_match(token, num) || std::regex_match(token, floatNum)) {
                 output.push_back(mnd::Constant{ std::atof(token.c_str()) });
+                expectingBinaryOperator = true;
             }
             else if (std::regex_match(token, ident)) {
                 output.push_back(mnd::Variable{ token });
+                expectingBinaryOperator = true;
             }
             else if (token == "+" || token == "-") {
-                while (!operators.empty() && getTopPrecedence() >= 1) {
-                    popOperator();
+                if (expectingBinaryOperator) {
+                    while (!operators.empty() && getTopPrecedence() >= 1) {
+                        popOperator();
+                    }
+                    operators.push(token[0]);
+                    expectingBinaryOperator = false;
+                }
+                else { // unary op
+                    if (token == "-")
+                        operators.push('m');
+                    else
+                        throw ParseError("unary '+' is not allowed");
                 }
-                operators.push(token[0]);
             }
             else if (token == "*" || token == "/") {
                 while (!operators.empty() && getTopPrecedence() >= 2) {
                     popOperator();
                 }
                 operators.push(token[0]);
+                expectingBinaryOperator = false;
             }
             else if (token == "^") {
                 while (!operators.empty() && getTopPrecedence() > 3) {
                     popOperator();
                 }
                 operators.push(token[0]);
+                expectingBinaryOperator = false;
             }
             else if (token == "(") {
                 operators.push(token[0]);
+                expectingBinaryOperator = false;
             }
             else if (token == ")") {
                 while (operators.top() != '(') {
                     popOperator();
                 }
                 operators.pop();
+                expectingBinaryOperator = true;
             }
         }
         while (!operators.empty())
@@ -306,6 +332,15 @@ public:
         mnd::Expression& left = output.at(output.size() - 2);
         mnd::Expression& right = output.at(output.size() - 1);
         mnd::Expression newExpr = mnd::Constant{ 0.0 };
+
+        // handle unary minus separately
+        if (top == 'm') {
+            auto neg = mnd::Negation{ std::make_unique<mnd::Expression>(std::move(right)) };
+            output.pop_back();
+            output.push_back(std::move(neg));
+            return;
+        }
+
         if (top == '+' || top == '-') {
             newExpr = mnd::Addition {
                 std::make_unique<mnd::Expression>(std::move(left)),
@@ -345,7 +380,7 @@ public:
 
     int getPrecedence(char op) const {
         char t = op;
-        if (t == '+' || t == '-')
+        if (t == '+' || t == '-' || t == 'm') // 'm' == unary minus
             return 1;
         else if (t == '*' || t == '/')
             return 2;

+ 176 - 58
libmandel/src/IterationIR.cpp

@@ -115,11 +115,48 @@ namespace mnd
             return { newa, newb };
         }
 
-        NodePair operator() (const Division& mul)
+        NodePair operator() (const Division& div)
         {
-            // TODO implement
-            throw "unimplemented";
-            return { nullptr, nullptr };
+            auto [a, b] = std::visit(*this, *div.left);
+            auto [c, d] = std::visit(*this, *div.right);
+
+            return division(a, b, c, d);
+        }
+
+        NodePair division(Node* a, Node* b, Node* c, Node* d)
+        {
+            Node* ac = arena.allocate(ir::Multiplication{ a, c });
+            Node* bd = arena.allocate(ir::Multiplication{ b, d });
+            Node* bc = arena.allocate(ir::Multiplication{ b, c });
+            Node* ad = arena.allocate(ir::Multiplication{ a, d });
+
+            Node* cc = arena.allocate(ir::Multiplication{ c, c });
+            Node* dd = arena.allocate(ir::Multiplication{ d, d });
+
+            Node* ac_bd = arena.allocate(ir::Addition{ ac, bd });
+            Node* bc_ad = arena.allocate(ir::Subtraction{ bc, ad });
+
+            Node* den = arena.allocate(ir::Addition{ cc, dd });
+            Node* factor = arena.allocate(ir::Division{ one, den });
+
+            Node* re = arena.allocate(ir::Multiplication{ factor, ac_bd });
+            Node* im = arena.allocate(ir::Multiplication{ factor, bc_ad });
+
+            return { re, im };
+        }
+
+        NodePair oneOver(Node* a, Node* b)
+        {
+            Node* cc = arena.allocate(ir::Multiplication{ a, a });
+            Node* dd = arena.allocate(ir::Multiplication{ b, b });
+
+            Node* den = arena.allocate(ir::Addition{ cc, dd });
+            Node* factor = arena.allocate(ir::Division{ one, den });
+
+            Node* re = arena.allocate(ir::Multiplication{ factor, a });
+            Node* im = arena.allocate(ir::Negation{ arena.allocate(ir::Multiplication{ factor, b }) });
+
+            return { re, im };
         }
 
         NodePair operator() (const Pow& p)
@@ -170,9 +207,8 @@ namespace mnd
             auto [a, b] = val;
 
             if (exponent < 0) {
-                // TODO implement
-                exponent = 0;
-                //return arena.allocate(ir::Division{ one });
+                auto [inva, invb] = intPow(val, -exponent);
+                return oneOver(inva, invb);
             }
 
             if (exponent == 0)
@@ -232,67 +268,130 @@ namespace mnd
 }
 
 
-std::string mnd::ir::Formula::toString(void) const
+using namespace std::string_literals;
+struct ToStringVisitor
 {
-    struct ToStringVisitor
-    {
-        std::string operator()(const ir::Constant& c) {
-            return mnd::toString(c.value);
-        }
+    // return string and precedence
+    using Ret = std::pair<std::string, int>;
 
-        std::string operator()(const ir::Variable& v) {
-            return v.name;
-        }
+    Ret operator()(const ir::Constant& c) {
+        return { mnd::toString(c.value), 0 };
+    }
 
-        std::string operator()(const ir::Negation& n) {
-            return "-(" + std::visit(*this, *n.value) + ")";
-        }
+    Ret operator()(const ir::Variable& v) {
+        return { v.name, 0 };
+    }
 
-        std::string operator()(const ir::Addition& n) {
-            return "(" + std::visit(*this, *n.left) + ") + (" + std::visit(*this, *n.right) + ")";
-        }
+    Ret operator()(const ir::Negation& n) {
+        auto [str, prec] = std::visit(*this, *n.value);
+        if (prec > 0)
+            return { "-("s + str + ")", 2 };
+        else
+            return { "-"s + str, 2 };
+    }
 
-        std::string operator()(const ir::Subtraction& n) {
-            return "(" + std::visit(*this, *n.left) + ") - (" + std::visit(*this, *n.right) + ")";
-        }
+    Ret operator()(const ir::Addition& n) {
+        auto [strl, precl] = std::visit(*this, *n.left);
+        auto [strr, precr] = std::visit(*this, *n.right);
+        std::string ret;
+        if (precl > 4)
+            ret += strl + " + ";
+        else
+            ret += "(" + strl + ") + ";
+        if (precr > 4)
+            ret += strr;
+        else
+            ret += "(" + strr + ")";
+        return { ret, 4 };
+    }
 
-        std::string operator()(const ir::Multiplication& n) {
-            return "(" + std::visit(*this, *n.left) + ") * (" + std::visit(*this, *n.right) + ")";
-        }
+    Ret operator()(const ir::Subtraction& n) {
+        auto [strl, precl] = std::visit(*this, *n.left);
+        auto [strr, precr] = std::visit(*this, *n.right);
+        std::string ret;
+        if (precl > 4)
+            ret += strl + " - ";
+        else
+            ret += "(" + strl + ") - ";
+        if (precr >= 4)
+            ret += strr;
+        else
+            ret += "(" + strr + ")";
+        return { ret, 4 };
+    }
 
-        std::string operator()(const ir::Division& n) {
-            return "(" + std::visit(*this, *n.left) + ") / (" + std::visit(*this, *n.right) + ")";
-        }
+    Ret operator()(const ir::Multiplication& n) {
+        auto [strl, precl] = std::visit(*this, *n.left);
+        auto [strr, precr] = std::visit(*this, *n.right);
+        std::string ret;
+        if (precl > 3)
+            ret += strl + " * ";
+        else
+            ret += "(" + strl + ") * ";
+        if (precr > 3)
+            ret += strr;
+        else
+            ret += "(" + strr + ")";
+        return { ret, 3 };
+    }
 
-        std::string operator()(const ir::Atan2& n) {
-            return "atan2(" + std::visit(*this, *n.left) + ", " + std::visit(*this, *n.right) + ")";
-        }
+    Ret operator()(const ir::Division& n) {
+        auto [strl, precl] = std::visit(*this, *n.left);
+        auto [strr, precr] = std::visit(*this, *n.right);
+        std::string ret;
+        if (precl > 3)
+            ret += strl + " / ";
+        else
+            ret += "(" + strl + ") / ";
+        if (precr >= 3)
+            ret += strr;
+        else
+            ret += "(" + strr + ")";
+        return { ret, 3 };
+    }
 
-        std::string operator()(const ir::Pow& n) {
-            return std::visit(*this, *n.left) + " ^ " + std::visit(*this, *n.right);
-        }
+    Ret operator()(const ir::Atan2& n) {
+        return { "atan2(" + std::visit(*this, *n.left).first + ", " + std::visit(*this, *n.right).first + ")", 1 };
+    }
 
-        std::string operator()(const ir::Cos& n) {
-            return "cos(" + std::visit(*this, *n.value) + ")";
-        }
+    Ret operator()(const ir::Pow& n) {
+        auto [strl, precl] = std::visit(*this, *n.left);
+        auto [strr, precr] = std::visit(*this, *n.right);
+        std::string ret;
+        if (precl >= 2)
+            ret += strl + " ^ ";
+        else
+            ret += "(" + strl + ") ^ ";
+        if (precr > 2)
+            ret += strr;
+        else
+            ret += "(" + strr + ")";
+        return { ret, 2 };
+    }
 
-        std::string operator()(const ir::Sin& n) {
-            return "sin(" + std::visit(*this, *n.value) + ")";
-        }
+    Ret operator()(const ir::Cos& n) {
+        return { "cos(" + std::visit(*this, *n.value).first + ")", 1 };
+    }
 
-        std::string operator()(const ir::Exp& n) {
-            return "exp(" + std::visit(*this, *n.value) + ")";
-        }
+    Ret operator()(const ir::Sin& n) {
+        return { "sin(" + std::visit(*this, *n.value).first + ")", 1 };
+    }
 
-        std::string operator()(const ir::Ln& n) {
-            return "ln(" + std::visit(*this, *n.value) + ")";
-        }
-    };
+    Ret operator()(const ir::Exp& n) {
+        return { "exp(" + std::visit(*this, *n.value).first + ")", 1 };
+    }
 
-    return std::string("a = ") + std::visit(ToStringVisitor{}, *this->newA) + 
-        "\nb = " + std::visit(ToStringVisitor{}, *this->newB) +
-        "\nx = " + std::visit(ToStringVisitor{}, *this->startA) +
-        "\ny = " + std::visit(ToStringVisitor{}, *this->startB);
+    Ret operator()(const ir::Ln& n) {
+        return { "ln(" + std::visit(*this, *n.value).first + ")", 1 };
+    }
+};
+
+std::string mnd::ir::Formula::toString(void) const
+{
+    return std::string("a = ") + std::visit(ToStringVisitor{}, *this->newA).first + 
+        "\nb = " + std::visit(ToStringVisitor{}, *this->newB).first +
+        "\nx = " + std::visit(ToStringVisitor{}, *this->startA).first +
+        "\ny = " + std::visit(ToStringVisitor{}, *this->startB).first;
 }
 
 struct ConstantPropagator
@@ -398,23 +497,42 @@ struct ConstantPropagator
     MaybeNode operator()(ir::Multiplication& n) {
         visitNode(n.left);
         visitNode(n.right);
+        printf("simpifying mul: %s * %s\n", std::visit(ToStringVisitor{}, *n.left).first.c_str(), std::visit(ToStringVisitor{}, *n.right).first.c_str());
         auto* ca = getIfConstant(n.left);
         auto* cb = getIfConstant(n.right);
         if (ca && cb) {
             return ir::Constant{ ca->value * cb->value };
         }
-        else if (ca && ca->value == 0) {
+        if (ca && ca->value == 0) {
             return ir::Constant{ 0 };
         }
-        else if (cb && cb->value == 0) {
+        if (cb && cb->value == 0) {
             return ir::Constant{ 0 };
         }
-        else if (ca && ca->value == 1) {
+        if (ca && ca->value == 1) {
             return *n.right;
         }
-        else if (cb && cb->value == 1) {
+        if (cb && cb->value == 1) {
             return *n.left;
         }
+        if (ca) {
+            auto* rightMul = std::get_if<ir::Multiplication>(n.right);
+            if (rightMul) {
+                auto* clr = getIfConstant(rightMul->left);
+                if (clr) {
+                    printf("left %s, right %s\n", mnd::toString(ca->value).c_str(), mnd::toString(clr->value).c_str());
+                    //ca->value *= clr->value;
+                    n.right = rightMul->right;
+                    auto mul = ir::Multiplication{ arena.allocate(ir::Constant{ ca->value * clr->value }), rightMul->right };//
+                    auto maybeBetter = this->operator()(mul);
+                    return maybeBetter.value_or(mul);
+                }
+            }
+        }
+        if (cb) {
+            // move constants to the left
+            std::swap(n.left, n.right);
+        }
         return std::nullopt;
     }