diff --git a/code/compiler/11/ast.cpp b/code/compiler/11/ast.cpp index b55c347..5bdd00c 100644 --- a/code/compiler/11/ast.cpp +++ b/code/compiler/11/ast.cpp @@ -18,7 +18,7 @@ void ast_int::find_free(type_mgr& mgr, type_env_ptr& env, std::set& } type_ptr ast_int::typecheck(type_mgr& mgr) { - return type_ptr(new type_base("Int")); + return type_ptr(new type_app(env->lookup_type("Int"))); } void ast_int::compile(const env_ptr& env, std::vector& into) const { @@ -161,7 +161,9 @@ type_ptr ast_case::typecheck(type_mgr& mgr) { } input_type = mgr.resolve(case_type, var); - if(!dynamic_cast(input_type.get())) { + type_app* app_type; + if(!(app_type = dynamic_cast(input_type.get())) || + !dynamic_cast(app_type->constructor.get())) { throw type_error("attempting case analysis of non-data type"); } @@ -169,7 +171,8 @@ type_ptr ast_case::typecheck(type_mgr& mgr) { } void ast_case::compile(const env_ptr& env, std::vector& into) const { - type_data* type = dynamic_cast(input_type.get()); + type_app* app_type = dynamic_cast(input_type.get()); + type_data* type = dynamic_cast(app_type->constructor.get()); of->compile(env, into); into.push_back(instruction_ptr(new instruction_eval())); diff --git a/code/compiler/11/definition.cpp b/code/compiler/11/definition.cpp index 91b0178..79d53af 100644 --- a/code/compiler/11/definition.cpp +++ b/code/compiler/11/definition.cpp @@ -56,28 +56,38 @@ void definition_defn::generate_llvm(llvm_context& ctx) { ctx.builder.CreateRetVoid(); } -void definition_data::insert_types(type_mgr& mgr, type_env_ptr& env) { +void definition_data::insert_types(type_env_ptr& env) { this->env = env; env->bind_type(name, type_ptr(new type_data(name))); } void definition_data::insert_constructors() const { - type_ptr return_type = env->lookup_type(name); - type_data* this_type = static_cast(return_type.get()); + type_ptr this_type_ptr = env->lookup_type(name); + type_data* this_type = static_cast(this_type_ptr.get()); int next_tag = 0; + std::set var_set; + type_app* return_app = new type_app(std::move(this_type_ptr)); + type_ptr return_type(return_app); + for(auto& var : vars) { + if(var_set.find(var) != var_set.end()) throw 0; + var_set.insert(var); + return_app->arguments.push_back(type_ptr(new type_var(var))); + } + for(auto& constructor : constructors) { constructor->tag = next_tag; this_type->constructors[constructor->name] = { next_tag++ }; type_ptr full_type = return_type; for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) { - type_ptr type = env->lookup_type(*it); - if(!type) throw 0; + type_ptr type = (*it)->to_type(var_set, env); full_type = type_ptr(new type_arr(type, full_type)); } - env->bind(constructor->name, full_type); + type_scheme_ptr full_scheme(new type_scheme(std::move(full_type))); + full_scheme->forall.insert(full_scheme->forall.begin(), vars.begin(), vars.end()); + env->bind(constructor->name, full_scheme); } } diff --git a/code/compiler/11/definition.hpp b/code/compiler/11/definition.hpp index b72bed6..c14ec0f 100644 --- a/code/compiler/11/definition.hpp +++ b/code/compiler/11/definition.hpp @@ -4,6 +4,7 @@ #include #include "instruction.hpp" #include "llvm_context.hpp" +#include "parsed_type.hpp" #include "type_env.hpp" struct ast; @@ -11,10 +12,10 @@ using ast_ptr = std::unique_ptr; struct constructor { std::string name; - std::vector types; + std::vector types; int8_t tag; - constructor(std::string n, std::vector ts) + constructor(std::string n, std::vector ts) : name(std::move(n)), types(std::move(ts)) {} }; @@ -52,14 +53,18 @@ using definition_defn_ptr = std::unique_ptr; struct definition_data { std::string name; + std::vector vars; std::vector constructors; type_env_ptr env; - definition_data(std::string n, std::vector cs) - : name(std::move(n)), constructors(std::move(cs)) {} + definition_data( + std::string n, + std::vector vs, + std::vector cs) + : name(std::move(n)), vars(std::move(vs)), constructors(std::move(cs)) {} - void insert_types(type_mgr& mgr, type_env_ptr& env); + void insert_types(type_env_ptr& env); void insert_constructors() const; void generate_llvm(llvm_context& ctx); }; diff --git a/code/compiler/11/main.cpp b/code/compiler/11/main.cpp index bc181d6..5cd3aa1 100644 --- a/code/compiler/11/main.cpp +++ b/code/compiler/11/main.cpp @@ -30,17 +30,18 @@ void typecheck_program( type_mgr& mgr, type_env_ptr& env) { type_ptr int_type = type_ptr(new type_base("Int")); env->bind_type("Int", int_type); + type_ptr int_type_app = type_ptr(new type_app(int_type)); type_ptr binop_type = type_ptr(new type_arr( - int_type, - type_ptr(new type_arr(int_type, int_type)))); + int_type_app, + type_ptr(new type_arr(int_type_app, int_type_app)))); env->bind("+", binop_type); env->bind("-", binop_type); env->bind("*", binop_type); env->bind("/", binop_type); for(auto& def_data : defs_data) { - def_data.second->insert_types(mgr, env); + def_data.second->insert_types(env); } for(auto& def_data : defs_data) { def_data.second->insert_constructors(); diff --git a/code/compiler/11/parser.y b/code/compiler/11/parser.y index 57f4be9..81ef4ac 100644 --- a/code/compiler/11/parser.y +++ b/code/compiler/11/parser.y @@ -5,6 +5,7 @@ #include "ast.hpp" #include "definition.hpp" #include "parser.hpp" +#include "parsed_type.hpp" std::map defs_data; std::map defs_defn; @@ -36,9 +37,11 @@ extern yy::parser::symbol_type yylex(); %define api.value.type variant %define api.token.constructor -%type > lowercaseParams uppercaseParams +%type > lowercaseParams %type > branches %type > constructors +%type > typeList +%type type nullaryType %type aAdd aMul case app appBase %type data %type defn @@ -75,11 +78,6 @@ lowercaseParams | lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); } ; -uppercaseParams - : %empty { $$ = std::vector(); } - | uppercaseParams UID { $$ = std::move($1); $$.push_back(std::move($2)); } - ; - aAdd : aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3))); } | aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3))); } @@ -127,8 +125,8 @@ pattern ; data - : DATA UID EQUAL OCURLY constructors CCURLY - { $$ = definition_data_ptr(new definition_data(std::move($2), std::move($5))); } + : DATA UID lowercaseParams EQUAL OCURLY constructors CCURLY + { $$ = definition_data_ptr(new definition_data(std::move($2), std::move($3), std::move($6))); } ; constructors @@ -138,7 +136,22 @@ constructors ; constructor - : UID uppercaseParams + : UID typeList { $$ = constructor_ptr(new constructor(std::move($1), std::move($2))); } ; +type + : nullaryType ARROW type { $$ = parsed_type_ptr(new parsed_type_arr(std::move($1), std::move($3))); } + | nullaryType { $$ = std::move($1); } + ; + +nullaryType + : OPAREN UID typeList CPAREN { $$ = parsed_type_ptr(new parsed_type_app(std::move($2), std::move($3))); } + | UID { $$ = parsed_type_ptr(new parsed_type_app(std::move($1), { })); } + | LID { $$ = parsed_type_ptr(new parsed_type_var(std::move($1))); } + ; + +typeList + : %empty { $$ = std::vector(); } + | typeList type { $$ = std::move($1); $$.push_back(std::move($2)); } + ;