CodeGeneration.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #include "CodeGeneration.h"
  2. #include <llvm/IR/LLVMContext.h>
  3. #include <llvm/IR/LegacyPassManager.h>
  4. #include <llvm/Transforms/IPO/PassManagerBuilder.h>
  5. #include <llvm/IR/Type.h>
  6. #include <llvm/IR/DerivedTypes.h>
  7. #include <llvm/IR/Constants.h>
  8. #include <llvm/IR/BasicBlock.h>
  9. #include <llvm/IR/Verifier.h>
  10. #include <llvm/IR/IRBuilder.h>
  11. #include <llvm/IR/Attributes.h>
  12. #include <llvm/Target/TargetMachine.h>
  13. #include <llvm/Support/TargetRegistry.h>
  14. #include <llvm/Support/TargetSelect.h>
  15. #include <llvm/Support/FileSystem.h>
  16. #include <llvm/Support/raw_os_ostream.h>
  17. using namespace qlow;
  18. static llvm::LLVMContext context;
  19. namespace qlow
  20. {
  21. namespace gen
  22. {
  23. std::unique_ptr<llvm::Module> generateModule(sem::GlobalScope& semantic)
  24. {
  25. using llvm::Module;
  26. using llvm::Function;
  27. using llvm::Argument;
  28. using llvm::Type;
  29. using llvm::FunctionType;
  30. using llvm::BasicBlock;
  31. using llvm::Value;
  32. using llvm::IRBuilder;
  33. Printer& printer = Printer::getInstance();
  34. #ifdef DEBUGGING
  35. printf("creating llvm module\n");
  36. #endif
  37. std::unique_ptr<Module> module = llvm::make_unique<Module>("qlow_module", context);
  38. // create llvm structs
  39. // TODO implement detection of circles
  40. /*
  41. for (const auto& [name, cl] : semantic.getClasses()) {
  42. llvm::StructType* st;
  43. std::vector<llvm::Type*> fields;
  44. #ifdef DEBUGGING
  45. printf("creating llvm struct for %s\n", name.c_str());
  46. #endif
  47. int llvmStructIndex = 0;
  48. // TODO: rewrite
  49. /*for (auto& [name, field] : cl->fields) {
  50. field->llvmStructIndex = llvmStructIndex;
  51. fields.push_back(field->type->getLlvmType(context));
  52. if (fields[fields.size() - 1] == nullptr)
  53. throw "internal error: possible circular dependency";
  54. llvmStructIndex++;
  55. }*//*
  56. st = llvm::StructType::create(context, fields, name);
  57. cl->llvmType = st;
  58. }*/
  59. semantic.getContext().createLlvmTypes(context);
  60. llvm::AttrBuilder ab;
  61. ab.addAttribute(llvm::Attribute::AttrKind::NoInline);
  62. ab.addAttribute(llvm::Attribute::AttrKind::NoUnwind);
  63. //ab.addAttribute(llvm::Attribute::AttrKind::OptimizeNone);
  64. //ab.addAttribute(llvm::Attribute::AttrKind::UWTable);
  65. ab.addAttribute("no-frame-pointer-elim", "true");
  66. ab.addAttribute("no-frame-pointer-elim-non-leaf");
  67. llvm::AttributeSet as = llvm::AttributeSet::get(context, ab);
  68. std::vector<llvm::Function*> functions;
  69. auto verifyStream = llvm::raw_os_ostream(printer);
  70. // create all llvm functions
  71. for (const auto& [name, cl] : semantic.getClasses()) {
  72. for (const auto& [name, method] : cl->methods) {
  73. Function* func = generateFunction(module.get(), method.get());
  74. for (auto a : as) {
  75. func->addFnAttr(a);
  76. }
  77. functions.push_back(func);
  78. }
  79. }
  80. for (const auto& [name, method] : semantic.getMethods()) {
  81. Function* func = generateFunction(module.get(), method.get());
  82. for (auto a : as) {
  83. func->addFnAttr(a);
  84. }
  85. functions.push_back(func);
  86. }
  87. for (const auto& [name, cl] : semantic.getClasses()){
  88. for (const auto& [name, method] : cl->methods) {
  89. if (!method->body)
  90. continue;
  91. FunctionGenerator fg(*method, module.get(), as);
  92. Function* f = fg.generate();
  93. // printer << "verifying function: " << method->name << std::endl;
  94. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  95. if (corrupt) {
  96. #ifdef DEBUGGING
  97. //module->print(verifyStream, nullptr);
  98. #endif
  99. throw (std::string("corrupt llvm function: ") + method->name).c_str();
  100. }
  101. #ifdef DEBUGGING
  102. printf("verified function: %s\n", method->name.c_str());
  103. #endif
  104. }
  105. }
  106. for (const auto& [name, method] : semantic.getMethods()) {
  107. if (!method->body)
  108. continue;
  109. FunctionGenerator fg(*method, module.get(), as);
  110. Function* f = fg.generate();
  111. //printer.debug() << "verifying function: " << method->name << std::endl;
  112. bool corrupt = llvm::verifyFunction(*f, &verifyStream);
  113. if (corrupt) {
  114. f->print(verifyStream, nullptr);
  115. throw (std::string("corrupt llvm function: ") + method->name).c_str();
  116. }
  117. #ifdef DEBUGGING
  118. printf("verified function: %s\n", method->name.c_str());
  119. #endif
  120. }
  121. return module;
  122. }
  123. llvm::Function* generateFunction(llvm::Module* module, sem::Method* method)
  124. {
  125. sem::Context& semCtxt = method->context;
  126. using llvm::Function;
  127. using llvm::Argument;
  128. using llvm::Type;
  129. using llvm::FunctionType;
  130. Type* returnType;
  131. if (method->returnType)
  132. returnType = semCtxt.getLlvmType(method->returnType, context);
  133. else
  134. returnType = llvm::Type::getVoidTy(context);
  135. std::vector<Type*> argumentTypes;
  136. if (method->thisExpression != nullptr) {
  137. Type* enclosingType = semCtxt.getLlvmType(method->thisExpression->type, context);
  138. argumentTypes.push_back(enclosingType);
  139. }
  140. for (auto& arg : method->arguments) {
  141. Type* argumentType = semCtxt.getLlvmType(arg->type, context);
  142. argumentTypes.push_back(argumentType);
  143. }
  144. FunctionType* funcType = FunctionType::get(
  145. returnType, argumentTypes, false);
  146. #ifdef DEBUGGING
  147. printf("looking up llvm type of %s\n", method->name.c_str());
  148. #endif
  149. if (returnType == nullptr)
  150. throw "invalid return type";
  151. Function* func = Function::Create(funcType, Function::ExternalLinkage, method->name, module);
  152. method->llvmNode = func;
  153. // linking alloca instances for funcs
  154. auto argIterator = func->arg_begin();
  155. if (method->thisExpression != nullptr) {
  156. method->thisExpression->allocaInst = &*argIterator;
  157. #ifdef DEBUGGING
  158. Printer::getInstance() << "allocaInst of this";
  159. #endif
  160. argIterator++;
  161. }
  162. size_t argIndex = 0;
  163. for (; argIterator != func->arg_end(); argIterator++) {
  164. if (argIndex > method->arguments.size())
  165. throw "internal error";
  166. method->arguments[argIndex]->allocaInst = &*argIterator;
  167. #ifdef DEBUGGING
  168. printf("allocaInst of arg '%s': %p\n", method->arguments[argIndex]->name.c_str(), method->arguments[argIndex]->allocaInst);
  169. #endif
  170. argIndex++;
  171. }
  172. //printf("UEEEEEEEE %s\n", method->name.c_str());
  173. return func;
  174. }
  175. void generateObjectFile(const std::string& filename, std::unique_ptr<llvm::Module> module, int optLevel)
  176. {
  177. using llvm::legacy::PassManager;
  178. using llvm::PassManagerBuilder;
  179. using llvm::raw_fd_ostream;
  180. using llvm::Target;
  181. using llvm::TargetMachine;
  182. using llvm::TargetRegistry;
  183. using llvm::TargetOptions;
  184. Printer& printer = Printer::getInstance();
  185. #ifdef DEBUGGING
  186. printer << "verifying mod" << std::endl;
  187. #endif
  188. auto ostr = llvm::raw_os_ostream(printer);
  189. #ifdef DEBUGGING
  190. module->print(ostr, nullptr);
  191. #endif
  192. bool broken = llvm::verifyModule(*module);
  193. if (broken)
  194. throw "invalid llvm module";
  195. llvm::InitializeAllTargetInfos();
  196. llvm::InitializeAllTargets();
  197. llvm::InitializeAllTargetMCs();
  198. llvm::InitializeAllAsmParsers();
  199. llvm::InitializeAllAsmPrinters();
  200. PassManager pm;
  201. int sizeLevel = 0;
  202. PassManagerBuilder builder;
  203. builder.OptLevel = optLevel;
  204. builder.SizeLevel = sizeLevel;
  205. if (optLevel >= 2) {
  206. builder.DisableUnitAtATime = false;
  207. builder.DisableUnrollLoops = false;
  208. builder.LoopVectorize = true;
  209. builder.SLPVectorize = true;
  210. }
  211. builder.populateModulePassManager(pm);
  212. const char cpu[] = "generic";
  213. const char features[] = "";
  214. std::string error;
  215. std::string targetTriple = llvm::sys::getDefaultTargetTriple();
  216. const Target* target = TargetRegistry::lookupTarget(targetTriple, error);
  217. if (!target) {
  218. #ifdef DEBUGGING
  219. printer << "could not create target: " << error << std::endl;
  220. #endif
  221. throw "internal error";
  222. }
  223. TargetOptions targetOptions;
  224. auto relocModel = llvm::Optional<llvm::Reloc::Model>(llvm::Reloc::Model::PIC_);
  225. TargetMachine* targetMachine = target->createTargetMachine(targetTriple, cpu,
  226. features, targetOptions, relocModel);
  227. std::error_code errorCode;
  228. raw_fd_ostream dest(filename, errorCode, llvm::sys::fs::F_None);
  229. #ifdef DEBUGGING
  230. printer << "adding passes" << std::endl;
  231. #endif
  232. targetMachine->addPassesToEmitFile(pm, dest,
  233. // llvm::LLVMTargetMachine::CGFT_ObjectFile,
  234. nullptr,
  235. llvm::TargetMachine::CGFT_ObjectFile);
  236. pm.run(*module);
  237. dest.flush();
  238. dest.close();
  239. return;
  240. }
  241. } // namespace gen
  242. } // namespace qlow
  243. llvm::Function* qlow::gen::FunctionGenerator::generate(void)
  244. {
  245. using llvm::Function;
  246. using llvm::Argument;
  247. using llvm::Type;
  248. using llvm::FunctionType;
  249. using llvm::BasicBlock;
  250. using llvm::Value;
  251. using llvm::IRBuilder;
  252. sem::Context& semCtxt = this->method.context;
  253. #ifdef DEBUGGING
  254. printf("generate function %s\n", method.name.c_str());
  255. #endif
  256. Function* func = module->getFunction(method.name);
  257. if (func == nullptr) {
  258. throw "internal error: function not found";
  259. }
  260. BasicBlock* bb = BasicBlock::Create(context, "entry", func);
  261. pushBlock(bb);
  262. IRBuilder<> builder(context);
  263. builder.SetInsertPoint(bb);
  264. for (auto& [name, var] : method.body->scope.getLocals()) {
  265. if (var.get() == nullptr)
  266. throw "wtf null variable";
  267. if (var->type == sem::NO_TYPE)
  268. throw "wtf null type";
  269. llvm::AllocaInst* v = builder.CreateAlloca(semCtxt.getLlvmType(var->type, context));
  270. var->allocaInst = v;
  271. }
  272. for (auto& statement : method.body->statements) {
  273. #ifdef DEBUGGING
  274. printf("statement visit %s\n", statement->toString().c_str());
  275. #endif
  276. statement->accept(statementVisitor, *this);
  277. }
  278. #ifdef DEBUGGING
  279. printf("End of Function\n");
  280. #endif
  281. //Value* val = llvm::ConstantFP::get(context, llvm::APFloat(5.0));
  282. builder.SetInsertPoint(getCurrentBlock());
  283. //if (method.returnType->equals(sem::NativeType(sem::NativeType::Type::VOID))) {
  284. if (method.returnType == sem::NO_TYPE ||
  285. (semCtxt.getType(method.returnType).getKind() == sem::Type::Kind::NATIVE &&
  286. semCtxt.getType(method.returnType).getNativeKind() == sem::Type::Native::VOID)) {
  287. if (!getCurrentBlock()->getTerminator())
  288. builder.CreateRetVoid();
  289. }
  290. return func;
  291. }