diff --git a/code/compiler/13/main.cpp b/code/compiler/13/main.cpp index ec21e2b..de2946f 100644 --- a/code/compiler/13/main.cpp +++ b/code/compiler/13/main.cpp @@ -25,7 +25,7 @@ void yy::parser::error(const yy::location& loc, const std::string& msg) { void typecheck_program( definition_group& defs, type_mgr& mgr, type_env_ptr& env) { - type_ptr int_type = type_ptr(new type_base("Int")); + type_ptr int_type = type_ptr(new type_internal("Int")); env->bind_type("Int", int_type); type_ptr int_type_app = type_ptr(new type_app(int_type)); diff --git a/code/compiler/13/type.cpp b/code/compiler/13/type.cpp index 856c463..4dd18e1 100644 --- a/code/compiler/13/type.cpp +++ b/code/compiler/13/type.cpp @@ -5,8 +5,6 @@ #include #include "error.hpp" -bool type::is_arrow(const type_mgr& mgr) const { return false; } - void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { if(forall.size() != 0) { to << "forall "; @@ -36,21 +34,17 @@ void type_var::print(const type_mgr& mgr, std::ostream& to) const { } } -bool type_var::is_arrow(const type_mgr& mgr) const { - auto it = mgr.types.find(name); - if(it != mgr.types.end()) { - return it->second->is_arrow(mgr); - } else { - return false; - } -} - void type_base::print(const type_mgr& mgr, std::ostream& to) const { to << name; } +void type_internal::print(const type_mgr& mgr, std::ostream& to) const { + to << "!" << name; +} + void type_arr::print(const type_mgr& mgr, std::ostream& to) const { - bool print_parenths = left->is_arrow(mgr); + type_var* var; + bool print_parenths = dynamic_cast(mgr.resolve(left, var).get()) != nullptr; if(print_parenths) to << "("; left->print(mgr, to); if(print_parenths) to << ")"; @@ -58,10 +52,6 @@ void type_arr::print(const type_mgr& mgr, std::ostream& to) const { right->print(mgr, to); } -bool type_arr::is_arrow(const type_mgr& mgr) const { - return true; -} - void type_app::print(const type_mgr& mgr, std::ostream& to) const { constructor->print(mgr, to); to << "*"; @@ -131,7 +121,10 @@ void type_mgr::unify(type_ptr l, type_ptr r, const std::optional& return; } else if((lid = dynamic_cast(l.get())) && (rid = dynamic_cast(r.get()))) { - if(lid->name == rid->name && lid->arity == rid->arity) return; + if(lid->name == rid->name && + lid->arity == rid->arity && + lid->is_internal() == rid->is_internal()) + return; } else if((lapp = dynamic_cast(l.get())) && (rapp = dynamic_cast(r.get()))) { unify(lapp->constructor, rapp->constructor, loc); diff --git a/code/compiler/13/type.hpp b/code/compiler/13/type.hpp index 0edae1c..af417b6 100644 --- a/code/compiler/13/type.hpp +++ b/code/compiler/13/type.hpp @@ -13,7 +13,6 @@ struct type { virtual ~type() = default; virtual void print(const type_mgr& mgr, std::ostream& to) const = 0; - virtual bool is_arrow(const type_mgr& mgr) const; }; using type_ptr = std::shared_ptr; @@ -37,7 +36,6 @@ struct type_var : public type { : name(std::move(n)) {} void print(const type_mgr& mgr, std::ostream& to) const; - bool is_arrow(const type_mgr& mgr) const; }; struct type_base : public type { @@ -48,6 +46,17 @@ struct type_base : public type { : name(std::move(n)), arity(a) {} void print(const type_mgr& mgr, std::ostream& to) const; + + virtual bool is_internal() const { return false; } +}; + +struct type_internal : public type_base { + type_internal(std::string n, int32_t a = 0) + : type_base(std::move(n), a) {} + + void print(const type_mgr& mgr, std::ostream& to) const; + + bool is_internal() const { return true; } }; struct type_data : public type_base { @@ -69,7 +78,6 @@ struct type_arr : public type { : left(std::move(l)), right(std::move(r)) {} void print(const type_mgr& mgr, std::ostream& to) const; - bool is_arrow(const type_mgr& mgr) const; }; struct type_app : public type {