CodeGeneration.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #include "CodeGeneration.h"
  2. #include <llvm/IR/LLVMContext.h>
  3. #include <llvm/IR/LegacyPassManager.h>
  4. #include <llvm/Transforms/IPO/PassManagerBuilder.h>
  5. #include <llvm/IR/Type.h>
  6. #include <llvm/IR/DerivedTypes.h>
  7. #include <llvm/IR/Constants.h>
  8. #include <llvm/IR/BasicBlock.h>
  9. #include <llvm/IR/Verifier.h>
  10. #include <llvm/IR/IRBuilder.h>
  11. #include <llvm/Target/TargetMachine.h>
  12. #include <llvm/Support/TargetRegistry.h>
  13. #include <llvm/Support/TargetSelect.h>
  14. #include <llvm/Support/FileSystem.h>
  15. #include <llvm/Support/raw_os_ostream.h>
  16. using namespace qlow;
  17. static llvm::LLVMContext context;
  18. namespace qlow
  19. {
  20. namespace gen
  21. {
  22. std::unique_ptr<llvm::Module> generateModule(const sem::GlobalScope& objects)
  23. {
  24. using llvm::Module;
  25. using llvm::Function;
  26. using llvm::Argument;
  27. using llvm::Type;
  28. using llvm::FunctionType;
  29. using llvm::BasicBlock;
  30. using llvm::Value;
  31. using llvm::IRBuilder;
  32. Logger& logger = Logger::getInstance();
  33. #ifdef DEBUGGING
  34. printf("creating llvm module\n");
  35. #endif
  36. std::unique_ptr<Module> module = llvm::make_unique<Module>("qlow_module", context);
  37. // create llvm structs
  38. // TODO implement detection of circles
  39. for (auto& [name, cl] : objects.classes){
  40. llvm::StructType* st;
  41. std::vector<llvm::Type*> fields;
  42. #ifdef DEBUGGING
  43. printf("creating llvm struct for %s\n", name.c_str());
  44. #endif
  45. for (auto& [name, field] : cl->fields) {
  46. fields.push_back(field->type->getLlvmType(context));
  47. if (fields[fields.size() - 1] == nullptr)
  48. throw "internal error: possible circular dependency";
  49. }
  50. st = llvm::StructType::create(context, fields, name);
  51. cl->llvmType = st;
  52. }
  53. std::vector<llvm::Function*> functions;
  54. auto verifyStream = llvm::raw_os_ostream(logger.debug());
  55. // create all llvm functions
  56. for (auto& [name, cl] : objects.classes) {
  57. for (auto& [name, method] : cl->methods) {
  58. functions.push_back(generateFunction(module.get(), method.get()));
  59. }
  60. }
  61. for (auto& [name, method] : objects.functions) {
  62. functions.push_back(generateFunction(module.get(), method.get()));
  63. }
  64. for (auto& [name, cl] : objects.classes){
  65. for (auto& [name, method] : cl->methods) {
  66. if (!method->body)
  67. continue;
  68. FunctionGenerator fg(*method, module.get());
  69. Function* f = fg.generate();
  70. logger.debug() << "verifying function: " << method->name << std::endl;
  71. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  72. if (corrupt)
  73. throw "corrupt llvm function";
  74. #ifdef DEBUGGING
  75. printf("verified function: %s\n", method->name.c_str());
  76. #endif
  77. }
  78. }
  79. for (auto& [name, method] : objects.functions) {
  80. if (!method->body)
  81. continue;
  82. FunctionGenerator fg(*method, module.get());
  83. Function* f = fg.generate();
  84. logger.debug() << "verifying function: " << method->name << std::endl;
  85. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  86. if (corrupt)
  87. throw "corrupt llvm function";
  88. #ifdef DEBUGGING
  89. printf("verified function: %s\n", method->name.c_str());
  90. #endif
  91. }
  92. return module;
  93. }
  94. llvm::Function* generateFunction(llvm::Module* module, sem::Method* method)
  95. {
  96. using llvm::Function;
  97. using llvm::Argument;
  98. using llvm::Type;
  99. using llvm::FunctionType;
  100. Type* returnType;
  101. if (method->returnType)
  102. returnType = method->returnType->getLlvmType(context);
  103. else
  104. returnType = llvm::Type::getVoidTy(context);
  105. std::vector<Type*> argumentTypes;
  106. if (method->containingType != nullptr) {
  107. Type* enclosingType = method->containingType->llvmType;
  108. argumentTypes.push_back(enclosingType);
  109. }
  110. for (auto& arg : method->arguments) {
  111. Type* argumentType = arg->type->getLlvmType(context);
  112. argumentTypes.push_back(argumentType);
  113. }
  114. FunctionType* funcType = FunctionType::get(
  115. returnType, argumentTypes, false);
  116. #ifdef DEBUGGING
  117. printf("looking up llvm type of %s\n", method->name.c_str());
  118. #endif
  119. if (returnType == nullptr)
  120. throw "invalid return type";
  121. Function* func = Function::Create(funcType, Function::ExternalLinkage, method->name, module);
  122. method->llvmNode = func;
  123. size_t index = 0;
  124. for (auto& arg : func->args()) {
  125. method->arguments[index]->allocaInst = &arg;
  126. #ifdef DEBUGGING
  127. printf("allocaInst of arg '%s': %p\n", method->arguments[index]->name.c_str(), method->arguments[index]->allocaInst);
  128. #endif
  129. index++;
  130. }
  131. //printf("UEEEEEEEE %s\n", method->name.c_str());
  132. return func;
  133. }
  134. void generateObjectFile(const std::string& filename, std::unique_ptr<llvm::Module> module, int optLevel)
  135. {
  136. using llvm::legacy::PassManager;
  137. using llvm::PassManagerBuilder;
  138. using llvm::raw_fd_ostream;
  139. using llvm::Target;
  140. using llvm::TargetMachine;
  141. using llvm::TargetRegistry;
  142. using llvm::TargetOptions;
  143. Logger& logger = Logger::getInstance();
  144. logger.debug() << "verifying mod" << std::endl;
  145. auto ostr = llvm::raw_os_ostream(logger.debug());
  146. module->print(ostr, nullptr);
  147. bool broken = llvm::verifyModule(*module);
  148. if (broken)
  149. throw "invalid llvm module";
  150. logger.debug() << "mod verified" << std::endl;
  151. llvm::InitializeAllTargetInfos();
  152. llvm::InitializeAllTargets();
  153. llvm::InitializeAllTargetMCs();
  154. llvm::InitializeAllAsmParsers();
  155. llvm::InitializeAllAsmPrinters();
  156. PassManager pm;
  157. int sizeLevel = 0;
  158. PassManagerBuilder builder;
  159. builder.OptLevel = optLevel;
  160. builder.SizeLevel = sizeLevel;
  161. if (optLevel >= 2) {
  162. builder.DisableUnitAtATime = false;
  163. builder.DisableUnrollLoops = false;
  164. builder.LoopVectorize = true;
  165. builder.SLPVectorize = true;
  166. }
  167. builder.populateModulePassManager(pm);
  168. const char cpu[] = "generic";
  169. const char features[] = "";
  170. std::string error;
  171. std::string targetTriple = llvm::sys::getDefaultTargetTriple();
  172. const Target* target = TargetRegistry::lookupTarget(targetTriple, error);
  173. if (!target) {
  174. logger.debug() << "could not create target: " << error << std::endl;
  175. throw "internal error";
  176. }
  177. TargetOptions targetOptions;
  178. auto relocModel = llvm::Optional<llvm::Reloc::Model>(llvm::Reloc::Model::PIC_);
  179. std::unique_ptr<TargetMachine> targetMachine(target->createTargetMachine(targetTriple, cpu,
  180. features, targetOptions, relocModel));
  181. std::error_code errorCode;
  182. raw_fd_ostream dest(filename, errorCode, llvm::sys::fs::F_None);
  183. targetMachine->addPassesToEmitFile(pm, dest, llvm::LLVMTargetMachine::CGFT_ObjectFile,
  184. llvm::TargetMachine::CGFT_ObjectFile);
  185. pm.run(*module);
  186. dest.flush();
  187. dest.close();
  188. return;
  189. }
  190. } // namespace gen
  191. } // namespace qlow
  192. llvm::Function* qlow::gen::FunctionGenerator::generate(void)
  193. {
  194. using llvm::Function;
  195. using llvm::Argument;
  196. using llvm::Type;
  197. using llvm::FunctionType;
  198. using llvm::BasicBlock;
  199. using llvm::Value;
  200. using llvm::IRBuilder;
  201. #ifdef DEBUGGING
  202. printf("generate function %s\n", method.name.c_str());
  203. #endif
  204. Function* func = module->getFunction(method.name);
  205. if (func == nullptr) {
  206. throw "internal error: function not found";
  207. }
  208. BasicBlock* bb = BasicBlock::Create(context, "entry", func);
  209. pushBlock(bb);
  210. IRBuilder<> builder(context);
  211. builder.SetInsertPoint(bb);
  212. for (auto& [name, var] : method.body->scope.getLocals()) {
  213. if (var.get() == nullptr)
  214. throw "wtf null variable";
  215. llvm::AllocaInst* v = builder.CreateAlloca(var->type->getLlvmType(context));
  216. var->allocaInst = v;
  217. }
  218. for (auto& statement : method.body->statements) {
  219. #ifdef DEBUGGING
  220. printf("statement visit %s\n", statement->toString().c_str());
  221. #endif
  222. statement->accept(statementVisitor, *this);
  223. }
  224. #ifdef DEBUGGING
  225. printf("End of Function\n");
  226. #endif
  227. //Value* val = llvm::ConstantFP::get(context, llvm::APFloat(5.0));
  228. builder.SetInsertPoint(getCurrentBlock());
  229. if (method.returnType->equals(sem::NativeType(sem::NativeType::Type::VOID))) {
  230. if (!getCurrentBlock()->getTerminator())
  231. builder.CreateRetVoid();
  232. }
  233. return func;
  234. }