diff --git a/code/compiler/13/ast.cpp b/code/compiler/13/ast.cpp index 2e83066..d66ee4e 100644 --- a/code/compiler/13/ast.cpp +++ b/code/compiler/13/ast.cpp @@ -204,10 +204,10 @@ type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) { input_type = mgr.resolve(case_type, var); type_app* app_type; - if(!(app_type = dynamic_cast(input_type.get())) || - !dynamic_cast(app_type->constructor.get())) { - throw type_error("attempting case analysis of non-data type", of->loc); - } + // if(!(app_type = dynamic_cast(input_type.get())) || + // !dynamic_cast(app_type->constructor.get())) { + // throw type_error("attempting case analysis of non-data type", of->loc); + // } return branch_type; } @@ -250,6 +250,11 @@ struct case_mappings { assert(default_case); return *default_case; } + + std::vector& get_case_for(tag_type tag) { + if(case_defined_for(tag)) return get_specific_case_for(tag); + return get_default_case(); + } bool case_defined_for(tag_type tag) { return defined_cases.find(tag) != defined_cases.end(); @@ -288,8 +293,19 @@ struct case_strategy_bool { return 2; } - instruction_ptr into_instruction(const type* type, case_mappings& ms) { - throw std::runtime_error("boolean case unimplemented!"); + void into_instructions( + const type* type, + case_mappings& ms, + std::vector& into) { + if(ms.defined_cases_count() == 0) { + for(auto& instruction : ms.get_default_case()) + into.push_back(std::move(instruction)); + return; + } + + into.push_back(instruction_ptr(new instruction_if( + std::move(ms.get_case_for(true)), + std::move(ms.get_case_for(false))))); } }; @@ -327,7 +343,10 @@ struct case_strategy_data { return static_cast(type)->constructors.size(); } - instruction_ptr into_instruction(const type* type, case_mappings& ms) { + void into_instructions( + const type* type, + case_mappings& ms, + std::vector& into) { instruction_jump* jump_instruction = new instruction_jump(); instruction_ptr inst(jump_instruction); @@ -350,7 +369,7 @@ struct case_strategy_data { } } - return std::move(inst); + into.push_back(std::move(inst)); } }; @@ -364,6 +383,7 @@ void compile_case(const ast_case& node, const env_ptr& env, const type* type, st auto& branch_into = cases.make_default_case(); env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env)); branch->expr->compile(new_env, branch_into); + branch_into.push_back(instruction_ptr(new instruction_slide(1))); } else { auto repr = strategy.from_typed_pattern(branch->pat, type); auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr)); @@ -375,7 +395,7 @@ void compile_case(const ast_case& node, const env_ptr& env, const type* type, st cases.default_case_defined())) throw type_error("incomplete patterns", node.loc); - into.push_back(strategy.into_instruction(type, cases)); + strategy.into_instructions(type, cases, into); } void ast_case::compile(const env_ptr& env, std::vector& into) const { diff --git a/code/compiler/13/instruction.cpp b/code/compiler/13/instruction.cpp index c2b050a..80ae40c 100644 --- a/code/compiler/13/instruction.cpp +++ b/code/compiler/13/instruction.cpp @@ -122,6 +122,45 @@ void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const { ctx.builder.SetInsertPoint(safety_block); } +void instruction_if::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "If(" << std::endl; + for(auto& instruction : on_true) { + instruction->print(indent + 2, to); + } + to << std::endl; + for(auto& instruction : on_false) { + instruction->print(indent + 2, to); + } + print_indent(indent, to); + to << ")" << std::endl; +} + +void instruction_if::gen_llvm(llvm_context& ctx, llvm::Function* f) const { + auto top_node = ctx.create_peek(f, ctx.create_size(0)); + auto num = ctx.unwrap_num(top_node); + + auto nonzero_block = BasicBlock::Create(ctx.ctx, "nonzero", f); + auto zero_block = BasicBlock::Create(ctx.ctx, "zero", f); + auto resume_block = BasicBlock::Create(ctx.ctx, "resume", f); + auto switch_op = ctx.builder.CreateSwitch(num, nonzero_block, 2); + + switch_op->addCase(ctx.create_i32(0), zero_block); + ctx.builder.SetInsertPoint(nonzero_block); + for(auto& instruction : on_true) { + instruction->gen_llvm(ctx, f); + } + ctx.builder.CreateBr(resume_block); + + ctx.builder.SetInsertPoint(zero_block); + for(auto& instruction : on_true) { + instruction->gen_llvm(ctx, f); + } + ctx.builder.CreateBr(resume_block); + + ctx.builder.SetInsertPoint(resume_block); +} + void instruction_slide::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "Slide(" << offset << ")" << std::endl; diff --git a/code/compiler/13/instruction.hpp b/code/compiler/13/instruction.hpp index abe2409..36b4191 100644 --- a/code/compiler/13/instruction.hpp +++ b/code/compiler/13/instruction.hpp @@ -101,6 +101,19 @@ struct instruction_jump : public instruction { void gen_llvm(llvm_context& ctx, llvm::Function* f) const; }; +struct instruction_if : public instruction { + std::vector on_true; + std::vector on_false; + + instruction_if( + std::vector t, + std::vector f) + : on_true(std::move(t)), on_false(std::move(f)) {} + + void print(int indent, std::ostream& to) const; + void gen_llvm(llvm_context& ctx, llvm::Function* f) const; +}; + struct instruction_slide : public instruction { int offset; diff --git a/code/compiler/13/main.cpp b/code/compiler/13/main.cpp index fa3aeec..aa3bcef 100644 --- a/code/compiler/13/main.cpp +++ b/code/compiler/13/main.cpp @@ -22,13 +22,15 @@ void yy::parser::error(const yy::location& loc, const std::string& msg) { std::cout << "An error occured: " << msg << std::endl; } -void typecheck_program( - definition_group& defs, - type_mgr& mgr, type_env_ptr& env) { +void prelude_types(definition_group& defs, type_env_ptr env) { type_ptr int_type = type_ptr(new type_internal("Int")); env->bind_type("Int", int_type); type_ptr int_type_app = type_ptr(new type_app(int_type)); + type_ptr bool_type = type_ptr(new type_internal("Bool")); + env->bind_type("Bool", bool_type); + type_ptr bool_type_app = type_ptr(new type_app(bool_type)); + type_ptr binop_type = type_ptr(new type_arr( int_type_app, type_ptr(new type_arr(int_type_app, int_type_app)))); @@ -37,6 +39,15 @@ void typecheck_program( env->bind("*", binop_type, visibility::global); env->bind("/", binop_type, visibility::global); + env->bind("True", bool_type_app, visibility::global); + env->bind("False", bool_type_app, visibility::global); +} + +void typecheck_program( + definition_group& defs, + type_mgr& mgr, type_env_ptr& env) { + prelude_types(defs, env); + std::set free; defs.find_free(free); defs.typecheck(mgr, env); @@ -60,23 +71,6 @@ global_scope translate_program(definition_group& group) { return scope; } -void gen_llvm_internal_op(llvm_context& ctx, binop op) { - auto new_function = ctx.create_custom_function(op_action(op), 2); - std::vector instructions; - instructions.push_back(instruction_ptr(new instruction_push(1))); - instructions.push_back(instruction_ptr(new instruction_eval())); - instructions.push_back(instruction_ptr(new instruction_push(1))); - instructions.push_back(instruction_ptr(new instruction_eval())); - instructions.push_back(instruction_ptr(new instruction_binop(op))); - instructions.push_back(instruction_ptr(new instruction_update(2))); - instructions.push_back(instruction_ptr(new instruction_pop(2))); - ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); - for(auto& instruction : instructions) { - instruction->gen_llvm(ctx, new_function); - } - ctx.builder.CreateRetVoid(); -} - void output_llvm(llvm_context& ctx, const std::string& filename) { std::string targetTriple = llvm::sys::getDefaultTargetTriple(); @@ -117,12 +111,43 @@ void output_llvm(llvm_context& ctx, const std::string& filename) { } } +void gen_llvm_internal_op(llvm_context& ctx, binop op) { + auto new_function = ctx.create_custom_function(op_action(op), 2); + std::vector instructions; + instructions.push_back(instruction_ptr(new instruction_push(1))); + instructions.push_back(instruction_ptr(new instruction_eval())); + instructions.push_back(instruction_ptr(new instruction_push(1))); + instructions.push_back(instruction_ptr(new instruction_eval())); + instructions.push_back(instruction_ptr(new instruction_binop(op))); + instructions.push_back(instruction_ptr(new instruction_update(2))); + instructions.push_back(instruction_ptr(new instruction_pop(2))); + ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); + for(auto& instruction : instructions) { + instruction->gen_llvm(ctx, new_function); + } + ctx.builder.CreateRetVoid(); +} + +void gen_llvm_boolean_constructor(llvm_context& ctx, const std::string& s, bool b) { + auto new_function = ctx.create_custom_function(s, 0); + std::vector instructions; + instructions.push_back(instruction_ptr(new instruction_pushint(b))); + instructions.push_back(instruction_ptr(new instruction_update(0))); + ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); + for(auto& instruction : instructions) { + instruction->gen_llvm(ctx, new_function); + } + ctx.builder.CreateRetVoid(); +} + void gen_llvm(global_scope& scope) { llvm_context ctx; gen_llvm_internal_op(ctx, PLUS); gen_llvm_internal_op(ctx, MINUS); gen_llvm_internal_op(ctx, TIMES); gen_llvm_internal_op(ctx, DIVIDE); + gen_llvm_boolean_constructor(ctx, "True", true); + gen_llvm_boolean_constructor(ctx, "False", false); scope.generate_llvm(ctx);