diff --git a/03/type.cpp b/03/type.cpp index 206a9aa..91f87d3 100644 --- a/03/type.cpp +++ b/03/type.cpp @@ -3,18 +3,17 @@ #include std::string type_mgr::new_type_name() { - std::ostringstream oss; int temp = last_id++; + std::string str = ""; - do { - oss << (char) ('a' + (temp % 26)); - temp /= 26; - } while(temp); - std::string str = oss.str(); + while(temp != -1) { + str += (char) ('a' + (temp % 26)); + temp = temp / 26 - 1; + } std::reverse(str.begin(), str.end()); return str; -}; +} type_ptr type_mgr::new_type() { return type_ptr(new type_var(new_type_name())); @@ -23,3 +22,57 @@ type_ptr type_mgr::new_type() { type_ptr type_mgr::new_arrow_type() { return type_ptr(new type_arr(new_type(), new_type())); } + +type_ptr type_mgr::resolve(type_ptr t, type_var*& var) { + type_var* cast; + + var = nullptr; + while((cast = dynamic_cast(t.get()))) { + auto it = types.find(cast->name); + + if(it == types.end()) { + var = cast; + break; + } + t = it->second; + } + + return t; +} + +void type_mgr::unify(type_ptr l, type_ptr r) { + type_var* lvar; + type_var* rvar; + type_arr* larr; + type_arr* rarr; + type_base* lid; + type_base* rid; + + l = resolve(l, lvar); + r = resolve(r, rvar); + + if(lvar) { + bind(lvar->name, r); + return; + } else if(rvar) { + bind(rvar->name, l); + return; + } else if((larr = dynamic_cast(l.get())) && + (rarr = dynamic_cast(r.get()))) { + unify(larr->left, rarr->left); + unify(larr->right, rarr->right); + return; + } else if((lid = dynamic_cast(l.get())) && + (rid = dynamic_cast(r.get()))) { + if(lid->name == rid->name) return; + } + + throw 0; +} + +void type_mgr::bind(std::string s, type_ptr t) { + type_var* other = dynamic_cast(t.get()); + + if(other && other->name == s) return; + types[s] = t; +} diff --git a/03/type.hpp b/03/type.hpp index 752d778..ab358a4 100644 --- a/03/type.hpp +++ b/03/type.hpp @@ -15,11 +15,11 @@ struct type_var : public type { : name(std::move(n)) {} }; -struct type_id : public type { - int id; +struct type_base : public type { + std::string name; - type_id(int i) - : id(i) {} + type_base(std::string n) + : name(std::move(n)) {} }; struct type_arr : public type { @@ -32,8 +32,13 @@ struct type_arr : public type { struct type_mgr { int last_id = 0; + std::map types; std::string new_type_name(); type_ptr new_type(); type_ptr new_arrow_type(); + + void unify(type_ptr l, type_ptr r); + type_ptr resolve(type_ptr t, type_var*& var); + void bind(std::string s, type_ptr t); };