Procházet zdrojové kódy

improving compilation

Nicolas Winkler před 5 roky
rodič
revize
6a7b99c4b7

+ 24 - 6
choosegenerators.cpp

@@ -130,6 +130,14 @@ ChooseGenerators::ChooseGenerators(mnd::MandelContext& mndCtxt, QWidget *parent)
     ui->progressBar->setRange(0, 1000);
     benchmarker.setMaxThreadCount(1);
 
+    QFont f("unexistent");
+    f.setStyleHint(QFont::Monospace);
+    f.setPointSize(12);
+    ui->formula->setFont(f);
+    ui->label_2->setFont(f);
+    ui->initialFormula->setFont(f);
+    ui->label_5->setFont(f);
+
     QRegExp floatingpoint{ "^[-+]?(\\d*\\.?\\d+|\\d+\\.?\\d*)([eE][-+]\\d+)?$" };
     floatValidator = std::make_unique<QRegExpValidator>(floatingpoint, this);
 
@@ -295,14 +303,26 @@ void ChooseGenerators::on_generatorTable_cellDoubleClicked(int row, int column)
 void ChooseGenerators::on_compile_clicked()
 {
     QString formula = this->ui->formula->text();
+    QString z0formula = this->ui->initialFormula->text();
     mnd::IterationFormula itf{ mnd::parse(formula.toStdString()) };
+    mnd::IterationFormula z0{ mnd::parse(z0formula.toStdString()) };
     itf.optimize();
+    z0.optimize();
+
+
+    const mnd::MandelDevice& dev = mndCtxt.getDevices()[0];
+    auto cls = mnd::compileOpenCl(dev, z0, itf);
+    chosenGenerator = compileCpu(mndCtxt, z0, itf);
+
 
     std::string expr = mnd::toString(*itf.expr);
-    printf("%s\n", expr.c_str()); fflush(stdout);
-    //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), mnd::getPrecision<double>());
+    printf("zi := %s\n", expr.c_str()); fflush(stdout);
+    expr = mnd::toString(*z0.expr);
+    printf("z0 := %s\n", expr.c_str()); fflush(stdout);
+    //chosenGenerator = std::make_unique<mnd::NaiveGenerator>(std::move(itf), std::move(z0), mnd::getPrecision<double>());
     //return;
-    mnd::ir::Formula irform = mnd::expand(itf);
+    mnd::ir::Formula irform = mnd::expand(itf, z0);
+    printf("%s\n", irform.toString().c_str()); fflush(stdout);
     irform.constantPropagation();
     printf("%s\n", irform.toString().c_str()); fflush(stdout);
     auto cg = std::make_unique<mnd::CompiledGenerator>(mnd::compile(irform));
@@ -312,10 +332,8 @@ void ChooseGenerators::on_compile_clicked()
     msgBox.setText(QString::fromStdString(asmCode));
     msgBox.exec();*/
     chosenGenerator = std::move(cg);
-
-    const mnd::MandelDevice& dev = mndCtxt.getDevices()[0];
     try {
-        chosenGenerator = mnd::compileCl(irform, dev);
+        //chosenGenerator = mnd::compileCl(irform, dev);
     }
     catch(const std::string& msg) {
         printf("error compiling: %s", msg.c_str());

+ 97 - 20
choosegenerators.ui

@@ -6,8 +6,8 @@
    <rect>
     <x>0</x>
     <y>0</y>
-    <width>1167</width>
-    <height>663</height>
+    <width>911</width>
+    <height>653</height>
    </rect>
   </property>
   <property name="windowTitle">
@@ -29,33 +29,123 @@
           </attribute>
           <layout class="QHBoxLayout" name="horizontalLayout_4">
            <item>
+            <widget class="QLabel" name="label_3">
+             <property name="sizePolicy">
+              <sizepolicy hsizetype="Expanding" vsizetype="Preferred">
+               <horstretch>0</horstretch>
+               <verstretch>0</verstretch>
+              </sizepolicy>
+             </property>
+             <property name="text">
+              <string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;&lt;span style=&quot; font-size:12pt; font-weight:600;&quot;&gt;Formula Editor&lt;/span&gt;&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
+             </property>
+            </widget>
+           </item>
+           <item>
             <layout class="QVBoxLayout" name="verticalLayout_5">
              <item>
               <layout class="QFormLayout" name="formLayout">
-               <item row="0" column="0">
+               <item row="3" column="0">
                 <widget class="QLabel" name="label_2">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
                  <property name="text">
-                  <string> z = </string>
+                  <string>z := </string>
                  </property>
                 </widget>
                </item>
-               <item row="0" column="1">
+               <item row="3" column="1">
                 <widget class="QLineEdit" name="formula">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
                  <property name="text">
                   <string>z^2+c</string>
                  </property>
                 </widget>
                </item>
-               <item row="1" column="0">
+               <item row="5" column="0">
                 <widget class="QPushButton" name="benchmark">
                  <property name="text">
                   <string>Benchmark</string>
                  </property>
                 </widget>
                </item>
-               <item row="1" column="1">
+               <item row="5" column="1">
                 <widget class="QLineEdit" name="compBenchResult"/>
                </item>
+               <item row="0" column="0" colspan="2">
+                <widget class="QLabel" name="descriptionInitial">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
+                 <property name="text">
+                  <string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;&lt;span style=&quot; font-weight:600;&quot;&gt;Initial Iteration Value&lt;/span&gt;&lt;/p&gt;&lt;p&gt;&lt;span style=&quot; font-size:9pt;&quot;&gt;Specify the initial value for z depending on c.&lt;/span&gt;&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
+                 </property>
+                 <property name="textFormat">
+                  <enum>Qt::RichText</enum>
+                 </property>
+                </widget>
+               </item>
+               <item row="1" column="0">
+                <widget class="QLabel" name="label_5">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
+                 <property name="text">
+                  <string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;z&lt;span style=&quot; vertical-align:sub;&quot;&gt;0&lt;/span&gt; = &lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
+                 </property>
+                 <property name="textFormat">
+                  <enum>Qt::RichText</enum>
+                 </property>
+                </widget>
+               </item>
+               <item row="1" column="1">
+                <widget class="QLineEdit" name="initialFormula">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
+                 <property name="text">
+                  <string>0</string>
+                 </property>
+                </widget>
+               </item>
+               <item row="2" column="0" colspan="2">
+                <widget class="QLabel" name="label_4">
+                 <property name="font">
+                  <font>
+                   <pointsize>12</pointsize>
+                  </font>
+                 </property>
+                 <property name="text">
+                  <string>&lt;html&gt;&lt;head/&gt;&lt;body&gt;&lt;p&gt;&lt;span style=&quot; font-weight:600;&quot;&gt;Iteration Formula&lt;/span&gt;&lt;/p&gt;&lt;p&gt;&lt;span style=&quot; font-size:9pt;&quot;&gt;Specify the iteration formula in terms of z and c.&lt;/span&gt;&lt;/p&gt;&lt;/body&gt;&lt;/html&gt;</string>
+                 </property>
+                </widget>
+               </item>
+               <item row="4" column="1">
+                <spacer name="verticalSpacer">
+                 <property name="orientation">
+                  <enum>Qt::Vertical</enum>
+                 </property>
+                 <property name="sizeHint" stdset="0">
+                  <size>
+                   <width>20</width>
+                   <height>40</height>
+                  </size>
+                 </property>
+                </spacer>
+               </item>
               </layout>
              </item>
              <item>
@@ -67,19 +157,6 @@
              </item>
             </layout>
            </item>
-           <item>
-            <widget class="QLabel" name="label_3">
-             <property name="sizePolicy">
-              <sizepolicy hsizetype="Expanding" vsizetype="Preferred">
-               <horstretch>0</horstretch>
-               <verstretch>0</verstretch>
-              </sizepolicy>
-             </property>
-             <property name="text">
-              <string>You can enter your formula.</string>
-             </property>
-            </widget>
-           </item>
           </layout>
          </widget>
          <widget class="QWidget" name="tab">

+ 12 - 2
libmandel/include/IterationCompiler.h

@@ -16,9 +16,19 @@ namespace mnd
     class MandelContext;
     class MandelDevice;
 
-    mnd::ExecData compile(mnd::MandelContext& mndCtxt);
+    enum class GeneratorType : int;
+
+    //mnd::ExecData compile(mnd::MandelContext& mndCtxt);
+
+    std::unique_ptr<mnd::MandelGenerator> compileCpu(mnd::MandelContext& mndCtxt,
+        const IterationFormula& z0,
+        const IterationFormula& z);
+
+    std::vector<std::pair<mnd::GeneratorType, std::unique_ptr<mnd::MandelGenerator>>> compileOpenCl(const mnd::MandelDevice& dev,
+        const IterationFormula& z0,
+        const IterationFormula& z);
 }
-void squareTest();
+//void squareTest();
 
 
 

+ 9 - 2
libmandel/include/IterationFormula.h

@@ -5,6 +5,7 @@
 #include <memory>
 #include <string>
 #include <stdexcept>
+#include <optional>
 
 #include "Types.h"
 
@@ -46,10 +47,16 @@ namespace mnd
 
 struct mnd::IterationFormula
 {
+    std::vector<std::string> variables;
     std::unique_ptr<Expression> expr;
-    IterationFormula(Expression expr);
+    IterationFormula(std::unique_ptr<Expression> expr, const std::vector<std::string>& variables = { "c", "z" });
+    IterationFormula(Expression expr, const std::vector<std::string>& variables = { "c", "z" });
 
+    std::optional<std::string> findUnknownVariables(const Expression& expr);
     void optimize(void);
+    bool containsVariable(const std::string& name) const;
+
+    IterationFormula clone(void) const;
 };
 
 
@@ -78,7 +85,7 @@ struct mnd::Variable
 struct mnd::Negation
 {
     std::unique_ptr<Expression> operand;
-    /*inline UnaryOperation(const UnaryOperation& other) :
+    /*inline Negation(const Negation& other) :
         operand{ std::make_unique<Expression>(*other.operand) }
     {}*/
 };

+ 6 - 3
libmandel/include/IterationGenerator.h

@@ -25,16 +25,17 @@ namespace mnd
 class mnd::IterationGenerator : public mnd::MandelGenerator
 {
 protected:
-    IterationFormula itf;
+    IterationFormula z0;
+    IterationFormula zi;
 public:
-    IterationGenerator(IterationFormula itf, const mnd::Real& prec);
+    IterationGenerator(IterationFormula z0, IterationFormula zi, const mnd::Real& prec);
 };
 
 
 class mnd::NaiveGenerator : public mnd::IterationGenerator
 {
 public:
-    NaiveGenerator(IterationFormula itf, const mnd::Real& prec);
+    NaiveGenerator(IterationFormula z0, IterationFormula zi, const mnd::Real& prec);
 
     virtual void generate(const MandelInfo& info, float* data);
 private:
@@ -43,6 +44,7 @@ private:
 };
 
 
+#if defined(__x86_64__) || defined(_M_X64)
 class mnd::CompiledGenerator : public mnd::MandelGenerator
 {
     std::unique_ptr<ExecData> execData;
@@ -54,6 +56,7 @@ public:
 
     std::string dump(void) const;
 };
+#endif
 
 
 #ifdef WITH_OPENCL

+ 4 - 1
libmandel/include/IterationIR.h

@@ -55,6 +55,9 @@ namespace mnd
         struct Formula
         {
             util::Arena<Node> nodeArena;
+            Node* startA;
+            Node* startB;
+
             Node* newA;
             Node* newB;
 
@@ -64,7 +67,7 @@ namespace mnd
         };
     }
 
