diff --git a/10/ast.cpp b/10/ast.cpp index 453a12b..0abbb10 100644 --- a/10/ast.cpp +++ b/10/ast.cpp @@ -7,33 +7,15 @@ static void print_indent(int n, std::ostream& to) { while(n--) to << " "; } -type_ptr ast::typecheck_common(type_mgr& mgr, const type_env& env) { - node_type = typecheck(mgr, env); - return node_type; -} - -void ast::resolve_common(const type_mgr& mgr) { - type_var* var; - type_ptr resolved_type = mgr.resolve(node_type, var); - if(var) throw type_error("ambiguously typed program"); - - resolve(mgr); - node_type = std::move(resolved_type); -} - void ast_int::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "INT: " << value << std::endl; } -type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const { +type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) { return type_ptr(new type_base("Int")); } -void ast_int::resolve(const type_mgr& mgr) const { - -} - void ast_int::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr(new instruction_pushint(value))); } @@ -43,14 +25,10 @@ void ast_lid::print(int indent, std::ostream& to) const { to << "LID: " << id << std::endl; } -type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const { +type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) { return env.lookup(id); } -void ast_lid::resolve(const type_mgr& mgr) const { - -} - void ast_lid::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr( env->has_variable(id) ? @@ -63,14 +41,10 @@ void ast_uid::print(int indent, std::ostream& to) const { to << "UID: " << id << std::endl; } -type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const { +type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) { return env.lookup(id); } -void ast_uid::resolve(const type_mgr& mgr) const { - -} - void ast_uid::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr(new instruction_pushglobal(id))); } @@ -82,9 +56,9 @@ void ast_binop::print(int indent, std::ostream& to) const { right->print(indent + 1, to); } -type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { - type_ptr ltype = left->typecheck_common(mgr, env); - type_ptr rtype = right->typecheck_common(mgr, env); +type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); type_ptr ftype = env.lookup(op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); @@ -96,11 +70,6 @@ type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { return return_type; } -void ast_binop::resolve(const type_mgr& mgr) const { - left->resolve_common(mgr); - right->resolve_common(mgr); -} - void ast_binop::compile(const env_ptr& env, std::vector& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); @@ -117,9 +86,9 @@ void ast_app::print(int indent, std::ostream& to) const { right->print(indent + 1, to); } -type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { - type_ptr ltype = left->typecheck_common(mgr, env); - type_ptr rtype = right->typecheck_common(mgr, env); +type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); type_ptr return_type = mgr.new_type(); type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); @@ -127,11 +96,6 @@ type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { return return_type; } -void ast_app::resolve(const type_mgr& mgr) const { - left->resolve_common(mgr); - right->resolve_common(mgr); -} - void ast_app::compile(const env_ptr& env, std::vector& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); @@ -149,35 +113,28 @@ void ast_case::print(int indent, std::ostream& to) const { } } -type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { +type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) { type_var* var; - type_ptr case_type = mgr.resolve(of->typecheck_common(mgr, env), var); + type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); type_ptr branch_type = mgr.new_type(); for(auto& branch : branches) { type_env new_env = env.scope(); branch->pat->match(case_type, mgr, new_env); - type_ptr curr_branch_type = branch->expr->typecheck_common(mgr, new_env); + type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); mgr.unify(branch_type, curr_branch_type); } - case_type = mgr.resolve(case_type, var); - if(!dynamic_cast(case_type.get())) { + input_type = mgr.resolve(case_type, var); + if(!dynamic_cast(input_type.get())) { throw type_error("attempting case analysis of non-data type"); } return branch_type; } -void ast_case::resolve(const type_mgr& mgr) const { - of->resolve_common(mgr); - for(auto& branch : branches) { - branch->expr->resolve_common(mgr); - } -} - void ast_case::compile(const env_ptr& env, std::vector& into) const { - type_data* type = dynamic_cast(of->node_type.get()); + type_data* type = dynamic_cast(input_type.get()); of->compile(env, into); into.push_back(instruction_ptr(new instruction_eval())); diff --git a/10/ast.hpp b/10/ast.hpp index c88fc99..98b6b24 100644 --- a/10/ast.hpp +++ b/10/ast.hpp @@ -8,18 +8,12 @@ #include "env.hpp" struct ast { - type_ptr node_type; - virtual ~ast() = default; virtual void print(int indent, std::ostream& to) const = 0; - virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0; - virtual void resolve(const type_mgr& mgr) const = 0; + virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) = 0; virtual void compile(const env_ptr& env, std::vector& into) const = 0; - - type_ptr typecheck_common(type_mgr& mgr, const type_env& env); - void resolve_common(const type_mgr& mgr); }; using ast_ptr = std::unique_ptr; @@ -50,8 +44,7 @@ struct ast_int : public ast { : value(v) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; @@ -62,8 +55,7 @@ struct ast_lid : public ast { : id(std::move(i)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; @@ -74,8 +66,7 @@ struct ast_uid : public ast { : id(std::move(i)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; @@ -88,8 +79,7 @@ struct ast_binop : public ast { : op(o), left(std::move(l)), right(std::move(r)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; @@ -101,21 +91,20 @@ struct ast_app : public ast { : left(std::move(l)), right(std::move(r)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; struct ast_case : public ast { ast_ptr of; + type_ptr input_type; std::vector branches; ast_case(ast_ptr o, std::vector b) : of(std::move(o)), branches(std::move(b)) {} void print(int indent, std::ostream& to) const; - type_ptr typecheck(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env); void compile(const env_ptr& env, std::vector& into) const; }; diff --git a/10/definition.cpp b/10/definition.cpp index 9c100b0..4a14571 100644 --- a/10/definition.cpp +++ b/10/definition.cpp @@ -31,22 +31,10 @@ void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const type_it++; } - type_ptr body_type = body->typecheck_common(mgr, new_env); + type_ptr body_type = body->typecheck(mgr, new_env); mgr.unify(return_type, body_type); } -void definition_defn::resolve(const type_mgr& mgr) { - type_var* var; - body->resolve_common(mgr); - - return_type = mgr.resolve(return_type, var); - if(var) throw type_error("ambiguously typed program"); - for(auto& param_type : param_types) { - param_type = mgr.resolve(param_type, var); - if(var) throw type_error("ambiguously typed program"); - } -} - void definition_defn::compile() { env_ptr new_env = env_ptr(new env_offset(0, nullptr)); for(auto it = params.rbegin(); it != params.rend(); it++) { @@ -91,10 +79,6 @@ void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const // Nothing } -void definition_data::resolve(const type_mgr& mgr) { - // Nothing -} - void definition_data::compile() { } diff --git a/10/definition.hpp b/10/definition.hpp index 6004d2f..74a2b69 100644 --- a/10/definition.hpp +++ b/10/definition.hpp @@ -13,7 +13,6 @@ struct definition { virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; - virtual void resolve(const type_mgr& mgr) = 0; virtual void compile() = 0; virtual void gen_llvm_first(llvm_context& ctx) = 0; virtual void gen_llvm_second(llvm_context& ctx) = 0; @@ -51,7 +50,6 @@ struct definition_defn : public definition { void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_second(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr); void compile(); void gen_llvm_first(llvm_context& ctx); void gen_llvm_second(llvm_context& ctx); @@ -66,7 +64,6 @@ struct definition_data : public definition { void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_second(type_mgr& mgr, const type_env& env) const; - void resolve(const type_mgr& mgr); void compile(); void gen_llvm_first(llvm_context& ctx); void gen_llvm_second(llvm_context& ctx); diff --git a/10/main.cpp b/10/main.cpp index 2b01f61..27529b7 100644 --- a/10/main.cpp +++ b/10/main.cpp @@ -48,10 +48,6 @@ void typecheck_program( pair.second->print(mgr, std::cout); std::cout << std::endl; } - - for(auto& def : prog) { - def->resolve(mgr); - } } void compile_program(const std::vector& prog) {