#include "ast.hpp" #include <ostream> #include "binop.hpp" #include "error.hpp" static void print_indent(int n, std::ostream& to) { while(n--) to << " "; } type_ptr ast::typecheck_common(type_mgr& mgr, const type_env& env) { node_type = typecheck(mgr, env); return node_type; } void ast::resolve_common(const type_mgr& mgr) { type_var* var; type_ptr resolved_type = mgr.resolve(node_type, var); if(var) throw type_error("ambiguously typed program"); resolve(mgr); node_type = std::move(resolved_type); } void ast_int::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "INT: " << value << std::endl; } type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const { return type_ptr(new type_base("Int")); } void ast_int::resolve(const type_mgr& mgr) const { } void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { into.push_back(instruction_ptr(new instruction_pushint(value))); } void ast_lid::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "LID: " << id << std::endl; } type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const { return env.lookup(id); } void ast_lid::resolve(const type_mgr& mgr) const { } void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { into.push_back(instruction_ptr( env->has_variable(id) ? (instruction*) new instruction_push(env->get_offset(id)) : (instruction*) new instruction_pushglobal(id))); } void ast_uid::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "UID: " << id << std::endl; } type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const { return env.lookup(id); } void ast_uid::resolve(const type_mgr& mgr) const { } void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { into.push_back(instruction_ptr(new instruction_pushglobal(id))); } void ast_binop::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "BINOP: " << op_name(op) << std::endl; left->print(indent + 1, to); right->print(indent + 1, to); } type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ltype = left->typecheck_common(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env); type_ptr ftype = env.lookup(op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); type_ptr return_type = mgr.new_type(); type_ptr arrow_one = type_ptr(new type_arr(rtype, return_type)); type_ptr arrow_two = type_ptr(new type_arr(ltype, arrow_one)); mgr.unify(arrow_two, ftype); return return_type; } void ast_binop::resolve(const type_mgr& mgr) const { left->resolve_common(mgr); right->resolve_common(mgr); } void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); into.push_back(instruction_ptr(new instruction_pushglobal(op_action(op)))); into.push_back(instruction_ptr(new instruction_mkapp())); into.push_back(instruction_ptr(new instruction_mkapp())); } void ast_app::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "APP:" << std::endl; left->print(indent + 1, to); right->print(indent + 1, to); } type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ltype = left->typecheck_common(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env); type_ptr return_type = mgr.new_type(); type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); mgr.unify(arrow, ltype); return return_type; } void ast_app::resolve(const type_mgr& mgr) const { left->resolve_common(mgr); right->resolve_common(mgr); } void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); into.push_back(instruction_ptr(new instruction_mkapp())); } void ast_case::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "CASE: " << std::endl; for(auto& branch : branches) { print_indent(indent + 1, to); branch->pat->print(to); to << std::endl; branch->expr->print(indent + 2, to); } } type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { type_var* var; type_ptr case_type = mgr.resolve(of->typecheck_common(mgr, env), var); type_ptr branch_type = mgr.new_type(); for(auto& branch : branches) { type_env new_env = env.scope(); branch->pat->match(case_type, mgr, new_env); type_ptr curr_branch_type = branch->expr->typecheck_common(mgr, new_env); mgr.unify(branch_type, curr_branch_type); } case_type = mgr.resolve(case_type, var); if(!dynamic_cast<type_data*>(case_type.get())) { throw type_error("attempting case analysis of non-data type"); } return branch_type; } void ast_case::resolve(const type_mgr& mgr) const { of->resolve_common(mgr); for(auto& branch : branches) { branch->expr->resolve_common(mgr); } } void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { type_data* type = dynamic_cast<type_data*>(of->node_type.get()); of->compile(env, into); into.push_back(instruction_ptr(new instruction_eval())); instruction_jump* jump_instruction = new instruction_jump(); into.push_back(instruction_ptr(jump_instruction)); for(auto& branch : branches) { std::vector<instruction_ptr> branch_instructions; pattern_var* vpat; pattern_constr* cpat; if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) { branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions); for(auto& constr_pair : type->constructors) { if(jump_instruction->tag_mappings.find(constr_pair.second.tag) != jump_instruction->tag_mappings.end()) break; jump_instruction->tag_mappings[constr_pair.second.tag] = jump_instruction->branches.size(); } jump_instruction->branches.push_back(std::move(branch_instructions)); } else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) { env_ptr new_env = env; for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) { new_env = env_ptr(new env_var(*it, new_env)); } branch_instructions.push_back(instruction_ptr(new instruction_split( cpat->params.size()))); branch->expr->compile(new_env, branch_instructions); branch_instructions.push_back(instruction_ptr(new instruction_slide( cpat->params.size()))); int new_tag = type->constructors[cpat->constr].tag; if(jump_instruction->tag_mappings.find(new_tag) != jump_instruction->tag_mappings.end()) throw type_error("technically not a type error: duplicate pattern"); jump_instruction->tag_mappings[new_tag] = jump_instruction->branches.size(); jump_instruction->branches.push_back(std::move(branch_instructions)); } } for(auto& constr_pair : type->constructors) { if(jump_instruction->tag_mappings.find(constr_pair.second.tag) == jump_instruction->tag_mappings.end()) throw type_error("non-total pattern"); } } void pattern_var::print(std::ostream& to) const { to << var; } void pattern_var::match(type_ptr t, type_mgr& mgr, type_env& env) const { env.bind(var, t); } void pattern_constr::print(std::ostream& to) const { to << constr; for(auto& param : params) { to << " " << param; } } void pattern_constr::match(type_ptr t, type_mgr& mgr, type_env& env) const { type_ptr constructor_type = env.lookup(constr); if(!constructor_type) { throw type_error(std::string("pattern using unknown constructor ") + constr); } for(int i = 0; i < params.size(); i++) { type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get()); if(!arr) throw type_error("too many parameters in constructor pattern"); env.bind(params[i], arr->left); constructor_type = arr->right; } mgr.unify(t, constructor_type); }