-    ir::Formula expand(const mnd::IterationFormula& fmla);
+    ir::Formula expand(const mnd::IterationFormula& fmla, const mnd::IterationFormula& z0);
 }
 
 

+ 2 - 2
libmandel/include/Mandel.h

@@ -21,7 +21,7 @@ namespace asmjit { class JitRuntime; }
 
 namespace mnd
 {
-    enum class GeneratorType;
+    enum class GeneratorType : int;
     class MandelContext;
     class MandelDevice;
 
@@ -34,7 +34,7 @@ namespace mnd
 }
 
 
-enum class mnd::GeneratorType
+enum class mnd::GeneratorType : int
 {
     FLOAT,
     FLOAT_SSE2,

+ 74 - 32
libmandel/src/IterationCompiler.cpp

@@ -52,16 +52,16 @@ namespace mnd
         }
 
         Reg operator()(const ir::Variable& v) {
-            if (v.name == "a") {
+            if (v.name == "z_re") {
                 return a;
             }
-            else if (v.name == "b") {
+            else if (v.name == "z_im") {
                 return b;
             }
-            else if (v.name == "x") {
+            else if (v.name == "c_re") {
                 return x;
             }
-            else if (v.name == "y") {
+            else if (v.name == "c_im") {
                 return y;
             }
             else
@@ -201,22 +201,29 @@ namespace mnd
         x86::Xmm a = comp.newXmmSd();
         x86::Xmm b = comp.newXmmSd();
         comp.addFunc(FuncSignatureT<int, double, double, int>(CallConv::kIdHost));
+
         comp.setArg(0, x);
         comp.setArg(1, y);
+
+        CompileVisitor formVisitor{ comp, a, b, x, y };
+        auto startA = std::visit(formVisitor, *formula.startA);
+        auto startB = std::visit(formVisitor, *formula.startB);
+        comp.movapd(a, startA);
+        comp.movapd(b, startB);
+
         comp.setArg(2, maxIter);
         //comp.movapd(a, x);
         //comp.movapd(b, y);
 
-        comp.xorpd(a, a);
-        comp.movapd(b, b);
+        //comp.xorpd(a, a);
+        //comp.movapd(b, b);
 
         comp.xor_(k, k);
 
         comp.bind(startLoop);
 
-        CompileVisitor cv{ comp, a, b, x, y };
-        auto newA = std::visit(cv, *formula.newA);
-        auto newB = std::visit(cv, *formula.newB);
+        auto newA = std::visit(formVisitor, *formula.newA);
+        auto newB = std::visit(formVisitor, *formula.newB);
         comp.movapd(a, newA);
         comp.movapd(b, newB);
 
@@ -233,7 +240,7 @@ namespace mnd
 
         comp.inc(k);
         comp.cmp(k, maxIter);
-        comp.jne(startLoop);
+        comp.jle(startLoop);
         comp.bind(endLoop);
         comp.ret(k);
         comp.endFunc();
@@ -255,6 +262,11 @@ namespace mnd
         int varnameCounter = 0;
         std::stringstream code;
 
+        OpenClVisitor(int startVarname) :
+            varnameCounter{ startVarname }
+        {
+        }
+
         std::string createVarname(void)
         {
             return "tmp"s + std::to_string(varnameCounter++);
@@ -333,33 +345,38 @@ namespace mnd
 
     std::string compileToOpenCl(const ir::Formula& formula)
     {
-        OpenClVisitor ocv;
+        OpenClVisitor z0Visitor{ 0 };
+        std::string startA = z0Visitor.visitNode(*formula.startA);
+        std::string startB = z0Visitor.visitNode(*formula.startB);
+
+        OpenClVisitor ocv{ z0Visitor.varnameCounter };
         std::string newA = ocv.visitNode(*formula.newA);
         std::string newB = ocv.visitNode(*formula.newB);
 
-        std::string prelude = 
-"__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
-"   int index = get_global_id(0);\n"
-"   int ix = index % width;\n"
-"   int iy = index / width;\n"
-"   float a = 0;\n"
-"   float b = 0;\n"
-"   float x = ix * pixelScaleX + xl;\n"
-"   float y = iy * pixelScaleY + yt;\n"
-"\n"
-"   int n = 0;\n"
-"   while (n < max - 1) {\n";
-
+        std::string prelude =
+            "__kernel void iterate(__global float* A, const int width, float xl, float yt, float pixelScaleX, float pixelScaleY, int max, int smooth, int julia, float juliaX, float juliaY) {\n"
+            "   int index = get_global_id(0);\n"
+            "   int ix = index % width;\n"
+            "   int iy = index / width;\n"
+            "   float c_re = ix * pixelScaleX + xl;\n"
+            "   float c_im = iy * pixelScaleY + yt;\n";
+        prelude += z0Visitor.code.str() +
+            "   float z_re = " + startA + ";\n" +
+            "   float z_im = " + startB + ";\n" +
+            "\n"
+            "   int n = 0;\n"
+            "   while (n < max - 1) {\n";
+        /*
         std::string orig = 
 "       float aa = a * a;"
 "       float bb = b * b;"
 "       float ab = a * b;"
 "       a = aa - bb + x;"
 "       b = ab + ab + y;";
-
+*/
     
         std::string after = 
-"       if (a * a + b * b > 16) break;\n"
+"       if (z_re * z_re + z_im * z_im > 16) break;\n"
 "       n++;\n"
 "   }\n"
 "   if (n >= max - 1) {\n"
@@ -372,8 +389,8 @@ namespace mnd
 
 
         std::string code = prelude + ocv.code.str();
-        code += "a = " + newA + ";\n";
-        code += "b = " + newB + ";\n";
+        code += "z_re = " + newA + ";\n";
+        code += "z_im = " + newB + ";\n";
         code += after;
         //code = mnd::getFloat_cl();
         printf("cl: %s\n", code.c_str()); fflush(stdout);
@@ -386,8 +403,31 @@ namespace mnd
         return std::make_unique<CompiledClGenerator>(md, compileToOpenCl(formula));
     }
 #endif
+
+    std::unique_ptr<mnd::MandelGenerator> compileCpu(mnd::MandelContext& mndCtxt,
+        const IterationFormula& z0,
+        const IterationFormula& zi)
+    {
+        auto ng = std::make_unique<NaiveGenerator>(z0.clone(), zi.clone(), mnd::getPrecision<double>());
+        
+        //ir::Formula irf = mnd::expand(zi, z0);
+        //auto dg = std::make_unique<CompiledGenerator>(compile(irf));
+
+        return ng;
+    }
+
+    std::vector<std::pair<mnd::GeneratorType, std::unique_ptr<mnd::MandelGenerator>>> compileOpenCl(const mnd::MandelDevice& dev,
+        const IterationFormula& z0,
+        const IterationFormula& zi)
+    {
+        ir::Formula irf = mnd::expand(zi, z0);
+        auto fl = compileCl(irf, dev);
+        return {};// { { mnd::GeneratorType::FLOAT, std::move(fl) } };
+    }
 }
 
+
+
 using namespace asmjit;
 
 struct Visitor
@@ -398,6 +438,7 @@ struct Visitor
 
 namespace mnd
 {
+    /*
     mnd::ExecData compile(mnd::MandelContext& mndCtxt)
     {
 
@@ -465,9 +506,10 @@ namespace mnd
             throw "error compiling";
         }
         return ed;
-    }
+    }*/
 
 
+    /*
     mnd::ExecData compile(mnd::MandelContext& mndCtxt, const IterationFormula& formula)
     {
         mnd::ExecData ed;
@@ -515,10 +557,10 @@ namespace mnd
             throw "error compiling";
         }
         return ed;
-    }
+    }*/
 }
 
-
+/*
 void squareTest()
 {
     mnd::Expression power = mnd::Pow{
@@ -563,6 +605,6 @@ void squareTest()
     double result = func(1.0, 3.0);
     printf("result: %f\n", result);
 }
-
+*/
 
 

+ 90 - 2
libmandel/src/IterationFormula.cpp

@@ -9,8 +9,19 @@
 using mnd::ParseError;
 
 
-mnd::IterationFormula::IterationFormula(mnd::Expression expr) :
-    expr{ std::make_unique<mnd::Expression>(std::move(expr)) }
+mnd::IterationFormula::IterationFormula(std::unique_ptr<Expression> expr, const std::vector<std::string>& variables) :
+    expr{ std::move(expr) },
+    variables{ variables }
+{
+    auto maybeUnknown = findUnknownVariables(*this->expr);
+    if (maybeUnknown.has_value()) {
+        throw ParseError(std::string("unknown variable: ") + maybeUnknown.value());
+    }
+}
+
+
+mnd::IterationFormula::IterationFormula(mnd::Expression expr, const std::vector<std::string>& variables) :
+    IterationFormula{ std::make_unique<mnd::Expression>(std::move(expr)), variables }
 {
 }
 
@@ -124,6 +135,45 @@ struct SimpleOptimizer
 };
 
 
+std::optional<std::string> mnd::IterationFormula::findUnknownVariables(const Expression& expr)
+{
+    std::string unknownVariable;
+    std::function<bool(const Expression&)> isCorrect;
+    auto corrLambda = [this, &isCorrect, &unknownVariable](const auto& x) {
+        using T = std::decay_t<decltype(x)>;
+        if constexpr (std::is_same<T, mnd::Variable>::value) {
+            if (containsVariable(x.name)) {
+                return true;
+            }
+            else {
+                unknownVariable = x.name;
+                return false;
+            }
+        }
+        else if constexpr (std::is_same<T, mnd::Negation>::value) {
+            return isCorrect(*x.operand);
+        }
+        else if constexpr (std::is_same<T, mnd::Addition>::value ||
+            std::is_same<T, mnd::Multiplication>::value ||
+            std::is_same<T, mnd::Division>::value ||
+            std::is_same<T, mnd::Pow>::value) {
+            return isCorrect(*x.left) && isCorrect(*x.right);
+        }
+        return true;
+    };
+    isCorrect = [corrLambda](const mnd::Expression& x) {
+        return std::visit(corrLambda, x);
+    };
+    bool allCorrect = isCorrect(expr);
+    if (allCorrect) {
+        return std::nullopt;
+    }
+    else {
+        return unknownVariable;
+    }
+}
+
+
 void mnd::IterationFormula::optimize(void)
 {
     SimpleOptimizer so;
@@ -131,6 +181,43 @@ void mnd::IterationFormula::optimize(void)
 }
 
 
+bool mnd::IterationFormula::containsVariable(const std::string& name) const
+{
+    for (const auto& varname : variables) {
+        if (varname == name)
+            return true;
+    }
+    return false;
+}
+
+
+mnd::IterationFormula mnd::IterationFormula::clone(void) const
+{
+
+
+    std::function<std::unique_ptr<mnd::Expression>(const mnd::Expression&)> cloner;
+    cloner = [&cloner](const mnd::Expression& e) {
+        return std::make_unique<mnd::Expression>(std::visit([&cloner](const auto& x) -> mnd::Expression {
+            using T = std::decay_t<decltype(x)>;
+            if constexpr (std::is_same<T, mnd::Constant>::value) {
+                return mnd::Constant{ 0, 0 };
+            }
+            else if constexpr (std::is_same<T, mnd::Variable>::value) {
+                return mnd::Variable{ x.name };
+            }
+            else if constexpr (std::is_same<T, mnd::Negation>::value) {
+                return mnd::Negation{ cloner(*x.operand) };
+            }
+            else {
+                return T{ cloner(*x.left), cloner(*x.right) };
+            }
+        }, e));
+    };
+    IterationFormula cl{ cloner(*expr), this->variables };
+    return cl;
+}
+
+
 static const std::string regexIdent = "[A-Za-z][A-Za-z0-9]*";
 static const std::string regexNum = "[1-9][0-9]*";
 static const std::string regexFloat = "(\\d*\\.?\\d+|\\d+\\.?\\d*)([eE][-+]\\d+)?";
@@ -158,6 +245,7 @@ public:
     void parse(void)
     {
         std::string token;
+        bool unary = true;
         while (getToken(token)) {
             if (std::regex_match(token, num)) {
                 output.push_back(mnd::Constant{ std::atof(token.c_str()) });

+ 12 - 13
libmandel/src/IterationGenerator.cpp

@@ -12,19 +12,21 @@ using mnd::NaiveGenerator;
 using mnd::IterationFormula;
 
 
-IterationGenerator::IterationGenerator(IterationFormula itf,
+IterationGenerator::IterationGenerator(IterationFormula z0, IterationFormula zi,
                                    const mnd::Real& prec) :
     mnd::MandelGenerator{ prec },
-    itf{ std::move(itf) }
+    z0{ std::move(z0) },
+    zi{ std::move(zi) }
 {
 }
 
 
-NaiveGenerator::NaiveGenerator(IterationFormula itf,
+NaiveGenerator::NaiveGenerator(IterationFormula z0, IterationFormula zi,
                                    const mnd::Real& prec) :
-    IterationGenerator{ std::move(itf), prec }
+    IterationGenerator{ std::move(z0), std::move(zi), prec }
 {
-    this->itf.optimize();
+    this->z0.optimize();
+    this->zi.optimize();
 }
 
 
@@ -51,12 +53,9 @@ void NaiveGenerator::generate(const mnd::MandelInfo& info, float* data)
         for (i; i < info.bWidth; i++) {
             T x = viewx + T(double(i)) * wpp;
 
-            T cx = info.julia ? juliaX : x;
-            T cy = info.julia ? juliaY : y;
-            std::complex<double> z{ x, y };
-            if (!info.julia) {
-                z = 0;
-            }
+            T cx = x;
+            T cy = y;
+            std::complex<double> z = calc(*z0.expr, { 0, 0 }, { x, y });
             std::complex<double> c{ cx, cy };
 
             int k = 0;
@@ -84,7 +83,7 @@ void NaiveGenerator::generate(const mnd::MandelInfo& info, float* data)
 
 std::complex<double> NaiveGenerator::iterate(std::complex<double> z, std::complex<double> c)
 {
-    auto& expr = *itf.expr;
+    auto& expr = *zi.expr;
     return calc(expr, z, c);
 }
 
@@ -183,7 +182,7 @@ void CompiledGenerator::generate(const mnd::MandelInfo& info, float* data)
             double x = mnd::convert<double>(info.view.x + info.view.width * j / info.bWidth);
             IterFunc iterFunc = asmjit::ptr_as_func<IterFunc>(this->execData->iterationFunc);
             int k = iterFunc(x, y, info.maxIter);
-            data[i * info.bWidth + j] = k;
+            data[i * info.bWidth + j] = float(k);
         }
     }
 }

+ 10 - 14
libmandel/src/IterationIR.cpp

@@ -14,13 +14,15 @@ namespace mnd
     {
         using NodePair = std::pair<Node*, Node*>;
         util::Arena<Node>& arena;
+        const mnd::IterationFormula& iterationFormula;
 
         Node* zero;
         Node* half;
         Node* one;
 
-        ConvertVisitor(util::Arena<Node>& arena) :
-            arena{ arena }
+        ConvertVisitor(util::Arena<Node>& arena, const mnd::IterationFormula& iterationFormula) :
+            arena{ arena },
+            iterationFormula{ iterationFormula }
         {
             zero = arena.allocate(ir::Constant{ 0.0 });
             half = arena.allocate(ir::Constant{ 0.5 });
@@ -38,18 +40,11 @@ namespace mnd
         NodePair operator() (const Variable& v)
         {
             //printf("var %s\n", v.name.c_str()); fflush(stdout);
-            if (v.name == "z") {
-                Node* a = arena.allocate(ir::Variable{ "a" });
-                Node* b = arena.allocate(ir::Variable{ "b" });
-
+            if (iterationFormula.containsVariable(v.name)) {
+                Node* a = arena.allocate(ir::Variable{ v.name + "_re" });
+                Node* b = arena.allocate(ir::Variable{ v.name + "_im" });
                 return { a, b };
             }
-            else if (v.name == "c") {
-                Node* x = arena.allocate(ir::Variable{ "x" });
-                Node* y = arena.allocate(ir::Variable{ "y" });
-
-                return { x, y };
-            }
             else if (v.name == "i") {
                 return { zero, one };
             }
@@ -225,10 +220,11 @@ namespace mnd
         }
     };
 
-    ir::Formula expand(const mnd::IterationFormula& fmla)
+    ir::Formula expand(const mnd::IterationFormula& fmla, const mnd::IterationFormula& z0)
     {
         ir::Formula formula;
-        ConvertVisitor cv{ formula.nodeArena };
+        ConvertVisitor cv{ formula.nodeArena, fmla };
+        std::tie(formula.startA, formula.startB) = std::visit(cv, *z0.expr);
         std::tie(formula.newA, formula.newB) = std::visit(cv, *fmla.expr);
         return formula;
     }