diff --git a/code/compiler/10/ast.cpp b/code/compiler/10/ast.cpp index 0abbb10..9ea0822 100644 --- a/code/compiler/10/ast.cpp +++ b/code/compiler/10/ast.cpp @@ -2,6 +2,7 @@ #include #include "binop.hpp" #include "error.hpp" +#include "type_env.hpp" static void print_indent(int n, std::ostream& to) { while(n--) to << " "; @@ -12,7 +13,11 @@ void ast_int::print(int indent, std::ostream& to) const { to << "INT: " << value << std::endl; } -type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) { +void ast_int::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; +} + +type_ptr ast_int::typecheck(type_mgr& mgr) { return type_ptr(new type_base("Int")); } @@ -25,8 +30,13 @@ void ast_lid::print(int indent, std::ostream& to) const { to << "LID: " << id << std::endl; } -type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) { - return env.lookup(id); +void ast_lid::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; + if(env->lookup(id) == nullptr) into.insert(id); +} + +type_ptr ast_lid::typecheck(type_mgr& mgr) { + return env->lookup(id); } void ast_lid::compile(const env_ptr& env, std::vector& into) const { @@ -41,8 +51,12 @@ void ast_uid::print(int indent, std::ostream& to) const { to << "UID: " << id << std::endl; } -type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) { - return env.lookup(id); +void ast_uid::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; +} + +type_ptr ast_uid::typecheck(type_mgr& mgr) { + return env->lookup(id); } void ast_uid::compile(const env_ptr& env, std::vector& into) const { @@ -56,10 +70,16 @@ void ast_binop::print(int indent, std::ostream& to) const { right->print(indent + 1, to); } -type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) { - type_ptr ltype = left->typecheck(mgr, env); - type_ptr rtype = right->typecheck(mgr, env); - type_ptr ftype = env.lookup(op_name(op)); +void ast_binop::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; + left->find_free(mgr, env, into); + right->find_free(mgr, env, into); +} + +type_ptr ast_binop::typecheck(type_mgr& mgr) { + type_ptr ltype = left->typecheck(mgr); + type_ptr rtype = right->typecheck(mgr); + type_ptr ftype = env->lookup(op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); type_ptr return_type = mgr.new_type(); @@ -86,9 +106,15 @@ void ast_app::print(int indent, std::ostream& to) const { right->print(indent + 1, to); } -type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) { - type_ptr ltype = left->typecheck(mgr, env); - type_ptr rtype = right->typecheck(mgr, env); +void ast_app::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; + left->find_free(mgr, env, into); + right->find_free(mgr, env, into); +} + +type_ptr ast_app::typecheck(type_mgr& mgr) { + type_ptr ltype = left->typecheck(mgr); + type_ptr rtype = right->typecheck(mgr); type_ptr return_type = mgr.new_type(); type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); @@ -113,15 +139,24 @@ void ast_case::print(int indent, std::ostream& to) const { } } -type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) { +void ast_case::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = env; + of->find_free(mgr, env, into); + for(auto& branch : branches) { + type_env_ptr new_env = type_scope(env); + branch->pat->insert_bindings(mgr, new_env); + branch->expr->find_free(mgr, new_env, into); + } +} + +type_ptr ast_case::typecheck(type_mgr& mgr) { type_var* var; - type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); + type_ptr case_type = mgr.resolve(of->typecheck(mgr), var); type_ptr branch_type = mgr.new_type(); for(auto& branch : branches) { - type_env new_env = env.scope(); - branch->pat->match(case_type, mgr, new_env); - type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); + branch->pat->typecheck(case_type, mgr, branch->expr->env); + type_ptr curr_branch_type = branch->expr->typecheck(mgr); mgr.unify(branch_type, curr_branch_type); } @@ -192,8 +227,12 @@ void pattern_var::print(std::ostream& to) const { to << var; } -void pattern_var::match(type_ptr t, type_mgr& mgr, type_env& env) const { - env.bind(var, t); +void pattern_var::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { + env->bind(var, mgr.new_type()); +} + +void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { + mgr.unify(env->lookup(var), t); } void pattern_constr::print(std::ostream& to) const { @@ -203,17 +242,23 @@ void pattern_constr::print(std::ostream& to) const { } } -void pattern_constr::match(type_ptr t, type_mgr& mgr, type_env& env) const { - type_ptr constructor_type = env.lookup(constr); +void pattern_constr::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { + for(auto& param : params) { + env->bind(param, mgr.new_type()); + } +} + +void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { + type_ptr constructor_type = env->lookup(constr); if(!constructor_type) { throw type_error(std::string("pattern using unknown constructor ") + constr); } - for(int i = 0; i < params.size(); i++) { + for(auto& param : params) { type_arr* arr = dynamic_cast(constructor_type.get()); if(!arr) throw type_error("too many parameters in constructor pattern"); - env.bind(params[i], arr->left); + mgr.unify(env->lookup(param), arr->left); constructor_type = arr->right; } diff --git a/code/compiler/10/ast.hpp b/code/compiler/10/ast.hpp index 98b6b24..6c66636 100644 --- a/code/compiler/10/ast.hpp +++ b/code/compiler/10/ast.hpp @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include "type.hpp" #include "type_env.hpp" #include "binop.hpp" @@ -8,10 +9,14 @@ #include "env.hpp" struct ast { + type_env_ptr env; + virtual ~ast() = default; virtual void print(int indent, std::ostream& to) const = 0; - virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) = 0; + virtual void find_free(type_mgr& mgr, + type_env_ptr& env, std::set& into) = 0; + virtual type_ptr typecheck(type_mgr& mgr) = 0; virtual void compile(const env_ptr& env, std::vector& into) const = 0; }; @@ -22,7 +27,8 @@ struct pattern { virtual ~pattern() = default; virtual void print(std::ostream& to) const = 0; - virtual void match(type_ptr t, type_mgr& mgr, type_env& env) const = 0; + virtual void insert_bindings(type_mgr& mgr, type_env_ptr& env) const = 0; + virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const = 0; }; using pattern_ptr = std::unique_ptr; @@ -44,7 +50,8 @@ struct ast_int : public ast { : value(v) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -55,7 +62,8 @@ struct ast_lid : public ast { : id(std::move(i)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -66,7 +74,8 @@ struct ast_uid : public ast { : id(std::move(i)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -79,7 +88,8 @@ struct ast_binop : public ast { : op(o), left(std::move(l)), right(std::move(r)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -91,7 +101,8 @@ struct ast_app : public ast { : left(std::move(l)), right(std::move(r)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -104,7 +115,8 @@ struct ast_case : public ast { : of(std::move(o)), branches(std::move(b)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env); + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + type_ptr typecheck(type_mgr& mgr); void compile(const env_ptr& env, std::vector& into) const; }; @@ -115,7 +127,8 @@ struct pattern_var : public pattern { : var(std::move(v)) {} void print(std::ostream &to) const; - void match(type_ptr t, type_mgr& mgr, type_env& env) const; + void insert_bindings(type_mgr& mgr, type_env_ptr& env) const; + void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const; }; struct pattern_constr : public pattern { @@ -126,5 +139,6 @@ struct pattern_constr : public pattern { : constr(std::move(c)), params(std::move(p)) {} void print(std::ostream &to) const; - void match(type_ptr t, type_mgr&, type_env& env) const; + virtual void insert_bindings(type_mgr& mgr, type_env_ptr& env) const; + virtual void typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const; }; diff --git a/code/compiler/10/definition.cpp b/code/compiler/10/definition.cpp index 4a14571..c4bf4f7 100644 --- a/code/compiler/10/definition.cpp +++ b/code/compiler/10/definition.cpp @@ -3,35 +3,33 @@ #include "ast.hpp" #include "instruction.hpp" #include "llvm_context.hpp" +#include "type_env.hpp" #include #include #include -void definition_defn::typecheck_first(type_mgr& mgr, type_env& env) { +void definition_defn::find_free(type_mgr& mgr, type_env_ptr& env) { + this->env = env; + + var_env = type_scope(env); return_type = mgr.new_type(); - type_ptr full_type = return_type; + full_type = return_type; for(auto it = params.rbegin(); it != params.rend(); it++) { type_ptr param_type = mgr.new_type(); full_type = type_ptr(new type_arr(param_type, full_type)); - param_types.push_back(param_type); + var_env->bind(*it, param_type); } - env.bind(name, full_type); + body->find_free(mgr, var_env, free_variables); } -void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const { - type_env new_env = env.scope(); - auto param_it = params.begin(); - auto type_it = param_types.rbegin(); +void definition_defn::insert_types(type_mgr& mgr) { + env->bind(name, full_type); +} - while(param_it != params.end() && type_it != param_types.rend()) { - new_env.bind(*param_it, *type_it); - param_it++; - type_it++; - } - - type_ptr body_type = body->typecheck(mgr, new_env); +void definition_defn::typecheck(type_mgr& mgr) { + type_ptr body_type = body->typecheck(mgr); mgr.unify(return_type, body_type); } @@ -44,11 +42,12 @@ void definition_defn::compile() { instructions.push_back(instruction_ptr(new instruction_update(params.size()))); instructions.push_back(instruction_ptr(new instruction_pop(params.size()))); } -void definition_defn::gen_llvm_first(llvm_context& ctx) { + +void definition_defn::declare_llvm(llvm_context& ctx) { generated_function = ctx.create_custom_function(name, params.size()); } -void definition_defn::gen_llvm_second(llvm_context& ctx) { +void definition_defn::generate_llvm(llvm_context& ctx) { ctx.builder.SetInsertPoint(&generated_function->getEntryBlock()); for(auto& instruction : instructions) { instruction->gen_llvm(ctx, generated_function); @@ -56,7 +55,11 @@ void definition_defn::gen_llvm_second(llvm_context& ctx) { ctx.builder.CreateRetVoid(); } -void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { +void definition_data::insert_types(type_mgr& mgr, type_env_ptr& env) { + this->env = env; +} + +void definition_data::insert_constructors() const { type_data* this_type = new type_data(name); type_ptr return_type = type_ptr(this_type); int next_tag = 0; @@ -71,19 +74,11 @@ void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { full_type = type_ptr(new type_arr(type, full_type)); } - env.bind(constructor->name, full_type); + env->bind(constructor->name, full_type); } } -void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const { - // Nothing -} - -void definition_data::compile() { - -} - -void definition_data::gen_llvm_first(llvm_context& ctx) { +void definition_data::generate_llvm(llvm_context& ctx) { for(auto& constructor : constructors) { auto new_function = ctx.create_custom_function(constructor->name, constructor->types.size()); @@ -99,7 +94,3 @@ void definition_data::gen_llvm_first(llvm_context& ctx) { ctx.builder.CreateRetVoid(); } } - -void definition_data::gen_llvm_second(llvm_context& ctx) { - // Nothing -} diff --git a/code/compiler/10/definition.hpp b/code/compiler/10/definition.hpp index 74a2b69..b72bed6 100644 --- a/code/compiler/10/definition.hpp +++ b/code/compiler/10/definition.hpp @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include "instruction.hpp" #include "llvm_context.hpp" #include "type_env.hpp" @@ -8,18 +9,6 @@ struct ast; using ast_ptr = std::unique_ptr; -struct definition { - virtual ~definition() = default; - - virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; - virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; - virtual void compile() = 0; - virtual void gen_llvm_first(llvm_context& ctx) = 0; - virtual void gen_llvm_second(llvm_context& ctx) = 0; -}; - -using definition_ptr = std::unique_ptr; - struct constructor { std::string name; std::vector types; @@ -31,13 +20,16 @@ struct constructor { using constructor_ptr = std::unique_ptr; -struct definition_defn : public definition { +struct definition_defn { std::string name; std::vector params; ast_ptr body; + type_env_ptr env; + type_env_ptr var_env; + std::set free_variables; + type_ptr full_type; type_ptr return_type; - std::vector param_types; std::vector instructions; @@ -48,23 +40,28 @@ struct definition_defn : public definition { } - void typecheck_first(type_mgr& mgr, type_env& env); - void typecheck_second(type_mgr& mgr, const type_env& env) const; + void find_free(type_mgr& mgr, type_env_ptr& env); + void insert_types(type_mgr& mgr); + void typecheck(type_mgr& mgr); void compile(); - void gen_llvm_first(llvm_context& ctx); - void gen_llvm_second(llvm_context& ctx); + void declare_llvm(llvm_context& ctx); + void generate_llvm(llvm_context& ctx); }; -struct definition_data : public definition { +using definition_defn_ptr = std::unique_ptr; + +struct definition_data { std::string name; std::vector constructors; + type_env_ptr env; + definition_data(std::string n, std::vector cs) : name(std::move(n)), constructors(std::move(cs)) {} - void typecheck_first(type_mgr& mgr, type_env& env); - void typecheck_second(type_mgr& mgr, const type_env& env) const; - void compile(); - void gen_llvm_first(llvm_context& ctx); - void gen_llvm_second(llvm_context& ctx); + void insert_types(type_mgr& mgr, type_env_ptr& env); + void insert_constructors() const; + void generate_llvm(llvm_context& ctx); }; + +using definition_data_ptr = std::unique_ptr; diff --git a/code/compiler/10/main.cpp b/code/compiler/10/main.cpp index 27529b7..dd59316 100644 --- a/code/compiler/10/main.cpp +++ b/code/compiler/10/main.cpp @@ -20,43 +20,52 @@ void yy::parser::error(const std::string& msg) { std::cout << "An error occured: " << msg << std::endl; } -extern std::vector program; +extern std::vector defs_data; +extern std::vector defs_defn; void typecheck_program( - const std::vector& prog, - type_mgr& mgr, type_env& env) { + const std::vector& defs_data, + const std::vector& defs_defn, + type_mgr& mgr, type_env_ptr& env) { type_ptr int_type = type_ptr(new type_base("Int")); type_ptr binop_type = type_ptr(new type_arr( int_type, type_ptr(new type_arr(int_type, int_type)))); - env.bind("+", binop_type); - env.bind("-", binop_type); - env.bind("*", binop_type); - env.bind("/", binop_type); + env->bind("+", binop_type); + env->bind("-", binop_type); + env->bind("*", binop_type); + env->bind("/", binop_type); - for(auto& def : prog) { - def->typecheck_first(mgr, env); + for(auto& def_data : defs_data) { + def_data->insert_types(mgr, env); + } + for(auto& def_data : defs_data) { + def_data->insert_constructors(); } - for(auto& def : prog) { - def->typecheck_second(mgr, env); + for(auto& def_defn : defs_defn) { + def_defn->find_free(mgr, env); + } + for(auto& def_defn : defs_defn) { + def_defn->insert_types(mgr); + } + for(auto& def_defn : defs_defn) { + def_defn->typecheck(mgr); } - for(auto& pair : env.names) { + for(auto& pair : env->names) { std::cout << pair.first << ": "; pair.second->print(mgr, std::cout); std::cout << std::endl; } } -void compile_program(const std::vector& prog) { - for(auto& def : prog) { - def->compile(); +void compile_program(const std::vector& defs_defn) { + for(auto& def_defn : defs_defn) { + def_defn->compile(); - definition_defn* defn = dynamic_cast(def.get()); - if(!defn) continue; - for(auto& instruction : defn->instructions) { + for(auto& instruction : def_defn->instructions) { instruction->print(0, std::cout); } std::cout << std::endl; @@ -120,20 +129,25 @@ void output_llvm(llvm_context& ctx, const std::string& filename) { } } -void gen_llvm(const std::vector& prog) { +void gen_llvm( + const std::vector& defs_data, + const std::vector& defs_defn) { llvm_context ctx; gen_llvm_internal_op(ctx, PLUS); gen_llvm_internal_op(ctx, MINUS); gen_llvm_internal_op(ctx, TIMES); gen_llvm_internal_op(ctx, DIVIDE); - for(auto& definition : prog) { - definition->gen_llvm_first(ctx); + for(auto& def_data : defs_data) { + def_data->generate_llvm(ctx); + } + for(auto& def_defn : defs_defn) { + def_defn->declare_llvm(ctx); + } + for(auto& def_defn : defs_defn) { + def_defn->generate_llvm(ctx); } - for(auto& definition : prog) { - definition->gen_llvm_second(ctx); - } ctx.module.print(llvm::outs(), nullptr); output_llvm(ctx, "program.o"); } @@ -141,23 +155,20 @@ void gen_llvm(const std::vector& prog) { int main() { yy::parser parser; type_mgr mgr; - type_env env; + type_env_ptr env(new type_env); parser.parse(); - for(auto& definition : program) { - definition_defn* def = dynamic_cast(definition.get()); - if(!def) continue; - - std::cout << def->name; - for(auto& param : def->params) std::cout << " " << param; + for(auto& def_defn : defs_defn) { + std::cout << def_defn->name; + for(auto& param : def_defn->params) std::cout << " " << param; std::cout << ":" << std::endl; - def->body->print(1, std::cout); + def_defn->body->print(1, std::cout); } try { - typecheck_program(program, mgr, env); - compile_program(program); - gen_llvm(program); + typecheck_program(defs_data, defs_defn, mgr, env); + compile_program(defs_defn); + gen_llvm(defs_data, defs_defn); } catch(unification_error& err) { std::cout << "failed to unify types: " << std::endl; std::cout << " (1) \033[34m"; diff --git a/code/compiler/10/parser.y b/code/compiler/10/parser.y index 088648d..9614126 100644 --- a/code/compiler/10/parser.y +++ b/code/compiler/10/parser.y @@ -5,7 +5,9 @@ #include "definition.hpp" #include "parser.hpp" -std::vector program; +std::vector defs_data; +std::vector defs_defn; + extern yy::parser::symbol_type yylex(); %} @@ -34,11 +36,11 @@ extern yy::parser::symbol_type yylex(); %define api.token.constructor %type > lowercaseParams uppercaseParams -%type > program definitions %type > branches %type > constructors %type aAdd aMul case app appBase -%type definition defn data +%type data +%type defn %type branch %type pattern %type constructor @@ -48,22 +50,22 @@ extern yy::parser::symbol_type yylex(); %% program - : definitions { program = std::move($1); } + : definitions { } ; definitions - : definitions definition { $$ = std::move($1); $$.push_back(std::move($2)); } - | definition { $$ = std::vector(); $$.push_back(std::move($1)); } + : definitions definition { } + | definition { } ; definition - : defn { $$ = std::move($1); } - | data { $$ = std::move($1); } + : defn { defs_defn.push_back(std::move($1)); } + | data { defs_data.push_back(std::move($1)); } ; defn : DEFN LID lowercaseParams EQUAL OCURLY aAdd CCURLY - { $$ = definition_ptr( + { $$ = definition_defn_ptr( new definition_defn(std::move($2), std::move($3), std::move($6))); } ; @@ -125,7 +127,7 @@ pattern data : DATA UID EQUAL OCURLY constructors CCURLY - { $$ = definition_ptr(new definition_data(std::move($2), std::move($5))); } + { $$ = definition_data_ptr(new definition_data(std::move($2), std::move($5))); } ; constructors