CodeGeneration.cpp 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. #include "CodeGeneration.h"
  2. #include <llvm/IR/LLVMContext.h>
  3. #include <llvm/IR/LegacyPassManager.h>
  4. #include <llvm/Transforms/IPO/PassManagerBuilder.h>
  5. #include <llvm/IR/Type.h>
  6. #include <llvm/IR/DerivedTypes.h>
  7. #include <llvm/IR/Constants.h>
  8. #include <llvm/IR/BasicBlock.h>
  9. #include <llvm/IR/Verifier.h>
  10. #include <llvm/IR/IRBuilder.h>
  11. #include <llvm/Target/TargetMachine.h>
  12. #include <llvm/Support/TargetRegistry.h>
  13. #include <llvm/Support/TargetSelect.h>
  14. #include <llvm/Support/FileSystem.h>
  15. #include <llvm/Support/raw_os_ostream.h>
  16. using namespace qlow;
  17. static llvm::LLVMContext context;
  18. namespace qlow
  19. {
  20. namespace gen
  21. {
  22. std::unique_ptr<llvm::Module> generateModule(const sem::GlobalScope& objects)
  23. {
  24. using llvm::Module;
  25. using llvm::Function;
  26. using llvm::Argument;
  27. using llvm::Type;
  28. using llvm::FunctionType;
  29. using llvm::BasicBlock;
  30. using llvm::Value;
  31. using llvm::IRBuilder;
  32. Logger& logger = Logger::getInstance();
  33. #ifdef DEBUGGING
  34. printf("creating llvm module\n");
  35. #endif
  36. std::unique_ptr<Module> module = llvm::make_unique<Module>("qlow_module", context);
  37. // create llvm structs
  38. // TODO implement detection of circles
  39. for (auto& [name, cl] : objects.classes){
  40. llvm::StructType* st;
  41. std::vector<llvm::Type*> fields;
  42. #ifdef DEBUGGING
  43. printf("creating llvm struct for %s\n", name.c_str());
  44. #endif
  45. for (auto& [name, field] : cl->fields) {
  46. fields.push_back(field->type->getLlvmType(context));
  47. if (fields[fields.size() - 1] == nullptr)
  48. throw "internal error: possible circular dependency";
  49. }
  50. st = llvm::StructType::create(context, fields, name);
  51. cl->llvmType = st;
  52. }
  53. std::vector<llvm::Function*> functions;
  54. auto verifyStream = llvm::raw_os_ostream(logger.debug());
  55. // create all llvm functions
  56. for (auto& [name, cl] : objects.classes) {
  57. for (auto& [name, method] : cl->methods) {
  58. functions.push_back(generateFunction(module.get(), method.get()));
  59. }
  60. }
  61. for (auto& [name, method] : objects.functions) {
  62. functions.push_back(generateFunction(module.get(), method.get()));
  63. }
  64. for (auto& [name, cl] : objects.classes){
  65. for (auto& [name, method] : cl->methods) {
  66. if (!method->body)
  67. continue;
  68. FunctionGenerator fg(*method, module.get());
  69. Function* f = fg.generate();
  70. logger.debug() << "verifying function: " << method->name << std::endl;
  71. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  72. if (corrupt)
  73. throw "corrupt llvm function";
  74. #ifdef DEBUGGING
  75. printf("verified function: %s\n", method->name.c_str());
  76. #endif
  77. }
  78. }
  79. for (auto& [name, method] : objects.functions) {
  80. if (!method->body)
  81. continue;
  82. FunctionGenerator fg(*method, module.get());
  83. Function* f = fg.generate();
  84. logger.debug() << "verifying function: " << method->name << std::endl;
  85. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  86. if (corrupt)
  87. throw "corrupt llvm function";
  88. #ifdef DEBUGGING
  89. printf("verified function: %s\n", method->name.c_str());
  90. #endif
  91. }
  92. return module;
  93. }
  94. llvm::Function* generateFunction(llvm::Module* module, sem::Method* method)
  95. {
  96. using llvm::Function;
  97. using llvm::Argument;
  98. using llvm::Type;
  99. using llvm::FunctionType;
  100. std::vector<Type*> argumentTypes;
  101. Type* returnType;
  102. if (method->returnType)
  103. returnType = method->returnType->getLlvmType(context);
  104. else
  105. returnType = llvm::Type::getVoidTy(context);
  106. for (auto& arg : method->arguments) {
  107. Type* argumentType = arg->type->getLlvmType(context);
  108. argumentTypes.push_back(argumentType);
  109. }
  110. FunctionType* funcType = FunctionType::get(
  111. returnType, argumentTypes, false);
  112. #ifdef DEBUGGING
  113. printf("looking up llvm type of %s\n", method->name.c_str());
  114. #endif
  115. if (returnType == nullptr)
  116. throw "invalid return type";
  117. Function* func = Function::Create(funcType, Function::ExternalLinkage, method->name, module);
  118. method->llvmNode = func;
  119. size_t index = 0;
  120. for (auto& arg : func->args()) {
  121. method->arguments[index]->allocaInst = &arg;
  122. #ifdef DEBUGGING
  123. printf("allocaInst of arg '%s': %p\n", method->arguments[index]->name.c_str(), method->arguments[index]->allocaInst);
  124. #endif
  125. index++;
  126. }
  127. //printf("UEEEEEEEE %s\n", method->name.c_str());
  128. return func;
  129. }
  130. void generateObjectFile(const std::string& filename, std::unique_ptr<llvm::Module> module, int optLevel)
  131. {
  132. using llvm::legacy::PassManager;
  133. using llvm::PassManagerBuilder;
  134. using llvm::raw_fd_ostream;
  135. using llvm::Target;
  136. using llvm::TargetMachine;
  137. using llvm::TargetRegistry;
  138. using llvm::TargetOptions;
  139. Logger& logger = Logger::getInstance();
  140. logger.debug() << "verifying mod" << std::endl;
  141. auto ostr = llvm::raw_os_ostream(logger.debug());
  142. module->print(ostr, nullptr);
  143. bool broken = llvm::verifyModule(*module);
  144. if (broken)
  145. throw "invalid llvm module";
  146. logger.debug() << "mod verified" << std::endl;
  147. llvm::InitializeAllTargetInfos();
  148. llvm::InitializeAllTargets();
  149. llvm::InitializeAllTargetMCs();
  150. llvm::InitializeAllAsmParsers();
  151. llvm::InitializeAllAsmPrinters();
  152. PassManager pm;
  153. int sizeLevel = 0;
  154. PassManagerBuilder builder;
  155. builder.OptLevel = optLevel;
  156. builder.SizeLevel = sizeLevel;
  157. if (optLevel >= 2) {
  158. builder.DisableUnitAtATime = false;
  159. builder.DisableUnrollLoops = false;
  160. builder.LoopVectorize = true;
  161. builder.SLPVectorize = true;
  162. }
  163. builder.populateModulePassManager(pm);
  164. const char cpu[] = "generic";
  165. const char features[] = "";
  166. std::string error;
  167. std::string targetTriple = llvm::sys::getDefaultTargetTriple();
  168. const Target* target = TargetRegistry::lookupTarget(targetTriple, error);
  169. if (!target) {
  170. logger.debug() << "could not create target: " << error << std::endl;
  171. throw "internal error";
  172. }
  173. TargetOptions targetOptions;
  174. auto relocModel = llvm::Optional<llvm::Reloc::Model>(llvm::Reloc::Model::PIC_);
  175. std::unique_ptr<TargetMachine> targetMachine(target->createTargetMachine(targetTriple, cpu,
  176. features, targetOptions, relocModel));
  177. std::error_code errorCode;
  178. raw_fd_ostream dest(filename, errorCode, llvm::sys::fs::F_None);
  179. targetMachine->addPassesToEmitFile(pm, dest, llvm::LLVMTargetMachine::CGFT_ObjectFile,
  180. llvm::TargetMachine::CGFT_ObjectFile);
  181. pm.run(*module);
  182. dest.flush();
  183. dest.close();
  184. return;
  185. }
  186. } // namespace gen
  187. } // namespace qlow
  188. llvm::Function* qlow::gen::FunctionGenerator::generate(void)
  189. {
  190. using llvm::Function;
  191. using llvm::Argument;
  192. using llvm::Type;
  193. using llvm::FunctionType;
  194. using llvm::BasicBlock;
  195. using llvm::Value;
  196. using llvm::IRBuilder;
  197. #ifdef DEBUGGING
  198. printf("generate function %s\n", method.name.c_str());
  199. #endif
  200. Function* func = module->getFunction(method.name);
  201. if (func == nullptr) {
  202. throw "internal error: function not found";
  203. }
  204. BasicBlock* bb = BasicBlock::Create(context, "entry", func);
  205. pushBlock(bb);
  206. IRBuilder<> builder(context);
  207. builder.SetInsertPoint(bb);
  208. for (auto& [name, var] : method.body->scope.getLocals()) {
  209. if (var.get() == nullptr)
  210. throw "wtf null variable";
  211. llvm::AllocaInst* v = builder.CreateAlloca(var->type->getLlvmType(context));
  212. var->allocaInst = v;
  213. }
  214. for (auto& statement : method.body->statements) {
  215. #ifdef DEBUGGING
  216. printf("statement visit %s\n", statement->toString().c_str());
  217. #endif
  218. statement->accept(statementVisitor, *this);
  219. }
  220. #ifdef DEBUGGING
  221. printf("End of Function\n");
  222. #endif
  223. //Value* val = llvm::ConstantFP::get(context, llvm::APFloat(5.0));
  224. builder.SetInsertPoint(getCurrentBlock());
  225. if (method.returnType->equals(sem::NativeType(sem::NativeType::Type::VOID))) {
  226. if (!getCurrentBlock()->getTerminator())
  227. builder.CreateRetVoid();
  228. }
  229. return func;
  230. }