diff --git a/code/compiler/10/ast.cpp b/code/compiler/10/ast.cpp index 9ea0822..b55c347 100644 --- a/code/compiler/10/ast.cpp +++ b/code/compiler/10/ast.cpp @@ -36,7 +36,7 @@ void ast_lid::find_free(type_mgr& mgr, type_env_ptr& env, std::set& } type_ptr ast_lid::typecheck(type_mgr& mgr) { - return env->lookup(id); + return env->lookup(id)->instantiate(mgr); } void ast_lid::compile(const env_ptr& env, std::vector& into) const { @@ -56,7 +56,7 @@ void ast_uid::find_free(type_mgr& mgr, type_env_ptr& env, std::set& } type_ptr ast_uid::typecheck(type_mgr& mgr) { - return env->lookup(id); + return env->lookup(id)->instantiate(mgr); } void ast_uid::compile(const env_ptr& env, std::vector& into) const { @@ -79,7 +79,7 @@ void ast_binop::find_free(type_mgr& mgr, type_env_ptr& env, std::settypecheck(mgr); type_ptr rtype = right->typecheck(mgr); - type_ptr ftype = env->lookup(op_name(op)); + type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); type_ptr return_type = mgr.new_type(); @@ -232,7 +232,7 @@ void pattern_var::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { } void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { - mgr.unify(env->lookup(var), t); + mgr.unify(env->lookup(var)->instantiate(mgr), t); } void pattern_constr::print(std::ostream& to) const { @@ -249,7 +249,7 @@ void pattern_constr::insert_bindings(type_mgr& mgr, type_env_ptr& env) const { } void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { - type_ptr constructor_type = env->lookup(constr); + type_ptr constructor_type = env->lookup(constr)->instantiate(mgr); if(!constructor_type) { throw type_error(std::string("pattern using unknown constructor ") + constr); } @@ -258,7 +258,7 @@ void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) con type_arr* arr = dynamic_cast(constructor_type.get()); if(!arr) throw type_error("too many parameters in constructor pattern"); - mgr.unify(env->lookup(param), arr->left); + mgr.unify(env->lookup(param)->instantiate(mgr), arr->left); constructor_type = arr->right; } diff --git a/code/compiler/10/main.cpp b/code/compiler/10/main.cpp index 89d32d6..dabc108 100644 --- a/code/compiler/10/main.cpp +++ b/code/compiler/10/main.cpp @@ -68,6 +68,7 @@ void typecheck_program( for(auto& def_defnn_name : group->members) { auto& def_defn = defs_defn.find(def_defnn_name)->second; def_defn->typecheck(mgr); + env->generalize(def_defnn_name, mgr); } } diff --git a/code/compiler/10/type.cpp b/code/compiler/10/type.cpp index f5868d5..d355e87 100644 --- a/code/compiler/10/type.cpp +++ b/code/compiler/10/type.cpp @@ -1,8 +1,45 @@ #include "type.hpp" +#include #include #include #include "error.hpp" +void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { + if(forall.size() != 0) { + to << "forall "; + for(auto& var : forall) { + to << var << " "; + } + to << ". "; + } + monotype->print(mgr, to); +} + +type_ptr substitute(const type_mgr& mgr, const std::map& subst, const type_ptr& t) { + type_var* var; + type_ptr resolved = mgr.resolve(t, var); + if(var) { + auto subst_it = subst.find(var->name); + if(subst_it == subst.end()) return resolved; + return subst_it->second; + } else if(type_arr* arr = dynamic_cast(t.get())) { + auto left_result = substitute(mgr, subst, arr->left); + auto right_result = substitute(mgr, subst, arr->right); + if(left_result == arr->left && right_result == arr->right) return t; + return type_ptr(new type_arr(left_result, right_result)); + } + return t; +} + +type_ptr type_scheme::instantiate(type_mgr& mgr) const { + if(forall.size() == 0) return monotype; + std::map subst; + for(auto& var : forall) { + subst[var] = mgr.new_type(); + } + return substitute(mgr, subst, monotype); +} + void type_var::print(const type_mgr& mgr, std::ostream& to) const { auto it = mgr.types.find(name); if(it != mgr.types.end()) { @@ -97,3 +134,15 @@ void type_mgr::bind(const std::string& s, type_ptr t) { if(other && other->name == s) return; types[s] = t; } + +void type_mgr::find_free(const type_ptr& t, std::set& into) const { + type_var* var; + type_ptr resolved = resolve(t, var); + + if(var) { + into.insert(var->name); + } else if(type_arr* arr = dynamic_cast(resolved.get())) { + find_free(arr->left, into); + find_free(arr->right, into); + } +} diff --git a/code/compiler/10/type.hpp b/code/compiler/10/type.hpp index 09e525f..1d387d3 100644 --- a/code/compiler/10/type.hpp +++ b/code/compiler/10/type.hpp @@ -1,6 +1,8 @@ #pragma once #include #include +#include +#include struct type_mgr; @@ -12,6 +14,18 @@ struct type { using type_ptr = std::shared_ptr; +struct type_scheme { + std::vector forall; + type_ptr monotype; + + type_scheme(type_ptr type) : forall(), monotype(std::move(type)) {} + + void print(const type_mgr& mgr, std::ostream& to) const; + type_ptr instantiate(type_mgr& mgr) const; +}; + +using type_scheme_ptr = std::shared_ptr; + struct type_var : public type { std::string name; @@ -62,4 +76,5 @@ struct type_mgr { void unify(type_ptr l, type_ptr r); type_ptr resolve(type_ptr t, type_var*& var) const; void bind(const std::string& s, type_ptr t); + void find_free(const type_ptr& t, std::set& into) const; }; diff --git a/code/compiler/10/type_env.cpp b/code/compiler/10/type_env.cpp index 1d5d757..0418af6 100644 --- a/code/compiler/10/type_env.cpp +++ b/code/compiler/10/type_env.cpp @@ -1,6 +1,7 @@ #include "type_env.hpp" +#include "type.hpp" -type_ptr type_env::lookup(const std::string& name) const { +type_scheme_ptr type_env::lookup(const std::string& name) const { auto it = names.find(name); if(it != names.end()) return it->second; if(parent) return parent->lookup(name); @@ -8,9 +9,25 @@ type_ptr type_env::lookup(const std::string& name) const { } void type_env::bind(const std::string& name, type_ptr t) { + names[name] = type_scheme_ptr(new type_scheme(t)); +} + +void type_env::bind(const std::string& name, type_scheme_ptr t) { names[name] = t; } +void type_env::generalize(const std::string& name, type_mgr& mgr) { + auto names_it = names.find(name); + if(names_it == names.end()) throw 0; + if(names_it->second->forall.size() > 0) throw 0; + + std::set free_variables; + mgr.find_free(names_it->second->monotype, free_variables); + for(auto& free : free_variables) { + names_it->second->forall.push_back(free); + } +} + type_env_ptr type_scope(type_env_ptr parent) { return type_env_ptr(new type_env(std::move(parent))); } diff --git a/code/compiler/10/type_env.hpp b/code/compiler/10/type_env.hpp index fe01413..69bc703 100644 --- a/code/compiler/10/type_env.hpp +++ b/code/compiler/10/type_env.hpp @@ -7,13 +7,15 @@ using type_env_ptr = std::shared_ptr; struct type_env { type_env_ptr parent; - std::map names; + std::map names; type_env(type_env_ptr p) : parent(std::move(p)) {} type_env() : type_env(nullptr) {} - type_ptr lookup(const std::string& name) const; + type_scheme_ptr lookup(const std::string& name) const; void bind(const std::string& name, type_ptr t); + void bind(const std::string& name, type_scheme_ptr t); + void generalize(const std::string& name, type_mgr& mgr); };