diff --git a/12/type.cpp b/12/type.cpp index 237c9fb..30852f8 100644 --- a/12/type.cpp +++ b/12/type.cpp @@ -5,6 +5,8 @@ #include #include "error.hpp" +bool type::is_arrow(const type_mgr& mgr) const { return false; } + void type_scheme::print(const type_mgr& mgr, std::ostream& to) const { if(forall.size() != 0) { to << "forall "; @@ -34,20 +36,35 @@ void type_var::print(const type_mgr& mgr, std::ostream& to) const { } } +bool type_var::is_arrow(const type_mgr& mgr) const { + auto it = mgr.types.find(name); + if(it != mgr.types.end()) { + return it->second->is_arrow(mgr); + } else { + return false; + } +} + void type_base::print(const type_mgr& mgr, std::ostream& to) const { to << name; } void type_arr::print(const type_mgr& mgr, std::ostream& to) const { + bool print_parenths = left->is_arrow(mgr); + if(print_parenths) to << "("; left->print(mgr, to); - to << " -> ("; + if(print_parenths) to << ")"; + to << " -> "; right->print(mgr, to); - to << ")"; +} + +bool type_arr::is_arrow(const type_mgr& mgr) const { + return true; } void type_app::print(const type_mgr& mgr, std::ostream& to) const { constructor->print(mgr, to); - to << "* "; + to << "*"; for(auto& arg : arguments) { to << " "; arg->print(mgr, to); diff --git a/12/type.hpp b/12/type.hpp index abf2b55..3ab4714 100644 --- a/12/type.hpp +++ b/12/type.hpp @@ -11,6 +11,7 @@ struct type { virtual ~type() = default; virtual void print(const type_mgr& mgr, std::ostream& to) const = 0; + virtual bool is_arrow(const type_mgr& mgr) const; }; using type_ptr = std::shared_ptr; @@ -34,6 +35,7 @@ struct type_var : public type { : name(std::move(n)) {} void print(const type_mgr& mgr, std::ostream& to) const; + bool is_arrow(const type_mgr& mgr) const; }; struct type_base : public type { @@ -65,6 +67,7 @@ struct type_arr : public type { : left(std::move(l)), right(std::move(r)) {} void print(const type_mgr& mgr, std::ostream& to) const; + bool is_arrow(const type_mgr& mgr) const; }; struct type_app : public type {