diff --git a/03/ast.cpp b/03/ast.cpp index 734a54a..f244ffa 100644 --- a/03/ast.cpp +++ b/03/ast.cpp @@ -1 +1,88 @@ #include "ast.hpp" + +std::string op_name(binop op) { + switch(op) { + case PLUS: return "+"; + case MINUS: return "-"; + case TIMES: return "*"; + case DIVIDE: return "/"; + } + throw 0; +} + +type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const { + return type_ptr(new type_base("Int")); +} + +type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const { + return env.lookup(id); +} + +type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const { + return env.lookup(id); +} + +type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); + type_ptr ftype = env.lookup(op_name(op)); + if(!ftype) throw 0; + + type_ptr place_a = mgr.new_type(); + type_ptr place_b = mgr.new_type(); + type_ptr place_c = mgr.new_type(); + type_ptr arrow_one = type_ptr(new type_arr(place_b, place_c)); + type_ptr arrow_two = type_ptr(new type_arr(place_a, arrow_one)); + + mgr.unify(arrow_two, ftype); + mgr.unify(place_a, ltype); + mgr.unify(place_b, rtype); + return place_c; +} + +type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); + + type_ptr place_a = mgr.new_type(); + type_ptr place_b = mgr.new_type(); + type_ptr arrow = type_ptr(new type_arr(place_a, place_b)); + mgr.unify(arrow, ltype); + mgr.unify(place_a, rtype); + return place_b; +} + +type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { + type_ptr case_type = of->typecheck(mgr, env); + type_ptr branch_type = mgr.new_type(); + + for(auto& branch : branches) { + type_env new_env = env.scope(); + branch->pat->match(case_type, mgr, new_env); + type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); + mgr.unify(branch_type, curr_branch_type); + } + + return branch_type; +} + +void pattern_var::match(type_ptr t, type_mgr& mgr, type_env& env) const { + env.bind(var, t); +} + +void pattern_constr::match(type_ptr t, type_mgr& mgr, type_env& env) const { + type_ptr constructor_type = env.lookup(constr); + if(!constructor_type) throw 0; + + for(int i = 0; i < params.size(); i++) { + type_arr* arr = dynamic_cast(constructor_type.get()); + if(!arr) throw 0; + + env.bind(params[i], arr->left); + constructor_type = arr->right; + } + + mgr.unify(t, constructor_type); + type_base* result_type = dynamic_cast(constructor_type.get()); + if(!result_type) throw 0; +} diff --git a/03/ast.hpp b/03/ast.hpp index e1629cd..c658be7 100644 --- a/03/ast.hpp +++ b/03/ast.hpp @@ -2,15 +2,20 @@ #include #include #include "type.hpp" +#include "env.hpp" struct ast { virtual ~ast() = default; + + virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0; }; using ast_ptr = std::unique_ptr; struct pattern { virtual ~pattern() = default; + + virtual void match(type_ptr t, type_mgr& mgr, type_env& env) const = 0; }; using pattern_ptr = std::unique_ptr; @@ -37,6 +42,9 @@ using constructor_ptr = std::unique_ptr; struct definition { virtual ~definition() = default; + + virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; + virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; }; using definition_ptr = std::unique_ptr; @@ -53,6 +61,8 @@ struct ast_int : public ast { explicit ast_int(int v) : value(v) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct ast_lid : public ast { @@ -60,6 +70,8 @@ struct ast_lid : public ast { explicit ast_lid(std::string i) : id(std::move(i)) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct ast_uid : public ast { @@ -67,6 +79,8 @@ struct ast_uid : public ast { explicit ast_uid(std::string i) : id(std::move(i)) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct ast_binop : public ast { @@ -76,6 +90,8 @@ struct ast_binop : public ast { ast_binop(binop o, ast_ptr l, ast_ptr r) : op(o), left(std::move(l)), right(std::move(r)) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct ast_app : public ast { @@ -84,6 +100,8 @@ struct ast_app : public ast { ast_app(ast_ptr l, ast_ptr r) : left(std::move(l)), right(std::move(r)) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct ast_case : public ast { @@ -92,6 +110,8 @@ struct ast_case : public ast { ast_case(ast_ptr o, std::vector b) : of(std::move(o)), branches(std::move(b)) {} + + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; }; struct pattern_var : public pattern { @@ -99,6 +119,8 @@ struct pattern_var : public pattern { pattern_var(std::string v) : var(std::move(v)) {} + + void match(type_ptr t, type_mgr& mgr, type_env& env) const; }; struct pattern_constr : public pattern { @@ -107,6 +129,8 @@ struct pattern_constr : public pattern { pattern_constr(std::string c, std::vector p) : constr(std::move(c)), params(std::move(p)) {} + + void match(type_ptr t, type_mgr&, type_env& env) const; }; struct definition_defn : public definition { @@ -114,10 +138,16 @@ struct definition_defn : public definition { std::vector params; ast_ptr body; + type_ptr return_type; + std::vector param_types; + definition_defn(std::string n, std::vector p, ast_ptr b) : name(std::move(n)), params(std::move(p)), body(std::move(b)) { } + + void typecheck_first(type_mgr& mgr, type_env& env); + void typecheck_second(type_mgr& mgr, const type_env& env) const; }; struct definition_data : public definition { @@ -126,4 +156,7 @@ struct definition_data : public definition { definition_data(std::string n, std::vector cs) : name(std::move(n)), constructors(std::move(cs)) {} + + void typecheck_first(type_mgr& mgr, type_env& env); + void typecheck_second(type_mgr& mgr, const type_env& env) const; }; diff --git a/03/clean.sh b/03/clean.sh index b987cf1..7580487 100755 --- a/03/clean.sh +++ b/03/clean.sh @@ -1 +1 @@ -rm -f parser.o parser.cpp parser.hpp stack.hh scanner.cpp scanner.o type.o a.out +rm -f parser.o parser.cpp parser.hpp stack.hh scanner.cpp scanner.o type.o env.o ast.o definition.o a.out diff --git a/03/compile.sh b/03/compile.sh index 7c28115..e67846b 100755 --- a/03/compile.sh +++ b/03/compile.sh @@ -1,6 +1,9 @@ bison -o parser.cpp -d parser.y flex -o scanner.cpp scanner.l -g++ -c -o scanner.o scanner.cpp -g++ -c -o parser.o parser.cpp -g++ -c -o type.o type.cpp -g++ main.cpp parser.o scanner.o type.o +g++ -g -c -o scanner.o scanner.cpp +g++ -g -c -o parser.o parser.cpp +g++ -g -c -o type.o type.cpp +g++ -g -c -o env.o env.cpp +g++ -g -c -o ast.o ast.cpp +g++ -g -c -o definition.o definition.cpp +g++ -g main.cpp parser.o scanner.o type.o env.o ast.o definition.o diff --git a/03/definition.cpp b/03/definition.cpp new file mode 100644 index 0000000..69b3196 --- /dev/null +++ b/03/definition.cpp @@ -0,0 +1,48 @@ +#include "ast.hpp" + +void definition_defn::typecheck_first(type_mgr& mgr, type_env& env) { + return_type = mgr.new_type(); + type_ptr full_type = return_type; + + for(auto it = params.rbegin(); it != params.rend(); it++) { + type_ptr param_type = mgr.new_type(); + full_type = type_ptr(new type_arr(param_type, full_type)); + param_types.push_back(param_type); + } + + env.bind(name, full_type); +} + +void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const { + type_env new_env = env.scope(); + auto param_it = params.begin(); + auto type_it = param_types.rbegin(); + + while(param_it != params.end() && type_it != param_types.rend()) { + new_env.bind(*param_it, *type_it); + param_it++; + type_it++; + } + + type_ptr body_type = body->typecheck(mgr, new_env); + mgr.unify(return_type, body_type); +} + +void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { + type_ptr return_type = type_ptr(new type_base(name)); + + for(auto& constructor : constructors) { + type_ptr full_type = return_type; + + for(auto& type_name : constructor->types) { + type_ptr type = type_ptr(new type_base(type_name)); + full_type = type_ptr(new type_arr(type, full_type)); + } + + env.bind(constructor->name, full_type); + } +} + +void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const { + // Nothing +} diff --git a/03/main.cpp b/03/main.cpp index 6a72a95..43c1b4f 100644 --- a/03/main.cpp +++ b/03/main.cpp @@ -8,8 +8,32 @@ void yy::parser::error(const std::string& msg) { extern std::vector program; +void typecheck_program(const std::vector& prog) { + type_mgr mgr; + type_env env; + + type_ptr int_type = type_ptr(new type_base("Int")); + type_ptr binop_type = type_ptr(new type_arr( + int_type, + type_ptr(new type_arr(int_type, int_type)))); + + env.bind("+", binop_type); + env.bind("-", binop_type); + env.bind("*", binop_type); + env.bind("/", binop_type); + + for(auto& def : prog) { + def->typecheck_first(mgr, env); + } + + for(auto& def : prog) { + def->typecheck_second(mgr, env); + } +} + int main() { yy::parser parser; parser.parse(); + typecheck_program(program); std::cout << program.size() << std::endl; } diff --git a/03/parser.y b/03/parser.y index d8c45b7..fdb72c9 100644 --- a/03/parser.y +++ b/03/parser.y @@ -32,7 +32,7 @@ extern yy::parser::symbol_type yylex(); %define api.value.type variant %define api.token.constructor -%type > lowercaseParams +%type > lowercaseParams uppercaseParams %type > program definitions %type > branches %type > constructors @@ -71,6 +71,11 @@ lowercaseParams | lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); } ; +uppercaseParams + : %empty { $$ = std::vector(); } + | uppercaseParams UID { $$ = std::move($1); $$.push_back(std::move($2)); } + ; + aAdd : aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3))); } | aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3))); } @@ -102,7 +107,7 @@ case ; branches - : branches COMMA branch { $$ = std::move($1); $1.push_back(std::move($3)); } + : branches branch { $$ = std::move($1); $1.push_back(std::move($2)); } | branch { $$ = std::vector(); $$.push_back(std::move($1));} ; @@ -129,7 +134,7 @@ constructors ; constructor - : UID lowercaseParams + : UID uppercaseParams { $$ = constructor_ptr(new constructor(std::move($1), std::move($2))); } ; diff --git a/03/type.cpp b/03/type.cpp index 91f87d3..f370a91 100644 --- a/03/type.cpp +++ b/03/type.cpp @@ -70,7 +70,7 @@ void type_mgr::unify(type_ptr l, type_ptr r) { throw 0; } -void type_mgr::bind(std::string s, type_ptr t) { +void type_mgr::bind(const std::string& s, type_ptr t) { type_var* other = dynamic_cast(t.get()); if(other && other->name == s) return; diff --git a/03/type.hpp b/03/type.hpp index ab358a4..e63040e 100644 --- a/03/type.hpp +++ b/03/type.hpp @@ -40,5 +40,5 @@ struct type_mgr { void unify(type_ptr l, type_ptr r); type_ptr resolve(type_ptr t, type_var*& var); - void bind(std::string s, type_ptr t); + void bind(const std::string& s, type_ptr t); };