Make substitution replace types at every lookup step

This commit is contained in:
Danila Fedorin 2020-04-13 17:59:57 -07:00
parent 2a936927a1
commit 207efc7ded
3 changed files with 38 additions and 32 deletions

View File

@ -21,7 +21,7 @@ type_ptr parsed_type_app::to_type(
type_ptr parsed_type_var::to_type( type_ptr parsed_type_var::to_type(
const std::set<std::string>& vars, const std::set<std::string>& vars,
const type_env& e) const { const type_env& e) const {
if(vars.find(var) != vars.end()) throw 0; if(vars.find(var) == vars.end()) throw 0;
return type_ptr(new type_var(var)); return type_ptr(new type_var(var));
} }

View File

@ -16,43 +16,13 @@ void type_scheme::print(const type_mgr& mgr, std::ostream& to) const {
monotype->print(mgr, to); monotype->print(mgr, to);
} }
type_ptr substitute(const type_mgr& mgr, const std::map<std::string, type_ptr>& subst, const type_ptr& t) {
type_var* var;
type_ptr resolved = mgr.resolve(t, var);
if(var) {
auto subst_it = subst.find(var->name);
if(subst_it == subst.end()) return resolved;
return subst_it->second;
} else if(type_arr* arr = dynamic_cast<type_arr*>(t.get())) {
auto left_result = substitute(mgr, subst, arr->left);
auto right_result = substitute(mgr, subst, arr->right);
if(left_result == arr->left && right_result == arr->right) return t;
return type_ptr(new type_arr(left_result, right_result));
} else if(type_app* app = dynamic_cast<type_app*>(t.get())) {
auto constructor_result = substitute(mgr, subst, app->constructor);
bool arg_changed = false;
std::vector<type_ptr> new_args;
for(auto& arg : app->arguments) {
auto arg_result = substitute(mgr, subst, arg);
arg_changed |= arg_result != arg;
new_args.push_back(std::move(arg_result));
}
if(constructor_result == app->constructor && !arg_changed) return t;
type_app* new_app = new type_app(std::move(constructor_result));
std::swap(new_app->arguments, new_args);
return type_ptr(new_app);
}
return t;
}
type_ptr type_scheme::instantiate(type_mgr& mgr) const { type_ptr type_scheme::instantiate(type_mgr& mgr) const {
if(forall.size() == 0) return monotype; if(forall.size() == 0) return monotype;
std::map<std::string, type_ptr> subst; std::map<std::string, type_ptr> subst;
for(auto& var : forall) { for(auto& var : forall) {
subst[var] = mgr.new_type(); subst[var] = mgr.new_type();
} }
return substitute(mgr, subst, monotype); return mgr.substitute(subst, monotype);
} }
void type_var::print(const type_mgr& mgr, std::ostream& to) const { void type_var::print(const type_mgr& mgr, std::ostream& to) const {
@ -161,6 +131,39 @@ void type_mgr::unify(type_ptr l, type_ptr r) {
throw unification_error(l, r); throw unification_error(l, r);
} }
type_ptr type_mgr::substitute(const std::map<std::string, type_ptr>& subst, const type_ptr& t) const {
type_ptr temp = t;
while(type_var* var = dynamic_cast<type_var*>(temp.get())) {
auto subst_it = subst.find(var->name);
if(subst_it != subst.end()) return subst_it->second;
auto var_it = types.find(var->name);
if(var_it == types.end()) return t;
temp = var_it->second;
}
if(type_arr* arr = dynamic_cast<type_arr*>(temp.get())) {
auto left_result = substitute(subst, arr->left);
auto right_result = substitute(subst, arr->right);
if(left_result == arr->left && right_result == arr->right) return t;
return type_ptr(new type_arr(left_result, right_result));
} else if(type_app* app = dynamic_cast<type_app*>(temp.get())) {
auto constructor_result = substitute(subst, app->constructor);
bool arg_changed = false;
std::vector<type_ptr> new_args;
for(auto& arg : app->arguments) {
auto arg_result = substitute(subst, arg);
arg_changed |= arg_result != arg;
new_args.push_back(std::move(arg_result));
}
if(constructor_result == app->constructor && !arg_changed) return t;
type_app* new_app = new type_app(std::move(constructor_result));
std::swap(new_app->arguments, new_args);
return type_ptr(new_app);
}
return t;
}
void type_mgr::bind(const std::string& s, type_ptr t) { void type_mgr::bind(const std::string& s, type_ptr t) {
type_var* other = dynamic_cast<type_var*>(t.get()); type_var* other = dynamic_cast<type_var*>(t.get());

View File

@ -86,6 +86,9 @@ struct type_mgr {
type_ptr new_arrow_type(); type_ptr new_arrow_type();
void unify(type_ptr l, type_ptr r); void unify(type_ptr l, type_ptr r);
type_ptr substitute(
const std::map<std::string, type_ptr>& subst,
const type_ptr& t) const;
type_ptr resolve(type_ptr t, type_var*& var) const; type_ptr resolve(type_ptr t, type_var*& var) const;
void bind(const std::string& s, type_ptr t); void bind(const std::string& s, type_ptr t);
void find_free(const type_ptr& t, std::set<std::string>& into) const; void find_free(const type_ptr& t, std::set<std::string>& into) const;