diff --git a/08/ast.cpp b/08/ast.cpp index c04eb62..453a12b 100644 --- a/08/ast.cpp +++ b/08/ast.cpp @@ -1,5 +1,6 @@ #include "ast.hpp" #include +#include "binop.hpp" #include "error.hpp" static void print_indent(int n, std::ostream& to) { @@ -104,7 +105,7 @@ void ast_binop::compile(const env_ptr& env, std::vector& into) right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); - into.push_back(instruction_ptr(new instruction_pushglobal(op_name(op)))); + into.push_back(instruction_ptr(new instruction_pushglobal(op_action(op)))); into.push_back(instruction_ptr(new instruction_mkapp())); into.push_back(instruction_ptr(new instruction_mkapp())); } diff --git a/08/definition.cpp b/08/definition.cpp index b3fb4df..25aabee 100644 --- a/08/definition.cpp +++ b/08/definition.cpp @@ -1,6 +1,10 @@ #include "definition.hpp" #include "error.hpp" #include "ast.hpp" +#include "llvm_context.hpp" +#include +#include +#include void definition_defn::typecheck_first(type_mgr& mgr, type_env& env) { return_type = mgr.new_type(); @@ -50,6 +54,18 @@ void definition_defn::compile() { body->compile(new_env, instructions); instructions.push_back(instruction_ptr(new instruction_update(params.size()))); } +void definition_defn::gen_llvm_first(llvm_context& ctx) { + generated_function = ctx.create_custom_function(name, params.size()); +} + +void definition_defn::gen_llvm_second(llvm_context& ctx) { + ctx.builder.SetInsertPoint(&generated_function->getEntryBlock()); + for(auto& instruction : instructions) { + instruction->gen_llvm(ctx, generated_function); + } + ctx.create_popn(generated_function, ctx.create_size(params.size())); + ctx.builder.CreateRetVoid(); +} void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { type_data* this_type = new type_data(name); @@ -57,6 +73,7 @@ void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { int next_tag = 0; for(auto& constructor : constructors) { + constructor->tag = next_tag; this_type->constructors[constructor->name] = { next_tag++ }; type_ptr full_type = return_type; @@ -80,3 +97,20 @@ void definition_data::resolve(const type_mgr& mgr) { void definition_data::compile() { } + +void definition_data::gen_llvm_first(llvm_context& ctx) { + for(auto& constructor : constructors) { + auto new_function = + ctx.create_custom_function(constructor->name, constructor->types.size()); + ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); + ctx.create_pack(new_function, + ctx.create_size(constructor->types.size()), + ctx.create_i8(constructor->tag) + ); + ctx.builder.CreateRetVoid(); + } +} + +void definition_data::gen_llvm_second(llvm_context& ctx) { + // Nothing +} diff --git a/08/definition.hpp b/08/definition.hpp index 1431b8b..6004d2f 100644 --- a/08/definition.hpp +++ b/08/definition.hpp @@ -2,6 +2,7 @@ #include #include #include "instruction.hpp" +#include "llvm_context.hpp" #include "type_env.hpp" struct ast; @@ -14,6 +15,8 @@ struct definition { virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; virtual void resolve(const type_mgr& mgr) = 0; virtual void compile() = 0; + virtual void gen_llvm_first(llvm_context& ctx) = 0; + virtual void gen_llvm_second(llvm_context& ctx) = 0; }; using definition_ptr = std::unique_ptr; @@ -21,6 +24,7 @@ using definition_ptr = std::unique_ptr; struct constructor { std::string name; std::vector types; + int8_t tag; constructor(std::string n, std::vector ts) : name(std::move(n)), types(std::move(ts)) {} @@ -38,6 +42,8 @@ struct definition_defn : public definition { std::vector instructions; + llvm::Function* generated_function; + definition_defn(std::string n, std::vector p, ast_ptr b) : name(std::move(n)), params(std::move(p)), body(std::move(b)) { @@ -47,6 +53,8 @@ struct definition_defn : public definition { void typecheck_second(type_mgr& mgr, const type_env& env) const; void resolve(const type_mgr& mgr); void compile(); + void gen_llvm_first(llvm_context& ctx); + void gen_llvm_second(llvm_context& ctx); }; struct definition_data : public definition { @@ -60,4 +68,6 @@ struct definition_data : public definition { void typecheck_second(type_mgr& mgr, const type_env& env) const; void resolve(const type_mgr& mgr); void compile(); + void gen_llvm_first(llvm_context& ctx); + void gen_llvm_second(llvm_context& ctx); }; diff --git a/08/instruction.cpp b/08/instruction.cpp index a845da4..2d2854a 100644 --- a/08/instruction.cpp +++ b/08/instruction.cpp @@ -1,5 +1,6 @@ #include "instruction.hpp" #include "llvm_context.hpp" +#include #include using namespace llvm; @@ -23,7 +24,9 @@ void instruction_pushglobal::print(int indent, std::ostream& to) const { } void instruction_pushglobal::gen_llvm(llvm_context& ctx, Function* f) const { - // TODO + auto& global_f = ctx.custom_functions.at("f_" + name); + auto arity = ctx.create_i32(global_f->arity); + ctx.create_push(f, ctx.create_global(global_f->function, arity)); } void instruction_push::print(int indent, std::ostream& to) const { @@ -87,7 +90,27 @@ void instruction_jump::print(int indent, std::ostream& to) const { } void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const { - // TODO + auto top_node = ctx.create_peek(f, ctx.create_size(0)); + auto tag = ctx.unwrap_data_tag(top_node); + auto safety_block = BasicBlock::Create(ctx.ctx, "safety", f); + auto switch_op = ctx.builder.CreateSwitch(tag, safety_block, tag_mappings.size()); + std::vector blocks; + + for(auto& branch : branches) { + auto branch_block = BasicBlock::Create(ctx.ctx, "branch", f); + ctx.builder.SetInsertPoint(branch_block); + for(auto& instruction : branch) { + instruction->gen_llvm(ctx, f); + } + ctx.builder.CreateBr(safety_block); + blocks.push_back(branch_block); + } + + for(auto& mapping : tag_mappings) { + switch_op->addCase(ctx.create_i8(mapping.first), blocks[mapping.second]); + } + + ctx.builder.SetInsertPoint(safety_block); } void instruction_slide::print(int indent, std::ostream& to) const { @@ -105,8 +128,8 @@ void instruction_binop::print(int indent, std::ostream& to) const { } void instruction_binop::gen_llvm(llvm_context& ctx, Function* f) const { - auto left_int = ctx.unwrap_num(ctx.create_pop(f)); - auto right_int = ctx.unwrap_num(ctx.create_pop(f)); + auto left_int = ctx.unwrap_num(ctx.create_eval(ctx.create_pop(f))); + auto right_int = ctx.unwrap_num(ctx.create_eval(ctx.create_pop(f))); llvm::Value* result; switch(op) { case PLUS: result = ctx.builder.CreateAdd(left_int, right_int); break; diff --git a/08/llvm_context.cpp b/08/llvm_context.cpp index 5aadeef..b2982f9 100644 --- a/08/llvm_context.cpp +++ b/08/llvm_context.cpp @@ -67,49 +67,49 @@ void llvm_context::create_functions() { functions["stack_pop"] = Function::Create( FunctionType::get(node_ptr_type, { stack_ptr_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_pop", &module ); functions["stack_peek"] = Function::Create( FunctionType::get(node_ptr_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_peek", &module ); functions["stack_popn"] = Function::Create( FunctionType::get(void_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_popn", &module ); functions["stack_slide"] = Function::Create( FunctionType::get(void_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_slide", &module ); functions["stack_update"] = Function::Create( FunctionType::get(void_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_update", &module ); functions["stack_alloc"] = Function::Create( FunctionType::get(void_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_alloc", &module ); functions["stack_pack"] = Function::Create( FunctionType::get(void_type, { stack_ptr_type, sizet_type, tag_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_pack", &module ); functions["stack_split"] = Function::Create( FunctionType::get(node_ptr_type, { stack_ptr_type, sizet_type }, false), Function::LinkageTypes::ExternalLinkage, - "stack_push", + "stack_split", &module ); @@ -147,13 +147,13 @@ void llvm_context::create_functions() { ); } -Value* llvm_context::create_i8(int8_t i) { +ConstantInt* llvm_context::create_i8(int8_t i) { return ConstantInt::get(ctx, APInt(8, i)); } -Value* llvm_context::create_i32(int32_t i) { +ConstantInt* llvm_context::create_i32(int32_t i) { return ConstantInt::get(ctx, APInt(32, i)); } -Value* llvm_context::create_size(size_t i) { +ConstantInt* llvm_context::create_size(size_t i) { return ConstantInt::get(ctx, APInt(sizeof(size_t) * 8, i)); } @@ -202,8 +202,8 @@ Value* llvm_context::create_eval(Value* e) { Value* llvm_context::unwrap_num(Value* v) { auto num_ptr_type = PointerType::getUnqual(struct_types.at("node_num")); auto cast = builder.CreatePointerCast(v, num_ptr_type); - auto offset_0 = create_size(0); - auto offset_1 = create_size(1); + auto offset_0 = create_i32(0); + auto offset_1 = create_i32(1); auto int_ptr = builder.CreateGEP(cast, { offset_0, offset_1 }); return builder.CreateLoad(int_ptr); } @@ -212,6 +212,15 @@ Value* llvm_context::create_num(Value* v) { return builder.CreateCall(alloc_num_f, { v }); } +Value* llvm_context::unwrap_data_tag(Value* v) { + auto data_ptr_type = PointerType::getUnqual(struct_types.at("node_data")); + auto cast = builder.CreatePointerCast(v, data_ptr_type); + auto offset_0 = create_i32(0); + auto offset_1 = create_i32(1); + auto tag_ptr = builder.CreateGEP(cast, { offset_0, offset_1 }); + return builder.CreateLoad(tag_ptr); +} + Value* llvm_context::create_global(Value* f, Value* a) { auto alloc_global_f = functions.at("alloc_global"); return builder.CreateCall(alloc_global_f, { f, a }); @@ -221,3 +230,23 @@ Value* llvm_context::create_app(Value* l, Value* r) { auto alloc_app_f = functions.at("alloc_app"); return builder.CreateCall(alloc_app_f, { l, r }); } + +llvm::Function* llvm_context::create_custom_function(std::string name, int32_t arity) { + auto void_type = llvm::Type::getVoidTy(ctx); + auto function_type = + llvm::FunctionType::get(void_type, { stack_ptr_type }, false); + auto new_function = llvm::Function::Create( + function_type, + llvm::Function::LinkageTypes::ExternalLinkage, + "f_" + name, + &module + ); + auto start_block = llvm::BasicBlock::Create(ctx, "entry", new_function); + + auto new_custom_f = custom_function_ptr(new custom_function()); + new_custom_f->arity = arity; + new_custom_f->function = new_function; + custom_functions["f_" + name] = std::move(new_custom_f); + + return new_function; +} diff --git a/08/llvm_context.hpp b/08/llvm_context.hpp index 5067ba1..ca75dca 100644 --- a/08/llvm_context.hpp +++ b/08/llvm_context.hpp @@ -7,10 +7,18 @@ #include struct llvm_context { + struct custom_function { + llvm::Function* function; + int32_t arity; + }; + + using custom_function_ptr = std::unique_ptr; + llvm::LLVMContext ctx; llvm::IRBuilder<> builder; llvm::Module module; + std::map custom_functions; std::map functions; std::map struct_types; @@ -29,9 +37,9 @@ struct llvm_context { void create_types(); void create_functions(); - llvm::Value* create_i8(int8_t); - llvm::Value* create_i32(int32_t); - llvm::Value* create_size(size_t); + llvm::ConstantInt* create_i8(int8_t); + llvm::ConstantInt* create_i32(int32_t); + llvm::ConstantInt* create_size(size_t); llvm::Value* create_pop(llvm::Function*); llvm::Value* create_peek(llvm::Function*, llvm::Value*); @@ -48,7 +56,11 @@ struct llvm_context { llvm::Value* unwrap_num(llvm::Value*); llvm::Value* create_num(llvm::Value*); + llvm::Value* unwrap_data_tag(llvm::Value*); + llvm::Value* create_global(llvm::Value*, llvm::Value*); llvm::Value* create_app(llvm::Value*, llvm::Value*); + + llvm::Function* create_custom_function(std::string name, int32_t arity); }; diff --git a/08/main.cpp b/08/main.cpp index 728e2a6..c3d6047 100644 --- a/08/main.cpp +++ b/08/main.cpp @@ -1,9 +1,20 @@ #include "ast.hpp" #include +#include "binop.hpp" #include "definition.hpp" +#include "instruction.hpp" +#include "llvm_context.hpp" #include "parser.hpp" #include "error.hpp" #include "type.hpp" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/Target/TargetMachine.h" void yy::parser::error(const std::string& msg) { std::cout << "An error occured: " << msg << std::endl; @@ -56,6 +67,73 @@ void compile_program(const std::vector& prog) { } } +void gen_llvm_internal_op(llvm_context& ctx, binop op) { + auto new_function = ctx.create_custom_function(op_action(op), 2); + auto new_instruction = instruction_ptr(new instruction_binop(op)); + ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); + new_instruction->gen_llvm(ctx, new_function); + ctx.builder.CreateRetVoid(); +} + +void output_llvm(llvm_context& ctx, const std::string& filename) { + std::string targetTriple = llvm::sys::getDefaultTargetTriple(); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmParser(); + llvm::InitializeNativeTargetAsmPrinter(); + + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(targetTriple, error); + if (!target) { + std::cerr << error << std::endl; + } else { + std::string cpu = "generic"; + std::string features = ""; + llvm::TargetOptions options; + llvm::TargetMachine* targetMachine = + target->createTargetMachine(targetTriple, cpu, features, + options, llvm::Optional()); + + ctx.module.setDataLayout(targetMachine->createDataLayout()); + ctx.module.setTargetTriple(targetTriple); + + std::error_code ec; + llvm::raw_fd_ostream file(filename, ec, llvm::sys::fs::F_None); + if (ec) { + std::cerr << "Could not open output file: " << ec.message() << std::endl; + } else { + llvm::TargetMachine::CodeGenFileType type = llvm::TargetMachine::CGFT_ObjectFile; + llvm::legacy::PassManager pm; + if (targetMachine->addPassesToEmitFile(pm, file, NULL, type)) { + std::cerr << "Unable to emit target code" << std::endl; + } else { + pm.run(ctx.module); + file.close(); + } + } + } +} + +void gen_llvm(const std::vector& prog) { + llvm_context ctx; + gen_llvm_internal_op(ctx, PLUS); + gen_llvm_internal_op(ctx, MINUS); + gen_llvm_internal_op(ctx, TIMES); + gen_llvm_internal_op(ctx, DIVIDE); + + for(auto& definition : prog) { + definition->gen_llvm_first(ctx); + } + + for(auto& definition : prog) { + definition->gen_llvm_second(ctx); + } + llvm::verifyModule(ctx.module); + ctx.module.print(llvm::outs(), nullptr); + output_llvm(ctx, "program.o"); +} + int main() { yy::parser parser; type_mgr mgr; @@ -75,6 +153,7 @@ int main() { try { typecheck_program(program, mgr, env); compile_program(program); + gen_llvm(program); } catch(unification_error& err) { std::cout << "failed to unify types: " << std::endl; std::cout << " (1) \033[34m"; diff --git a/08/runtime.c b/08/runtime.c index c7bc5b0..ebae7ae 100644 --- a/08/runtime.c +++ b/08/runtime.c @@ -1,6 +1,7 @@ #include #include #include +#include #include "runtime.h" struct node_base* alloc_node() { @@ -97,7 +98,7 @@ void stack_pack(struct stack* s, size_t n, int8_t t) { struct node_base** data = malloc(sizeof(*data) * n); assert(data != NULL); - memcpy(data, &s->data[s->count - 1 - n], n * sizeof(*data)); + memcpy(data, &s->data[s->count - n], n * sizeof(*data)); struct node_data* new_node = (struct node_data*) alloc_node(); new_node->array = data; @@ -153,7 +154,30 @@ struct node_base* eval(struct node_base* n) { extern void f_main(struct stack* s); +void print_node(struct node_base* n) { + if(n->tag == NODE_APP) { + struct node_app* app = (struct node_app*) n; + print_node(app->left); + putchar(' '); + print_node(app->right); + } else if(n->tag == NODE_DATA) { + printf("(Packed)"); + } else if(n->tag == NODE_GLOBAL) { + struct node_global* global = (struct node_global*) n; + printf("(Global: %p)", global->function); + } else if(n->tag == NODE_IND) { + print_node(((struct node_ind*) n)->next); + } else if(n->tag == NODE_NUM) { + struct node_num* num = (struct node_num*) n; + printf("%d", num->value); + } +} + int main(int argc, char** argv) { struct node_global* first_node = alloc_global(f_main, 0); struct node_base* result = eval((struct node_base*) first_node); + + printf("Result: "); + print_node(result); + putchar('\n'); }