瀏覽代碼

improving structure

Nicolas Winkler 6 年之前
父節點
當前提交
f27b38e53f
共有 9 個文件被更改,包括 470 次插入26 次删除
  1. 2 1
      src/Driver.cpp
  2. 1 0
      src/Driver.h
  3. 3 1
      src/Scope.h
  4. 336 0
      src/sem/CodeGeneration.cpp
  5. 57 0
      src/sem/CodeGeneration.h
  6. 32 0
      src/sem/Context.cpp
  7. 17 12
      src/sem/Context.h
  8. 6 0
      src/sem/Type.cpp
  9. 16 12
      src/sem/Type.h

+ 2 - 1
src/Driver.cpp

@@ -193,10 +193,11 @@ bool Driver::parseStage(void)
 
 bool Driver::semanticStage(void)
 {
+    Logger& logger = Logger::getInstance();
     bool errorOccurred = false;
 
     try {
-        this->semClasses = qlow::sem::createFromAst(this->ast);
+        this->semClasses = qlow::sem::createFromAst(*this->ast);
     }
     catch(SemanticError& se) {
         se.print(logger);

+ 1 - 0
src/Driver.h

@@ -7,6 +7,7 @@
 #include <string>
 #include <utility>
 #include "Parser.h"
+#include "Scope.h"
 
 namespace qlow
 {

+ 3 - 1
src/Scope.h

@@ -9,6 +9,7 @@
 #include <llvm/IR/Value.h>
 
 #include "Util.h"
+#include "Context.h"
 
 namespace qlow
 {
@@ -66,6 +67,8 @@ public:
     SymbolTable<Class> classes;
     SymbolTable<Method> functions;
     OwningList<Cast> casts;
+
+    Context typeContext;
 public:
     virtual Variable* getVariable(const std::string& name);
     virtual Method* getMethod(const std::string& name);
@@ -139,7 +142,6 @@ public:
     {
     }
     
-    
     virtual Variable* getVariable(const std::string& name);
     virtual Method* getMethod(const std::string& name);
     virtual std::shared_ptr<Type> getType(const ast::Type& name);

+ 336 - 0
src/sem/CodeGeneration.cpp

@@ -0,0 +1,336 @@
+#include "CodeGeneration.h"
+
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/LegacyPassManager.h>
+#include <llvm/Transforms/IPO/PassManagerBuilder.h>
+#include <llvm/IR/Type.h>
+#include <llvm/IR/DerivedTypes.h>
+#include <llvm/IR/Constants.h>
+#include <llvm/IR/BasicBlock.h>
+#include <llvm/IR/Verifier.h>
+#include <llvm/IR/IRBuilder.h>
+#include <llvm/IR/Attributes.h>
+#include <llvm/Target/TargetMachine.h>
+#include <llvm/Support/TargetRegistry.h>
+#include <llvm/Support/TargetSelect.h>
+#include <llvm/Support/FileSystem.h>
+#include <llvm/Support/raw_os_ostream.h>
+
+
+using namespace qlow;
+
+static llvm::LLVMContext context;
+
+namespace qlow
+{
+namespace gen
+{
+
+std::unique_ptr<llvm::Module> generateModule(const sem::Semantic& semantic)
+{
+    using llvm::Module;
+    using llvm::Function;
+    using llvm::Argument;
+    using llvm::Type;
+    using llvm::FunctionType;
+    using llvm::BasicBlock;
+    using llvm::Value;
+    using llvm::IRBuilder;
+    
+    Logger& logger = Logger::getInstance();
+    
+#ifdef DEBUGGING
+        printf("creating llvm module\n"); 
+#endif 
+
+    std::unique_ptr<Module> module = llvm::make_unique<Module>("qlow_module", context);
+
+    // create llvm structs
+    // TODO implement detection of circles
+    for (auto& [name, cl] : semantic.getClasses()) {
+        llvm::StructType* st;
+        std::vector<llvm::Type*> fields;
+#ifdef DEBUGGING
+        printf("creating llvm struct for %s\n", name.c_str());
+#endif
+        int llvmStructIndex = 0;
+        for (auto& [name, field] : cl->fields) {
+            field->llvmStructIndex = llvmStructIndex;
+            fields.push_back(field->type->getLlvmType(context));
+            if (fields[fields.size() - 1] == nullptr)
+                throw "internal error: possible circular dependency";
+            
+            llvmStructIndex++;
+        }
+        st = llvm::StructType::create(context, fields, name);
+        cl->llvmType = st;
+    }
+    
+    llvm::AttrBuilder ab;
+    ab.addAttribute(llvm::Attribute::AttrKind::NoInline);
+    ab.addAttribute(llvm::Attribute::AttrKind::NoUnwind);
+    //ab.addAttribute(llvm::Attribute::AttrKind::OptimizeNone);
+    //ab.addAttribute(llvm::Attribute::AttrKind::UWTable);
+    ab.addAttribute("no-frame-pointer-elim", "true");
+    ab.addAttribute("no-frame-pointer-elim-non-leaf");
+    llvm::AttributeSet as = llvm::AttributeSet::get(context, ab);
+    
+    
+    std::vector<llvm::Function*> functions;
+    auto verifyStream = llvm::raw_os_ostream(logger.debug());
+    
+    // create all llvm functions
+    for (auto& [name, cl] : objects.classes) {
+        for (auto& [name, method] : cl->methods) {
+            Function* func = generateFunction(module.get(), method.get());
+            for (auto a : as) {
+                func->addFnAttr(a);
+            }
+            functions.push_back(func);
+        }
+    }
+    
+    for (auto& [name, method] : objects.functions) {
+        Function* func = generateFunction(module.get(), method.get());
+        for (auto a : as) {
+            func->addFnAttr(a);
+        }
+        functions.push_back(func);
+    }
+
+    for (auto& [name, cl] : objects.classes){
+        for (auto& [name, method] : cl->methods) {
+            if (!method->body)
+                continue;
+            
+            FunctionGenerator fg(*method, module.get(), as);
+            Function* f = fg.generate();
+            logger.debug() << "verifying function: " << method->name << std::endl;
+            bool corrupt = llvm::verifyFunction(*f, &verifyStream);
+            if (corrupt) {
+                module->print(verifyStream, nullptr);
+                throw "corrupt llvm function";
+            }
+#ifdef DEBUGGING
+            printf("verified function: %s\n", method->name.c_str());
+#endif
+        }
+    }
+    for (auto& [name, method] : objects.functions) {
+        if (!method->body)
+            continue;
+        
+        FunctionGenerator fg(*method, module.get(), as);
+        Function* f = fg.generate();
+        logger.debug() << "verifying function: " << method->name << std::endl;
+        bool corrupt = llvm::verifyFunction(*f, &verifyStream);
+        if (corrupt) {
+            module->print(verifyStream, nullptr);
+            throw "corrupt llvm function";
+        }
+#ifdef DEBUGGING
+        printf("verified function: %s\n", method->name.c_str());
+#endif
+    }
+    return module;
+}
+
+
+llvm::Function* generateFunction(llvm::Module* module, sem::Method* method)
+{
+    using llvm::Function;
+    using llvm::Argument;
+    using llvm::Type;
+    using llvm::FunctionType;
+    
+    Type* returnType;
+    if (method->returnType)
+        returnType = method->returnType->getLlvmType(context);
+    else
+        returnType = llvm::Type::getVoidTy(context);
+    
+    std::vector<Type*> argumentTypes;
+    if (method->thisExpression != nullptr) {
+        Type* enclosingType = method->thisExpression->type->getLlvmType(context);
+        argumentTypes.push_back(enclosingType);
+    }
+    
+    for (auto& arg : method->arguments) {
+        Type* argumentType = arg->type->getLlvmType(context);
+        argumentTypes.push_back(argumentType);
+    }
+    
+    FunctionType* funcType = FunctionType::get(
+        returnType, argumentTypes, false);
+#ifdef DEBUGGING
+    printf("looking up llvm type of %s\n", method->name.c_str());
+#endif 
+    if (returnType == nullptr)
+        throw "invalid return type";
+    Function* func = Function::Create(funcType, Function::ExternalLinkage, method->name, module);
+    method->llvmNode = func;
+    
+    // linking alloca instances for funcs
+    auto argIterator = func->arg_begin();
+    if (method->thisExpression != nullptr) {
+        method->thisExpression->allocaInst = &*argIterator;
+        Logger::getInstance().debug() << "allocaInst of this";
+        argIterator++;
+    }
+    
+    size_t argIndex = 0;
+    for (; argIterator != func->arg_end(); argIterator++) {
+        if (argIndex > method->arguments.size())
+            throw "internal error";
+        method->arguments[argIndex]->allocaInst = &*argIterator;
+#ifdef DEBUGGING
+        printf("allocaInst of arg '%s': %p\n", method->arguments[argIndex]->name.c_str(), method->arguments[argIndex]->allocaInst);
+#endif 
+        argIndex++;
+    }
+    
+    //printf("UEEEEEEEE %s\n", method->name.c_str());
+    return func;
+}
+
+
+void generateObjectFile(const std::string& filename, std::unique_ptr<llvm::Module> module, int optLevel)
+{
+    using llvm::legacy::PassManager;
+    using llvm::PassManagerBuilder;
+    using llvm::raw_fd_ostream;
+    using llvm::Target;
+    using llvm::TargetMachine;
+    using llvm::TargetRegistry;
+    using llvm::TargetOptions;
+
+    Logger& logger = Logger::getInstance();
+    logger.debug() << "verifying mod" << std::endl;
+    auto ostr = llvm::raw_os_ostream(logger.debug());
+#ifdef DEBUGGING
+    module->print(ostr, nullptr);
+#endif
+    bool broken = llvm::verifyModule(*module);
+    
+    if (broken)
+        throw "invalid llvm module";
+    
+    logger.debug() << "mod verified" << std::endl;
+
+    llvm::InitializeAllTargetInfos();
+    llvm::InitializeAllTargets();
+    llvm::InitializeAllTargetMCs();
+    llvm::InitializeAllAsmParsers();
+    llvm::InitializeAllAsmPrinters();
+
+    PassManager pm;
+    
+    int sizeLevel = 0;
+    PassManagerBuilder builder;
+    builder.OptLevel = optLevel;
+    builder.SizeLevel = sizeLevel;
+    if (optLevel >= 2) {
+        builder.DisableUnitAtATime = false;
+        builder.DisableUnrollLoops = false;
+        builder.LoopVectorize = true;
+        builder.SLPVectorize = true;
+    }
+
+    builder.populateModulePassManager(pm);
+
+    const char cpu[] = "generic";
+    const char features[] = "";
+
+    std::string error;
+    std::string targetTriple = llvm::sys::getDefaultTargetTriple();
+    const Target* target = TargetRegistry::lookupTarget(targetTriple, error);
+
+    if (!target) {
+        logger.debug() << "could not create target: " << error << std::endl;
+        throw "internal error";
+    }
+
+    TargetOptions targetOptions;
+    auto relocModel = llvm::Optional<llvm::Reloc::Model>(llvm::Reloc::Model::PIC_);
+    std::unique_ptr<TargetMachine> targetMachine(
+        target->createTargetMachine(targetTriple, cpu,
+            features, targetOptions, relocModel));
+
+    std::error_code errorCode;
+    raw_fd_ostream dest(filename, errorCode, llvm::sys::fs::F_None);
+    targetMachine->addPassesToEmitFile(pm, dest,
+//        llvm::LLVMTargetMachine::CGFT_ObjectFile,
+        llvm::TargetMachine::CGFT_ObjectFile);
+
+    pm.run(*module);
+    dest.flush();
+    dest.close();
+
+    return;
+}
+
+} // namespace gen
+} // namespace qlow
+
+
+llvm::Function* qlow::gen::FunctionGenerator::generate(void)
+{
+    using llvm::Function;
+    using llvm::Argument;
+    using llvm::Type;
+    using llvm::FunctionType;
+    using llvm::BasicBlock;
+    using llvm::Value;
+    using llvm::IRBuilder;
+    
+#ifdef DEBUGGING
+    printf("generate function %s\n", method.name.c_str()); 
+#endif
+
+    Function* func = module->getFunction(method.name);
+
+    if (func == nullptr) {
+        throw "internal error: function not found";
+    }
+
+    BasicBlock* bb = BasicBlock::Create(context, "entry", func);
+
+    pushBlock(bb);
+
+    IRBuilder<> builder(context);
+    builder.SetInsertPoint(bb);
+    for (auto& [name, var] : method.body->scope.getLocals()) {
+        if (var.get() == nullptr)
+            throw "wtf null variable";
+        if (var->type == nullptr)
+            throw "wtf null type";
+        
+        llvm::AllocaInst* v = builder.CreateAlloca(var->type->getLlvmType(context));
+        var->allocaInst = v;
+    }
+    
+    for (auto& statement : method.body->statements) {
+#ifdef DEBUGGING
+        printf("statement visit %s\n", statement->toString().c_str());
+#endif
+        statement->accept(statementVisitor, *this);
+    }
+    
+
+#ifdef DEBUGGING
+    printf("End of Function\n");
+#endif
+    
+    //Value* val = llvm::ConstantFP::get(context, llvm::APFloat(5.0));
+    
+    builder.SetInsertPoint(getCurrentBlock());
+    if (method.returnType->equals(sem::NativeType(sem::NativeType::Type::VOID))) {
+        if (!getCurrentBlock()->getTerminator())
+            builder.CreateRetVoid();
+    }
+
+    return func;
+}
+
+
+

+ 57 - 0
src/sem/CodeGeneration.h

@@ -0,0 +1,57 @@
+#ifndef QLOW_CODE_GENERATION_H
+#define QLOW_CODE_GENERATION_H
+
+#include "Semantic.h"
+#include "Builtin.h"
+#include "CodegenVisitor.h"
+
+#include <stack>
+
+#include <llvm/IR/Module.h>
+
+namespace qlow
+{
+namespace gen
+{
+    std::unique_ptr<llvm::Module> generateModule(const sem::GlobalScope& objects);
+    llvm::Function* generateFunction (llvm::Module* module, sem::Method* method);
+    void generateObjectFile(const std::string& name, std::unique_ptr<llvm::Module> module, int optLevel);
+
+    class FunctionGenerator;
+}
+}
+
+class qlow::gen::FunctionGenerator
+{
+    const sem::Method& method;
+    llvm::Module* module;
+    llvm::AttributeSet& attributes;
+
+    std::stack<llvm::BasicBlock*> basicBlocks;
+
+public:
+
+    StatementVisitor statementVisitor;
+    ExpressionCodegenVisitor expressionVisitor;
+    LValueVisitor lvalueVisitor;
+
+    inline FunctionGenerator(const sem::Method& m, llvm::Module* module,
+        llvm::AttributeSet& attributes) :
+        method{ m },
+        module{ module },
+        attributes{ attributes },
+        expressionVisitor{ *this }
+    {
+    }
+
+    llvm::Function* generate(void);
+
+    inline llvm::Module* getModule(void) const { return module; }
+    inline llvm::LLVMContext& getContext(void) const { return module->getContext(); }
+    inline llvm::BasicBlock* getCurrentBlock(void) const { return basicBlocks.top(); }
+    inline void pushBlock(llvm::BasicBlock* bb) { basicBlocks.push(bb); }
+    inline llvm::BasicBlock* popBlock(void) { auto* bb = basicBlocks.top(); basicBlocks.pop(); return bb; }
+};
+
+
+#endif // QLOW_CODE_GENERATION_H

+ 32 - 0
src/sem/Context.cpp

@@ -0,0 +1,32 @@
+#include "Context.h"
+#include "Type.h"
+
+using qlow::sem::Context;
+
+size_t std::hash<std::reference_wrapper<qlow::sem::Type>>::operator() (const std::reference_wrapper<qlow::sem::Type>& t) const
+{
+    return t.get().hash();
+}
+
+
+qlow::sem::TypeId Context::addType(Type&& type) {
+    if (typesMap.contains(type)) {
+        return typesMap[type];
+    }
+    else {
+        types.emplace_back(type);
+        return types.size() - 1;
+    }
+}
+
+
+std::optional<std::reference_wrapper<qlow::sem::Type>> Context::getType(TypeId tid)
+{
+    if (tid >= 0 && tid <= types.size()) {
+        return std::make_optional<std::reference_wrapper<qlow::sem::Type>>(*types[tid]);
+    }
+    else {
+        return std::nullopt;
+    }
+}
+

+ 17 - 12
src/sem/Context.h

@@ -1,34 +1,39 @@
 #ifndef QLOW_SEM_CONTEXT_H
 #define QLOW_SEM_CONTEXT_H
 
-#include "Type.h"
-#include "unordered_map"
+#include <unordered_map>
+#include <memory>
+#include <vector>
+#include <optional>
 
 namespace qlow::sem
 {
+    class Type;
     class Context;
     
     using TypeId = size_t;
 }
 
+namespace std
+{
+    template<>
+    struct std::hash<std::reference_wrapper<qlow::sem::Type>>
+    {
+        size_t operator() (const std::reference_wrapper<qlow::sem::Type>& t) const;
+    };
+}
+
 
 class qlow::sem::Context
 {
 private:
     std::vector<std::unique_ptr<Type>> types;
-    std::unordered_map<Type&, TypeId> typesMap;
+    std::unordered_map<std::reference_wrapper<Type>, TypeId> typesMap;
     
 public:
     
-    TypeId addType(Type&& type) {
-        if (typesMap.contains(type)) {
-            return typesMap[type];
-        }
-        else {
-            types.push_back(std::unique_ptr<Type>(type));
-            return types.size() - 1;
-        }
-    }
+    TypeId addType(Type&& type);
+    std::optional<std::reference_wrapper<Type>> getType(TypeId tid);
 };
 
 #endif // QLOW_SEM_CONTEXT_H

+ 6 - 0
src/sem/Type.cpp

@@ -21,6 +21,12 @@ bool sem::Type::equals(const Type& other) const
 }
 
 
+size_t sem::Type::hash(void) const
+{
+    return std::hash<std::string>()(this->asString());
+}
+
+
 /*std::shared_ptr<sem::Type> sem::Type::VOID =
     std::make_shared<sem::NativeType>(sem::NativeType::Type::VOID);
 std::shared_ptr<sem::Type> sem::Type::INTEGER =

+ 16 - 12
src/sem/Type.h

@@ -1,9 +1,11 @@
 #ifndef QLOW_SEM_TYPE_H
 #define QLOW_SEM_TYPE_H
 
-#include <memory>
 #include "Scope.h"
 
+#include <memory>
+#include <string>
+
 namespace llvm {
     class Value;
     class Type;
@@ -71,6 +73,8 @@ public:
     virtual llvm::Type* getLlvmType(llvm::LLVMContext& context) const = 0;
     
     virtual bool equals(const Type& other) const;
+
+    virtual size_t hash(void) const;
     
 //    static std::shared_ptr<Type> VOID;
 //    static std::shared_ptr<Type> INTEGER;
@@ -115,12 +119,12 @@ public:
     
     inline bool isClassType(void) const override { return true; }
     
-    std::string asString(void) const;
-    Scope& getScope(void);
+    std::string asString(void) const override;
+    Scope& getScope(void) override;
     
     virtual llvm::Type* getLlvmType(llvm::LLVMContext& context) const override;
     inline sem::Class* getClassType(void) { return classType; }
-    virtual bool equals(const Type& other) const;
+    virtual bool equals(const Type& other) const override;
 };
 
 
@@ -138,12 +142,12 @@ public:
     
     inline bool isArrayType(void) const override { return true; }
     
-    std::string asString(void) const;
-    Scope& getScope(void);
+    std::string asString(void) const override;
+    Scope& getScope(void) override;
     
     virtual llvm::Type* getLlvmType(llvm::LLVMContext& context) const override;
     inline std::shared_ptr<sem::Type> getArrayType(void) { return arrayType; }
-    virtual bool equals(const Type& other) const;
+    virtual bool equals(const Type& other) const override;
 };
 
 
@@ -167,20 +171,20 @@ public:
     SymbolTable<NativeMethod> nativeMethods;
     
     inline NativeType(Type type) :
-        type{ type },
-        scope{ *this }
+        scope{ *this },
+        type{ type }
     {
     }
     
     inline bool isNativeType(void) const override { return true; }
     
-    std::string asString(void) const;
-    Scope& getScope(void);
+    std::string asString(void) const override;
+    Scope& getScope(void) override;
     
     bool isIntegerType(void) const;
     
     llvm::Type* getLlvmType(llvm::LLVMContext& context) const override;
-    virtual bool equals(const sem::Type& other) const;
+    virtual bool equals(const sem::Type& other) const override;
     
     /// cast an llvm::Value from another native type to this one
     llvm::Value* generateImplicitCast(llvm::Value* value);