diff --git a/13/ast.cpp b/13/ast.cpp index 10db5cb..3f6b9f1 100644 --- a/13/ast.cpp +++ b/13/ast.cpp @@ -99,11 +99,19 @@ type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) { type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op), loc); - type_ptr return_type = mgr.new_type(); - type_ptr arrow_one = type_ptr(new type_arr(rtype, return_type)); - type_ptr arrow_two = type_ptr(new type_arr(ltype, arrow_one)); + // For better type errors, we first require binary function, + // and only then unify each argument. This way, we can + // precisely point out which argument is "wrong". - mgr.unify(arrow_two, ftype); + type_ptr return_type = mgr.new_type(); + type_ptr second_type = mgr.new_type(); + type_ptr first_type = mgr.new_type(); + type_ptr arrow_one = type_ptr(new type_arr(second_type, return_type)); + type_ptr arrow_two = type_ptr(new type_arr(first_type, arrow_one)); + + mgr.unify(ftype, arrow_two, loc); + mgr.unify(first_type, ltype, left->loc); + mgr.unify(second_type, rtype, right->loc); return return_type; } @@ -140,7 +148,7 @@ type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) { type_ptr return_type = mgr.new_type(); type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); - mgr.unify(arrow, ltype); + mgr.unify(arrow, ltype, left->loc); return return_type; } @@ -190,7 +198,7 @@ type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) { type_env_ptr new_env = type_scope(env); branch->pat->typecheck(case_type, mgr, new_env); type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); - mgr.unify(branch_type, curr_branch_type); + mgr.unify(curr_branch_type, branch_type, branch->expr->loc); } input_type = mgr.resolve(case_type, var); @@ -361,7 +369,7 @@ type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) { full_type = type_ptr(new type_arr(std::move(param_type), full_type)); } - mgr.unify(return_type, body->typecheck(mgr, var_env)); + mgr.unify(return_type, body->typecheck(mgr, var_env), body->loc); return full_type; } @@ -433,5 +441,5 @@ void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) con constructor_type = arr->right; } - mgr.unify(t, constructor_type); + mgr.unify(constructor_type, t, loc); } diff --git a/13/type.cpp b/13/type.cpp index c4ff35a..856c463 100644 --- a/13/type.cpp +++ b/13/type.cpp @@ -109,7 +109,7 @@ type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const { return t; } -void type_mgr::unify(type_ptr l, type_ptr r) { +void type_mgr::unify(type_ptr l, type_ptr r, const std::optional& loc) { type_var *lvar, *rvar; type_arr *larr, *rarr; type_base *lid, *rid; @@ -126,26 +126,26 @@ void type_mgr::unify(type_ptr l, type_ptr r) { return; } else if((larr = dynamic_cast(l.get())) && (rarr = dynamic_cast(r.get()))) { - unify(larr->left, rarr->left); - unify(larr->right, rarr->right); + unify(larr->left, rarr->left, loc); + unify(larr->right, rarr->right, loc); return; } else if((lid = dynamic_cast(l.get())) && (rid = dynamic_cast(r.get()))) { if(lid->name == rid->name && lid->arity == rid->arity) return; } else if((lapp = dynamic_cast(l.get())) && (rapp = dynamic_cast(r.get()))) { - unify(lapp->constructor, rapp->constructor); + unify(lapp->constructor, rapp->constructor, loc); auto left_it = lapp->arguments.begin(); auto right_it = rapp->arguments.begin(); while(left_it != lapp->arguments.end() && right_it != rapp->arguments.end()) { - unify(*left_it, *right_it); + unify(*left_it, *right_it, loc); left_it++, right_it++; } return; } - throw unification_error(l, r); + throw unification_error(l, r, loc); } type_ptr type_mgr::substitute(const std::map& subst, const type_ptr& t) const { diff --git a/13/type.hpp b/13/type.hpp index 3ab4714..0edae1c 100644 --- a/13/type.hpp +++ b/13/type.hpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include "location.hh" struct type_mgr; @@ -88,7 +90,7 @@ struct type_mgr { type_ptr new_type(); type_ptr new_arrow_type(); - void unify(type_ptr l, type_ptr r); + void unify(type_ptr l, type_ptr r, const std::optional& loc = std::nullopt); type_ptr substitute( const std::map& subst, const type_ptr& t) const;