Selaa lähdekoodia

introduced arrays

Nicolas Winkler 6 vuotta sitten
vanhempi
commit
7a482933ee

+ 92 - 8
src/CodegenVisitor.cpp

@@ -17,7 +17,8 @@ using namespace qlow;
 llvm::Value* ExpressionCodegenVisitor::visit(sem::LocalVariableExpression& lve, llvm::IRBuilder<>& builder)
 {
     assert(lve.var->allocaInst != nullptr);
-    if (llvm::dyn_cast<llvm::AllocaInst>(lve.var->allocaInst)) {
+    // TODO improve handling of arrays
+    if (llvm::dyn_cast<llvm::AllocaInst>(lve.var->allocaInst) && !lve.type->isArrayType()) {
         llvm::Type* returnType = lve.type->getLlvmType(builder.getContext());
         llvm::Value* val = builder.CreateLoad(returnType, lve.var->allocaInst);
         return val;
@@ -152,8 +153,30 @@ llvm::Value* ExpressionCodegenVisitor::visit(sem::NewExpression& nexpr, llvm::IR
 llvm::Value* ExpressionCodegenVisitor::visit(sem::NewArrayExpression& naexpr, llvm::IRBuilder<>& builder)
 {
     using llvm::Value;
-    // TODO implement
-    return nullptr;
+
+    sem::Context& semCtxt = naexpr.context;
+    llvm::LLVMContext& llvmCtxt = builder.getContext();
+
+    const llvm::DataLayout& layout = builder.GetInsertBlock()->getModule()->getDataLayout();
+    llvm::Type* llvmTy = naexpr.elementType->getLlvmType(llvmCtxt);
+    llvm::Type* arrayStructType = naexpr.type->getLlvmType(llvmCtxt);
+    auto elementSize = layout.getTypeAllocSize(llvmTy);
+
+    llvm::Value* lengthExpr = naexpr.length->accept(*this, builder);
+    llvm::Value* allocSize = builder.CreateMul(lengthExpr, llvm::ConstantInt::get(llvm::Type::getInt64Ty(llvmCtxt), elementSize));
+
+    auto mallocCall = llvm::CallInst::CreateMalloc(builder.GetInsertBlock(), allocSize->getType(), llvmTy, allocSize, nullptr, nullptr, "");
+    //auto casted = builder.CreateBitCast(mallocCall, llvmTy);
+    builder.GetInsertBlock()->getInstList().push_back(llvm::cast<llvm::Instruction>(mallocCall));
+
+    llvm::Value* result = builder.CreateAlloca(arrayStructType);
+    llvm::Value* arrRef = builder.CreateStructGEP(arrayStructType, result, 0);
+    llvm::Value* lenRef = builder.CreateStructGEP(arrayStructType, result, 1);
+
+    builder.CreateStore(mallocCall, arrRef);
+    builder.CreateStore(lengthExpr, lenRef);
+
+    return result; // builder.CreateGEP(result, llvm::ConstantInt::get(llvm::Type::getInt64Ty(llvmCtxt), 0));
 }
 
 
@@ -249,6 +272,30 @@ llvm::Value* ExpressionCodegenVisitor::visit(sem::AddressExpression& node, llvm:
 }
 
 
+llvm::Value* ExpressionCodegenVisitor::visit(sem::ArrayAccessExpression& node, llvm::IRBuilder<>& builder)
+{
+    auto array = node.array->accept(*this, builder);
+    auto index = node.index->accept(*this, builder);
+
+    auto arrType = node.array->type;
+    if (!arrType->isArrayType()) {
+        throw "trying to access non-array type as array.";
+    }
+
+    sem::ArrayType* at = static_cast<sem::ArrayType*>(arrType);
+    auto elemType = at->getArrayOf();
+
+    auto arrPtr = builder.CreateStructGEP(at->getLlvmType(builder.getContext()), array, 0);
+    // TODO implement range checks
+    //auto length = builder.CreateStructGEP(at->getLlvmType(builder.getContext()), array, 1);
+
+    auto arr = builder.CreateLoad(arrPtr);
+    auto accessVal = builder.CreateGEP(arr, index);
+
+    return builder.CreateLoad(accessVal);
+}
+
+
 llvm::Value* ExpressionCodegenVisitor::visit(sem::IntConst& node, llvm::IRBuilder<>& builder)
 {
     return llvm::ConstantInt::get(builder.getContext(),
@@ -285,8 +332,8 @@ llvm::Value* LValueVisitor::visit(sem::LocalVariableExpression& lve, qlow::gen::
         return lve.var->allocaInst;
     }
     else {
-        return lve.var->allocaInst;
-        //throw "unable to find alloca instance of local variable";
+        throw "unable to find alloca instance of local variable";
+        //return lve.var->allocaInst;
     }
 }
 
@@ -335,10 +382,38 @@ llvm::Value* LValueVisitor::visit(sem::FieldAccessExpression& access, qlow::gen:
 }
 
 
-llvm::Value* StatementVisitor::visit(sem::DoEndBlock& assignment,
+llvm::Value* LValueVisitor::visit(sem::ArrayAccessExpression& node, qlow::gen::FunctionGenerator& fg)
+{
+    auto& builder = fg.builder;
+    auto array = node.array->accept(fg.expressionVisitor, builder);
+    auto index = node.index->accept(fg.expressionVisitor, builder);
+
+    auto arrType = node.array->type;
+    if (!arrType->isArrayType()) {
+        throw "trying to access non-array type as array.";
+    }
+
+    //auto ostr = llvm::raw_os_ostream(Printer::getInstance());
+    //fg.getModule()->print(ostr, nullptr);
+
+    auto elemType = arrType->getArrayOf();
+
+    auto arrPtr = builder.CreateStructGEP(arrType->getLlvmType(builder.getContext()), array, 0);
+    // TODO implement range check
+    //auto length = builder.CreateStructGEP(arrType->getLlvmType(builder.getContext()), array, 1);
+
+    auto arr = builder.CreateLoad(arrPtr);
+
+    auto accessVal = builder.CreateGEP(arr, index);
+
+    return accessVal;
+}
+
+
+llvm::Value* StatementVisitor::visit(sem::DoEndBlock& block,
         qlow::gen::FunctionGenerator& fg)
 {
-    for (auto& statement : assignment.statements) {
+    for (auto& statement : block.statements) {
         statement->accept(*this, fg);
     }
     return nullptr;
@@ -426,7 +501,13 @@ llvm::Value* StatementVisitor::visit(sem::AssignmentStatement& assignment,
     auto val = assignment.value->accept(fg.expressionVisitor, fg.builder);
     auto target = assignment.target->accept(fg.lvalueVisitor, fg);
     
-    return fg.builder.CreateStore(val, target);
+    if (val->getType()->isPointerTy() && val->getType()->getPointerElementType()->isStructTy()) {
+        const llvm::DataLayout& layout = fg.builder.GetInsertBlock()->getModule()->getDataLayout();
+        return fg.builder.CreateMemCpy(target, val, layout.getTypeAllocSize(val->getType()->getPointerElementType()), 1);
+    }
+    else {
+        return fg.builder.CreateStore(val, target);
+    }
     
     /*
     if (auto* targetVar =
@@ -471,6 +552,9 @@ llvm::Value* StatementVisitor::visit(sem::ReturnStatement& returnStatement,
 {
     fg.builder.SetInsertPoint(fg.getCurrentBlock());
     auto val = returnStatement.value->accept(fg.expressionVisitor, fg.builder);
+    if (returnStatement.value != nullptr && val == nullptr) {
+        throw "internal error: returned type is invalid";
+    }
     fg.builder.CreateRet(val);
     return val;
 }

+ 5 - 1
src/CodegenVisitor.h

@@ -47,6 +47,7 @@ class qlow::ExpressionCodegenVisitor :
         sem::MethodCallExpression,
         sem::FieldAccessExpression,
         sem::AddressExpression,
+        sem::ArrayAccessExpression,
         sem::IntConst,
         sem::ThisExpression
     >
@@ -67,6 +68,7 @@ public:
     llvm::Value* visit(sem::MethodCallExpression& node, llvm::IRBuilder<>&) override;
     llvm::Value* visit(sem::FieldAccessExpression& node, llvm::IRBuilder<>&) override;
     llvm::Value* visit(sem::AddressExpression& node, llvm::IRBuilder<>&) override;
+    llvm::Value* visit(sem::ArrayAccessExpression& node, llvm::IRBuilder<>&) override;
     llvm::Value* visit(sem::IntConst& node, llvm::IRBuilder<>&) override;
     llvm::Value* visit(sem::ThisExpression& node, llvm::IRBuilder<>&) override;
 };
@@ -79,13 +81,15 @@ class qlow::LValueVisitor :
 
         sem::Expression,
         sem::LocalVariableExpression,
-        sem::FieldAccessExpression
+        sem::FieldAccessExpression,
+        sem::ArrayAccessExpression
     >
 {
 public:
     llvm::Value* visit(sem::Expression& node, qlow::gen::FunctionGenerator& fg) override;
     llvm::Value* visit(sem::LocalVariableExpression& node, qlow::gen::FunctionGenerator& fg) override;
     llvm::Value* visit(sem::FieldAccessExpression& node, qlow::gen::FunctionGenerator& fg) override;
+    llvm::Value* visit(sem::ArrayAccessExpression& node, qlow::gen::FunctionGenerator& fg) override;
 };
 
 

+ 1 - 0
src/ast/Ast.cpp

@@ -53,6 +53,7 @@ ACCEPT_DEFINITION(AssignmentStatement, StructureVisitor)
 ACCEPT_DEFINITION(ReturnStatement, StructureVisitor)
 ACCEPT_DEFINITION(LocalVariableStatement, StructureVisitor)
 ACCEPT_DEFINITION(AddressExpression, StructureVisitor)
+ACCEPT_DEFINITION(ArrayAccessExpression, StructureVisitor)
 ACCEPT_DEFINITION(IntConst, StructureVisitor)
 ACCEPT_DEFINITION(StringConst, StructureVisitor)
 ACCEPT_DEFINITION(UnaryOperation, StructureVisitor)

+ 19 - 0
src/ast/Ast.h

@@ -75,6 +75,7 @@ namespace qlow
         struct ReturnStatement;
         struct LocalVariableStatement;
         struct AddressExpression;
+        struct ArrayAccessExpression;
 
         struct IntConst;
         struct StringConst;
@@ -490,6 +491,24 @@ struct qlow::ast::AddressExpression : public Expression
 };
 
 
+struct qlow::ast::ArrayAccessExpression : public Expression
+{
+    std::unique_ptr<Expression> array;
+    std::unique_ptr<Expression> index;
+    inline ArrayAccessExpression(std::unique_ptr<Expression> array,
+                                 std::unique_ptr<Expression> index,
+                                 const CodePosition& cp) :
+        AstObject{ cp },
+        Expression{ cp },
+       array{ std::move(array) },
+       index{ std::move(index) }
+    {
+    } 
+
+    virtual std::unique_ptr<sem::SemanticObject> accept(StructureVisitor& v, sem::Scope&);
+};
+
+
 struct qlow::ast::IntConst : public Expression
 {
     unsigned long long value;

+ 12 - 1
src/ast/AstVisitor.cpp

@@ -362,6 +362,16 @@ std::unique_ptr<sem::SemanticObject> StructureVisitor::visit(
 }
 
 
+std::unique_ptr<sem::SemanticObject> StructureVisitor::visit(
+    ast::ArrayAccessExpression& ast, sem::Scope& scope)
+{
+    auto array = unique_dynamic_cast<sem::Expression>(ast.array->accept(*this, scope));
+    auto index = unique_dynamic_cast<sem::Expression>(ast.index->accept(*this, scope));
+
+    return std::make_unique<sem::ArrayAccessExpression>(std::move(array), std::move(index), ast.pos);
+}
+
+
 std::unique_ptr<sem::SemanticObject> StructureVisitor::visit(ast::IntConst& ast, sem::Scope& scope)
 {
     return std::make_unique<sem::IntConst>(scope.getContext(), ast.value, ast.pos);
@@ -439,7 +449,8 @@ std::unique_ptr<sem::SemanticObject> StructureVisitor::visit(ast::NewExpression&
 
 std::unique_ptr<sem::SemanticObject> StructureVisitor::visit(ast::NewArrayExpression& ast, sem::Scope& scope)
 {
-    auto ret = std::make_unique<sem::NewArrayExpression>(scope.getContext(), scope.getType(ast.type.get()), ast.pos);
+    auto length = unique_dynamic_cast<sem::Expression>(ast.length->accept(*this, scope));
+    auto ret = std::make_unique<sem::NewArrayExpression>(scope.getType(ast.type.get()), std::move(length), ast.pos);
     return ret;
 }
 

+ 2 - 0
src/ast/AstVisitor.h

@@ -48,6 +48,7 @@ class qlow::StructureVisitor :
         ast::ReturnStatement,
         ast::LocalVariableStatement,
         ast::AddressExpression,
+        ast::ArrayAccessExpression,
         ast::IntConst,
         ast::StringConst,
         ast::UnaryOperation,
@@ -75,6 +76,7 @@ public:
     ReturnType visit(ast::ReturnStatement& ast, sem::Scope& scope) override;
     ReturnType visit(ast::LocalVariableStatement& ast, sem::Scope& scope) override;
     ReturnType visit(ast::AddressExpression& ast, sem::Scope& scope) override;
+    ReturnType visit(ast::ArrayAccessExpression& ast, sem::Scope& scope) override;
     ReturnType visit(ast::IntConst& ast, sem::Scope& scope) override;
     ReturnType visit(ast::StringConst& ast, sem::Scope& scope) override;
     ReturnType visit(ast::UnaryOperation& ast, sem::Scope& scope) override;

+ 18 - 2
src/ast/syntax.y

@@ -141,6 +141,7 @@ while (0)
     qlow::ast::ReturnStatement* returnStatement;
     qlow::ast::LocalVariableStatement* localVariableStatement;
     qlow::ast::AddressExpression* addressExpression;
+    qlow::ast::ArrayAccessExpression* arrayAccessExpression;
 
     qlow::ast::UnaryOperation* unaryOperation;
     qlow::ast::BinaryOperation* binaryOperation;
@@ -193,6 +194,7 @@ while (0)
 %type <returnStatement> returnStatement
 %type <localVariableStatement> localVariableStatement
 %type <addressExpression> addressExpression
+%type <arrayAccessExpression> arrayAccessExpression
 %type <string> operator
 %type <unaryOperation> unaryOperation
 %type <binaryOperation> binaryOperation
@@ -218,7 +220,7 @@ while (0)
 
 %start topLevel
 
-%expect 65
+%expect 77
 
 %%
 
@@ -590,6 +592,10 @@ expression:
         $$ = $1;
     }
     |
+    arrayAccessExpression {
+        $$ = $1;
+    }
+    |
     castExpression {
         $$ = $1;
     }
@@ -725,6 +731,12 @@ addressExpression:
         $2 = nullptr;
     };
 
+arrayAccessExpression:
+    expression SQUARE_LEFT expression SQUARE_RIGHT {
+        $$ = new ArrayAccessExpression(std::unique_ptr<Expression>($1), std::unique_ptr<Expression>($3), @$);
+        $1 = nullptr; $3 = nullptr;
+    };
+
 paranthesesExpression:
     ROUND_LEFT expression ROUND_RIGHT {
         $$ = $2;
@@ -738,7 +750,11 @@ newExpression:
 
 newArrayExpression:
     NEW SQUARE_LEFT type SEMICOLON expression SQUARE_RIGHT {
-        $$ = nullptr;
+        $$ = new NewArrayExpression(
+            std::unique_ptr<qlow::ast::Type>($3),
+            std::unique_ptr<qlow::ast::Expression>($5),
+        @$);
+        $3 = nullptr; $5 = nullptr;
     };
     
 castExpression:

+ 9 - 2
src/sem/CodeGeneration.cpp

@@ -212,8 +212,15 @@ llvm::Function* generateStartFunction(llvm::Module* module, llvm::Function* star
     IRBuilder<> builder(context);
     BasicBlock* bb = BasicBlock::Create(context, "entry", startFunction);
     builder.SetInsertPoint(bb);
-    builder.CreateCall(start, {});
-    builder.CreateCall(exitFunction, { llvm::ConstantInt::get(context, llvm::APInt(32, "0", 10)) });
+    auto returnVal = builder.CreateCall(start, {});
+
+    if (start->getReturnType()->isIntegerTy()) {
+        auto rv = builder.CreateIntCast(returnVal, llvm::Type::getInt32Ty(context), true);
+        builder.CreateCall(exitFunction, { rv });
+    }
+    else {
+        builder.CreateCall(exitFunction, { llvm::ConstantInt::get(context, llvm::APInt(32, "0", 10)) });
+    }
     builder.CreateRetVoid();
 
     return startFunction;

+ 9 - 0
src/sem/Context.cpp

@@ -88,6 +88,7 @@ void Context::createLlvmTypes(llvm::LLVMContext& llvmCtxt)
             type->createLlvmTypeDecl(llvmCtxt);
         }
         else {
+            // all structs and array types are structs
             type->llvmType = llvm::StructType::create(llvmCtxt, type->asIdentifier());
         }
     }
@@ -104,6 +105,14 @@ void Context::createLlvmTypes(llvm::LLVMContext& llvmCtxt)
             if (type->getClass()->isReferenceType)
                 type->llvmType = type->llvmType->getPointerTo();
         }
+        if (type->isArrayType()) {
+            ArrayType* arrType = static_cast<ArrayType*>(type.get());
+            std::vector<llvm::Type*> structTypes {
+                arrType->elementType->getLlvmType(llvmCtxt)->getPointerTo(),    // elements pointer
+                llvm::Type::getInt64Ty(llvmCtxt)                                // length
+            };
+            llvm::dyn_cast<llvm::StructType>(type->llvmType)->setBody(llvm::ArrayRef(structTypes));
+        }
     }
 }
 

+ 14 - 0
src/sem/Semantic.cpp

@@ -189,12 +189,14 @@ ACCEPT_DEFINITION(UnaryOperation, ExpressionCodegenVisitor, llvm::Value*, llvm::
 ACCEPT_DEFINITION(MethodCallExpression, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
 ACCEPT_DEFINITION(FieldAccessExpression, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
 ACCEPT_DEFINITION(AddressExpression, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
+ACCEPT_DEFINITION(ArrayAccessExpression, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
 ACCEPT_DEFINITION(IntConst, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
 ACCEPT_DEFINITION(ThisExpression, ExpressionCodegenVisitor, llvm::Value*, llvm::IRBuilder<>&)
 
 ACCEPT_DEFINITION(Expression, LValueVisitor, llvm::Value*, qlow::gen::FunctionGenerator&)
 ACCEPT_DEFINITION(LocalVariableExpression, LValueVisitor, llvm::Value*, qlow::gen::FunctionGenerator&)
 ACCEPT_DEFINITION(FieldAccessExpression, LValueVisitor, llvm::Value*, qlow::gen::FunctionGenerator&)
+ACCEPT_DEFINITION(ArrayAccessExpression, LValueVisitor, llvm::Value*, qlow::gen::FunctionGenerator&)
 
 ACCEPT_DEFINITION(AssignmentStatement, StatementVisitor, llvm::Value*, qlow::gen::FunctionGenerator&) 
 ACCEPT_DEFINITION(DoEndBlock, StatementVisitor, llvm::Value*, qlow::gen::FunctionGenerator&) 
@@ -222,6 +224,12 @@ std::string LocalVariableExpression::toString(void) const
 }
 
 
+std::string ArrayAccessExpression::toString(void) const
+{
+    return "ArrayAccessExpression[" + array->toString() + "[" + index->toString() + "]]";
+}
+
+
 std::string UnaryOperation::toString(void) const
 {
     return "UnaryOperation[" + arg->toString() + "]";
@@ -276,6 +284,12 @@ std::string FieldAccessExpression::toString(void) const
 }
 
 
+std::string IntConst::toString(void) const
+{
+    return "IntConst[" + std::to_string(value) + "]";
+}
+
+
 std::string FeatureCallStatement::toString(void) const
 {
     return "FeatureCallStatement[" + expr->callee->toString() + "]";

+ 27 - 3
src/sem/Semantic.h

@@ -43,6 +43,7 @@ namespace qlow
         
         struct LocalVariableExpression;
         struct AddressExpression;
+        struct ArrayAccessExpression;
 
         struct Operation;
         struct UnaryOperation;
@@ -359,6 +360,27 @@ struct qlow::sem::AddressExpression : public Expression
 };
 
 
+struct qlow::sem::ArrayAccessExpression : public Expression
+{
+    std::unique_ptr<sem::Expression> array;
+    std::unique_ptr<sem::Expression> index;
+
+    inline ArrayAccessExpression(std::unique_ptr<sem::Expression> array,
+                                 std::unique_ptr<sem::Expression> index,
+                                 const CodePosition& pos) :
+        Expression{ array->context, array->type->getArrayOf(), pos },
+        array{ std::move(array) },
+        index{ std::move(index) }
+    {
+    }
+
+    virtual llvm::Value* accept(ExpressionCodegenVisitor& visitor, llvm::IRBuilder<>& arg2) override;
+    virtual llvm::Value* accept(LValueVisitor& visitor, qlow::gen::FunctionGenerator&) override;
+
+    virtual std::string toString(void) const override;
+};
+
+
 struct qlow::sem::BinaryOperation : public Operation
 {
     std::unique_ptr<Expression> left;
@@ -428,9 +450,10 @@ struct qlow::sem::NewArrayExpression : public Expression
     Type* elementType;
     std::unique_ptr<Expression> length;
     
-    inline NewArrayExpression(Context& context, Type* elementType, const CodePosition& pos) :
-        Expression{ context, context.getArrayType(elementType), pos },
-        elementType{ elementType }
+    inline NewArrayExpression(Type* elementType, std::unique_ptr<Expression> length, const CodePosition& pos) :
+        Expression{ length->context, length->context.getArrayType(elementType), pos },
+        elementType{ elementType },
+        length{ std::move(length) }
     {
     }
     
@@ -508,6 +531,7 @@ struct qlow::sem::IntConst : public Expression
     }
     
     virtual llvm::Value* accept(ExpressionCodegenVisitor& visitor, llvm::IRBuilder<>& arg2) override;
+    virtual std::string toString(void) const override;
 };
 
 

+ 12 - 1
src/sem/Type.cpp

@@ -20,7 +20,6 @@ qlow::sem::SemanticObject::~SemanticObject(void) = default;
 
 std::string qlow::sem::SemanticObject::toString(void) const
 {
-    
     return "SemanticObject [" + util::toString(this) + "]";
 }
 
@@ -41,6 +40,12 @@ qlow::sem::Class* Type::getClass(void) const
 }
 
 
+Type* Type::getArrayOf(void) const
+{
+    return nullptr;
+}
+
+
 void Type::setTypeScope(std::unique_ptr<TypeScope> scope)
 {
     this->typeScope = std::move(scope);
@@ -238,6 +243,12 @@ bool ArrayType::isArrayType(void) const
 }
 
 
+Type* ArrayType::getArrayOf(void) const
+{
+    return elementType;
+}
+
+
 std::string ArrayType::asString(void) const
 {
     return "[" + elementType->asString() + "]";

+ 9 - 0
src/sem/Type.h

@@ -79,6 +79,14 @@ public:
      *       it will not return a <code>nullptr</code>
      */
     virtual Class* getClass(void) const;
+
+    /**
+     * \brief get the type of which this type is an array type of.
+     * 
+     * \return the type of which this type is an array type of, or
+     *         <code>nullptr</code> if this type is not an array type.
+     */
+    virtual Type* getArrayOf(void) const;
     
     /**
      * @brief returns the type scope of this type
@@ -173,6 +181,7 @@ protected:
 public:
     virtual bool equals(const Type& other) const override;
     virtual bool isArrayType(void) const override;
+    virtual Type* getArrayOf(void) const override;
 
     virtual std::string asString(void) const override;
     virtual std::string asIdentifier(void) const override;