220 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
		
		
			
		
	
	
			220 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
								 | 
							
								#include "type.hpp"
							 | 
						||
| 
								 | 
							
								#include <ostream>
							 | 
						||
| 
								 | 
							
								#include <sstream>
							 | 
						||
| 
								 | 
							
								#include <algorithm>
							 | 
						||
| 
								 | 
							
								#include <vector>
							 | 
						||
| 
								 | 
							
								#include "error.hpp"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								bool type::is_arrow(const type_mgr& mgr) const { return false; }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_scheme::print(const type_mgr& mgr, std::ostream& to) const {
							 | 
						||
| 
								 | 
							
								    if(forall.size() != 0) {
							 | 
						||
| 
								 | 
							
								        to << "forall ";
							 | 
						||
| 
								 | 
							
								        for(auto& var : forall) {
							 | 
						||
| 
								 | 
							
								            to << var << " ";
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        to << ". ";
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    monotype->print(mgr, to);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type_ptr type_scheme::instantiate(type_mgr& mgr) const {
							 | 
						||
| 
								 | 
							
								    if(forall.size() == 0) return monotype;
							 | 
						||
| 
								 | 
							
								    std::map<std::string, type_ptr> subst;
							 | 
						||
| 
								 | 
							
								    for(auto& var : forall) {
							 | 
						||
| 
								 | 
							
								        subst[var] = mgr.new_type();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    return mgr.substitute(subst, monotype);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_var::print(const type_mgr& mgr, std::ostream& to) const {
							 | 
						||
| 
								 | 
							
								    auto it = mgr.types.find(name);
							 | 
						||
| 
								 | 
							
								    if(it != mgr.types.end()) {
							 | 
						||
| 
								 | 
							
								        it->second->print(mgr, to);
							 | 
						||
| 
								 | 
							
								    } else {
							 | 
						||
| 
								 | 
							
								        to << name;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								bool type_var::is_arrow(const type_mgr& mgr) const {
							 | 
						||
| 
								 | 
							
								    auto it = mgr.types.find(name);
							 | 
						||
| 
								 | 
							
								    if(it != mgr.types.end()) {
							 | 
						||
| 
								 | 
							
								        return it->second->is_arrow(mgr);
							 | 
						||
| 
								 | 
							
								    } else {
							 | 
						||
| 
								 | 
							
								        return false;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_base::print(const type_mgr& mgr, std::ostream& to) const {
							 | 
						||
| 
								 | 
							
								    to << name;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_arr::print(const type_mgr& mgr, std::ostream& to) const {
							 | 
						||
| 
								 | 
							
								    bool print_parenths = left->is_arrow(mgr);
							 | 
						||
| 
								 | 
							
								    if(print_parenths) to << "(";
							 | 
						||
| 
								 | 
							
								    left->print(mgr, to);
							 | 
						||
| 
								 | 
							
								    if(print_parenths) to << ")";
							 | 
						||
| 
								 | 
							
								    to << " -> ";
							 | 
						||
| 
								 | 
							
								    right->print(mgr, to);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								bool type_arr::is_arrow(const type_mgr& mgr) const {
							 | 
						||
| 
								 | 
							
								    return true;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_app::print(const type_mgr& mgr, std::ostream& to) const {
							 | 
						||
| 
								 | 
							
								    constructor->print(mgr, to);
							 | 
						||
| 
								 | 
							
								    to << "*";
							 | 
						||
| 
								 | 
							
								    for(auto& arg : arguments) {
							 | 
						||
| 
								 | 
							
								        to << " ";
							 | 
						||
| 
								 | 
							
								        arg->print(mgr, to);
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								std::string type_mgr::new_type_name() {
							 | 
						||
| 
								 | 
							
								    int temp = last_id++;
							 | 
						||
| 
								 | 
							
								    std::string str = "";
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    while(temp != -1) {
							 | 
						||
| 
								 | 
							
								        str += (char) ('a' + (temp % 26));
							 | 
						||
| 
								 | 
							
								        temp = temp / 26 - 1;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    std::reverse(str.begin(), str.end());
							 | 
						||
| 
								 | 
							
								    return str;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type_ptr type_mgr::new_type() {
							 | 
						||
| 
								 | 
							
								    return type_ptr(new type_var(new_type_name()));
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type_ptr type_mgr::new_arrow_type() {
							 | 
						||
| 
								 | 
							
								    return type_ptr(new type_arr(new_type(), new_type()));
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const {
							 | 
						||
| 
								 | 
							
								    type_var* cast;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    var = nullptr;
							 | 
						||
| 
								 | 
							
								    while((cast = dynamic_cast<type_var*>(t.get()))) {
							 | 
						||
| 
								 | 
							
								        auto it = types.find(cast->name);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if(it == types.end()) {
							 | 
						||
| 
								 | 
							
								            var = cast;
							 | 
						||
| 
								 | 
							
								            break;
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        t = it->second;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return t;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_mgr::unify(type_ptr l, type_ptr r) {
							 | 
						||
| 
								 | 
							
								    type_var *lvar, *rvar;
							 | 
						||
| 
								 | 
							
								    type_arr *larr, *rarr;
							 | 
						||
| 
								 | 
							
								    type_base *lid, *rid;
							 | 
						||
| 
								 | 
							
								    type_app *lapp, *rapp;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    l = resolve(l, lvar);
							 | 
						||
| 
								 | 
							
								    r = resolve(r, rvar);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if(lvar) {
							 | 
						||
| 
								 | 
							
								        bind(lvar->name, r);
							 | 
						||
| 
								 | 
							
								        return;
							 | 
						||
| 
								 | 
							
								    } else if(rvar) {
							 | 
						||
| 
								 | 
							
								        bind(rvar->name, l);
							 | 
						||
| 
								 | 
							
								        return;
							 | 
						||
| 
								 | 
							
								    } else if((larr = dynamic_cast<type_arr*>(l.get())) &&
							 | 
						||
| 
								 | 
							
								            (rarr = dynamic_cast<type_arr*>(r.get()))) {
							 | 
						||
| 
								 | 
							
								        unify(larr->left, rarr->left);
							 | 
						||
| 
								 | 
							
								        unify(larr->right, rarr->right);
							 | 
						||
| 
								 | 
							
								        return;
							 | 
						||
| 
								 | 
							
								    } else if((lid = dynamic_cast<type_base*>(l.get())) &&
							 | 
						||
| 
								 | 
							
								            (rid = dynamic_cast<type_base*>(r.get()))) {
							 | 
						||
| 
								 | 
							
								        if(lid->name == rid->name && lid->arity == rid->arity) return;
							 | 
						||
| 
								 | 
							
								    } else if((lapp = dynamic_cast<type_app*>(l.get())) &&
							 | 
						||
| 
								 | 
							
								            (rapp = dynamic_cast<type_app*>(r.get()))) {
							 | 
						||
| 
								 | 
							
								        unify(lapp->constructor, rapp->constructor);
							 | 
						||
| 
								 | 
							
								        auto left_it = lapp->arguments.begin();
							 | 
						||
| 
								 | 
							
								        auto right_it = rapp->arguments.begin();
							 | 
						||
| 
								 | 
							
								        while(left_it != lapp->arguments.end() &&
							 | 
						||
| 
								 | 
							
								                right_it != rapp->arguments.end()) {
							 | 
						||
| 
								 | 
							
								            unify(*left_it, *right_it);
							 | 
						||
| 
								 | 
							
								            left_it++, right_it++;
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								        return;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    throw unification_error(l, r);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type_ptr type_mgr::substitute(const std::map<std::string, type_ptr>& subst, const type_ptr& t) const {
							 | 
						||
| 
								 | 
							
								    type_ptr temp = t;
							 | 
						||
| 
								 | 
							
								    while(type_var* var = dynamic_cast<type_var*>(temp.get())) {
							 | 
						||
| 
								 | 
							
								        auto subst_it = subst.find(var->name);
							 | 
						||
| 
								 | 
							
								        if(subst_it != subst.end()) return subst_it->second;
							 | 
						||
| 
								 | 
							
								        auto var_it = types.find(var->name);
							 | 
						||
| 
								 | 
							
								        if(var_it == types.end()) return t;
							 | 
						||
| 
								 | 
							
								        temp = var_it->second;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if(type_arr* arr = dynamic_cast<type_arr*>(temp.get())) {
							 | 
						||
| 
								 | 
							
								        auto left_result = substitute(subst, arr->left);
							 | 
						||
| 
								 | 
							
								        auto right_result = substitute(subst, arr->right);
							 | 
						||
| 
								 | 
							
								        if(left_result == arr->left && right_result == arr->right) return t;
							 | 
						||
| 
								 | 
							
								        return type_ptr(new type_arr(left_result, right_result));
							 | 
						||
| 
								 | 
							
								    } else if(type_app* app = dynamic_cast<type_app*>(temp.get())) {
							 | 
						||
| 
								 | 
							
								        auto constructor_result = substitute(subst, app->constructor);
							 | 
						||
| 
								 | 
							
								        bool arg_changed = false;
							 | 
						||
| 
								 | 
							
								        std::vector<type_ptr> new_args;
							 | 
						||
| 
								 | 
							
								        for(auto& arg : app->arguments) {
							 | 
						||
| 
								 | 
							
								            auto arg_result = substitute(subst, arg);
							 | 
						||
| 
								 | 
							
								            arg_changed |= arg_result != arg;
							 | 
						||
| 
								 | 
							
								            new_args.push_back(std::move(arg_result));
							 | 
						||
| 
								 | 
							
								        }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if(constructor_result == app->constructor && !arg_changed) return t;
							 | 
						||
| 
								 | 
							
								        type_app* new_app = new type_app(std::move(constructor_result));
							 | 
						||
| 
								 | 
							
								        std::swap(new_app->arguments, new_args);
							 | 
						||
| 
								 | 
							
								        return type_ptr(new_app);
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    return t;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_mgr::bind(const std::string& s, type_ptr t) {
							 | 
						||
| 
								 | 
							
								    type_var* other = dynamic_cast<type_var*>(t.get());
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    if(other && other->name == s) return;
							 | 
						||
| 
								 | 
							
								    types[s] = t;
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_mgr::find_free(const type_ptr& t, std::set<std::string>& into) const {
							 | 
						||
| 
								 | 
							
								    type_var* var;
							 | 
						||
| 
								 | 
							
								    type_ptr resolved = resolve(t, var);
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if(var) {
							 | 
						||
| 
								 | 
							
								        into.insert(var->name);
							 | 
						||
| 
								 | 
							
								    } else if(type_arr* arr = dynamic_cast<type_arr*>(resolved.get())) {
							 | 
						||
| 
								 | 
							
								        find_free(arr->left, into);
							 | 
						||
| 
								 | 
							
								        find_free(arr->right, into);
							 | 
						||
| 
								 | 
							
								    } else if(type_app* app = dynamic_cast<type_app*>(resolved.get())) {
							 | 
						||
| 
								 | 
							
								        find_free(app->constructor, into);
							 | 
						||
| 
								 | 
							
								        for(auto& arg : app->arguments) find_free(arg, into);
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								void type_mgr::find_free(const type_scheme_ptr& t, std::set<std::string>& into) const {
							 | 
						||
| 
								 | 
							
								    std::set<std::string> monotype_free;
							 | 
						||
| 
								 | 
							
								    type_mgr limited_mgr;
							 | 
						||
| 
								 | 
							
								    for(auto& binding : types) {
							 | 
						||
| 
								 | 
							
								        auto existing_position = std::find(t->forall.begin(), t->forall.end(), binding.first);
							 | 
						||
| 
								 | 
							
								        if(existing_position != t->forall.end()) continue;
							 | 
						||
| 
								 | 
							
								        limited_mgr.types[binding.first] = binding.second;
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    limited_mgr.find_free(t->monotype, monotype_free);
							 | 
						||
| 
								 | 
							
								    for(auto& not_free : t->forall) {
							 | 
						||
| 
								 | 
							
								        monotype_free.erase(not_free);
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								    into.insert(monotype_free.begin(), monotype_free.end());
							 | 
						||
| 
								 | 
							
								}
							 |