#include "type.hpp" #include #include #include #include #include "error.hpp" void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { if(forall.size() != 0) { to << "forall "; for(auto& var : forall) { to << var << " "; } to << ". "; } monotype->print(mgr, to); } type_ptr type_scheme::instantiate(type_mgr& mgr) const { if(forall.size() == 0) return monotype; std::map subst; for(auto& var : forall) { subst[var] = mgr.new_type(); } return mgr.substitute(subst, monotype); } void type_var::print(const type_mgr& mgr, std::ostream& to) const { auto it = mgr.types.find(name); if(it != mgr.types.end()) { it->second->print(mgr, to); } else { to << name; } } void type_base::print(const type_mgr& mgr, std::ostream& to) const { to << name; } void type_internal::print(const type_mgr& mgr, std::ostream& to) const { to << "!" << name; } void type_arr::print(const type_mgr& mgr, std::ostream& to) const { type_var* var; bool print_parenths = dynamic_cast(mgr.resolve(left, var).get()) != nullptr; if(print_parenths) to << "("; left->print(mgr, to); if(print_parenths) to << ")"; to << " -> "; right->print(mgr, to); } void type_app::print(const type_mgr& mgr, std::ostream& to) const { constructor->print(mgr, to); to << "*"; for(auto& arg : arguments) { to << " "; arg->print(mgr, to); } } std::string type_mgr::new_type_name() { int temp = last_id++; std::string str = ""; while(temp != -1) { str += (char) ('a' + (temp % 26)); temp = temp / 26 - 1; } std::reverse(str.begin(), str.end()); return str; } type_ptr type_mgr::new_type() { return type_ptr(new type_var(new_type_name())); } type_ptr type_mgr::new_arrow_type() { return type_ptr(new type_arr(new_type(), new_type())); } type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const { type_var* cast; var = nullptr; while((cast = dynamic_cast(t.get()))) { auto it = types.find(cast->name); if(it == types.end()) { var = cast; break; } t = it->second; } return t; } void type_mgr::unify(type_ptr l, type_ptr r, const std::optional& loc) { type_var *lvar, *rvar; type_arr *larr, *rarr; type_base *lid, *rid; type_app *lapp, *rapp; l = resolve(l, lvar); r = resolve(r, rvar); if(lvar) { bind(lvar->name, r); return; } else if(rvar) { bind(rvar->name, l); return; } else if((larr = dynamic_cast(l.get())) && (rarr = dynamic_cast(r.get()))) { unify(larr->left, rarr->left, loc); unify(larr->right, rarr->right, loc); return; } else if((lid = dynamic_cast(l.get())) && (rid = dynamic_cast(r.get()))) { if(lid->name == rid->name && lid->arity == rid->arity && lid->is_internal() == rid->is_internal()) return; } else if((lapp = dynamic_cast(l.get())) && (rapp = dynamic_cast(r.get()))) { unify(lapp->constructor, rapp->constructor, loc); auto left_it = lapp->arguments.begin(); auto right_it = rapp->arguments.begin(); while(left_it != lapp->arguments.end() && right_it != rapp->arguments.end()) { unify(*left_it, *right_it, loc); left_it++, right_it++; } return; } throw unification_error(l, r, loc); } type_ptr type_mgr::substitute(const std::map& subst, const type_ptr& t) const { type_ptr temp = t; while(type_var* var = dynamic_cast(temp.get())) { auto subst_it = subst.find(var->name); if(subst_it != subst.end()) return subst_it->second; auto var_it = types.find(var->name); if(var_it == types.end()) return t; temp = var_it->second; } if(type_arr* arr = dynamic_cast(temp.get())) { auto left_result = substitute(subst, arr->left); auto right_result = substitute(subst, arr->right); if(left_result == arr->left && right_result == arr->right) return t; return type_ptr(new type_arr(left_result, right_result)); } else if(type_app* app = dynamic_cast(temp.get())) { auto constructor_result = substitute(subst, app->constructor); bool arg_changed = false; std::vector new_args; for(auto& arg : app->arguments) { auto arg_result = substitute(subst, arg); arg_changed |= arg_result != arg; new_args.push_back(std::move(arg_result)); } if(constructor_result == app->constructor && !arg_changed) return t; type_app* new_app = new type_app(std::move(constructor_result)); std::swap(new_app->arguments, new_args); return type_ptr(new_app); } return t; } void type_mgr::bind(const std::string& s, type_ptr t) { type_var* other = dynamic_cast(t.get()); if(other && other->name == s) return; types[s] = t; } void type_mgr::find_free(const type_ptr& t, std::set& into) const { type_var* var; type_ptr resolved = resolve(t, var); if(var) { into.insert(var->name); } else if(type_arr* arr = dynamic_cast(resolved.get())) { find_free(arr->left, into); find_free(arr->right, into); } else if(type_app* app = dynamic_cast(resolved.get())) { find_free(app->constructor, into); for(auto& arg : app->arguments) find_free(arg, into); } } void type_mgr::find_free(const type_scheme_ptr& t, std::set& into) const { std::set monotype_free; type_mgr limited_mgr; for(auto& binding : types) { auto existing_position = std::find(t->forall.begin(), t->forall.end(), binding.first); if(existing_position != t->forall.end()) continue; limited_mgr.types[binding.first] = binding.second; } limited_mgr.find_free(t->monotype, monotype_free); for(auto& not_free : t->forall) { monotype_free.erase(not_free); } into.insert(monotype_free.begin(), monotype_free.end()); }