diff --git a/code/compiler/11/parsed_type.cpp b/code/compiler/11/parsed_type.cpp index ba6fddd..4644d52 100644 --- a/code/compiler/11/parsed_type.cpp +++ b/code/compiler/11/parsed_type.cpp @@ -21,7 +21,7 @@ type_ptr parsed_type_app::to_type( type_ptr parsed_type_var::to_type( const std::set& vars, const type_env& e) const { - if(vars.find(var) != vars.end()) throw 0; + if(vars.find(var) == vars.end()) throw 0; return type_ptr(new type_var(var)); } diff --git a/code/compiler/11/type.cpp b/code/compiler/11/type.cpp index 6c381b8..8022966 100644 --- a/code/compiler/11/type.cpp +++ b/code/compiler/11/type.cpp @@ -16,43 +16,13 @@ void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { monotype->print(mgr, to); } -type_ptr substitute(const type_mgr& mgr, const std::map& subst, const type_ptr& t) { - type_var* var; - type_ptr resolved = mgr.resolve(t, var); - if(var) { - auto subst_it = subst.find(var->name); - if(subst_it == subst.end()) return resolved; - return subst_it->second; - } else if(type_arr* arr = dynamic_cast(t.get())) { - auto left_result = substitute(mgr, subst, arr->left); - auto right_result = substitute(mgr, subst, arr->right); - if(left_result == arr->left && right_result == arr->right) return t; - return type_ptr(new type_arr(left_result, right_result)); - } else if(type_app* app = dynamic_cast(t.get())) { - auto constructor_result = substitute(mgr, subst, app->constructor); - bool arg_changed = false; - std::vector new_args; - for(auto& arg : app->arguments) { - auto arg_result = substitute(mgr, subst, arg); - arg_changed |= arg_result != arg; - new_args.push_back(std::move(arg_result)); - } - - if(constructor_result == app->constructor && !arg_changed) return t; - type_app* new_app = new type_app(std::move(constructor_result)); - std::swap(new_app->arguments, new_args); - return type_ptr(new_app); - } - return t; -} - type_ptr type_scheme::instantiate(type_mgr& mgr) const { if(forall.size() == 0) return monotype; std::map subst; for(auto& var : forall) { subst[var] = mgr.new_type(); } - return substitute(mgr, subst, monotype); + return mgr.substitute(subst, monotype); } void type_var::print(const type_mgr& mgr, std::ostream& to) const { @@ -161,6 +131,39 @@ void type_mgr::unify(type_ptr l, type_ptr r) { throw unification_error(l, r); } +type_ptr type_mgr::substitute(const std::map& subst, const type_ptr& t) const { + type_ptr temp = t; + while(type_var* var = dynamic_cast(temp.get())) { + auto subst_it = subst.find(var->name); + if(subst_it != subst.end()) return subst_it->second; + auto var_it = types.find(var->name); + if(var_it == types.end()) return t; + temp = var_it->second; + } + + if(type_arr* arr = dynamic_cast(temp.get())) { + auto left_result = substitute(subst, arr->left); + auto right_result = substitute(subst, arr->right); + if(left_result == arr->left && right_result == arr->right) return t; + return type_ptr(new type_arr(left_result, right_result)); + } else if(type_app* app = dynamic_cast(temp.get())) { + auto constructor_result = substitute(subst, app->constructor); + bool arg_changed = false; + std::vector new_args; + for(auto& arg : app->arguments) { + auto arg_result = substitute(subst, arg); + arg_changed |= arg_result != arg; + new_args.push_back(std::move(arg_result)); + } + + if(constructor_result == app->constructor && !arg_changed) return t; + type_app* new_app = new type_app(std::move(constructor_result)); + std::swap(new_app->arguments, new_args); + return type_ptr(new_app); + } + return t; +} + void type_mgr::bind(const std::string& s, type_ptr t) { type_var* other = dynamic_cast(t.get()); diff --git a/code/compiler/11/type.hpp b/code/compiler/11/type.hpp index cd3bd45..6c524aa 100644 --- a/code/compiler/11/type.hpp +++ b/code/compiler/11/type.hpp @@ -86,6 +86,9 @@ struct type_mgr { type_ptr new_arrow_type(); void unify(type_ptr l, type_ptr r); + type_ptr substitute( + const std::map& subst, + const type_ptr& t) const; type_ptr resolve(type_ptr t, type_var*& var) const; void bind(const std::string& s, type_ptr t); void find_free(const type_ptr& t, std::set& into) const;