#include "ast.hpp" #include #include #include "binop.hpp" #include "error.hpp" #include "type_env.hpp" #include "env.hpp" static void print_indent(int n, std::ostream& to) { while(n--) to << " "; } void ast_int::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "INT: " << value << std::endl; } void ast_int::find_free(std::set& into) { } type_ptr ast_int::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; return type_ptr(new type_app(env->lookup_type("Int"))); } void ast_int::translate(global_scope& scope) { } void ast_int::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr(new instruction_pushint(value))); } void ast_lid::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "LID: " << id << std::endl; } void ast_lid::find_free(std::set& into) { into.insert(id); } type_ptr ast_lid::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; type_scheme_ptr lid_type = env->lookup(id); if(!lid_type) throw type_error(std::string("unknown identifier ") + id, loc); return lid_type->instantiate(mgr); } void ast_lid::translate(global_scope& scope) { } void ast_lid::compile(const env_ptr& env, std::vector& into) const { auto mangled_name = this->env->get_mangled_name(id); into.push_back(instruction_ptr( (env->has_variable(mangled_name) && !this->env->is_global(id)) ? (instruction*) new instruction_push(env->get_offset(mangled_name)) : (instruction*) new instruction_pushglobal(mangled_name))); } void ast_uid::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "UID: " << id << std::endl; } void ast_uid::find_free(std::set& into) { } type_ptr ast_uid::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; type_scheme_ptr uid_type = env->lookup(id); if(!uid_type) throw type_error(std::string("unknown constructor ") + id, loc); return uid_type->instantiate(mgr); } void ast_uid::translate(global_scope& scope) { } void ast_uid::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr( new instruction_pushglobal(this->env->get_mangled_name(id)))); } void ast_binop::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "BINOP: " << op_name(op) << std::endl; left->print(indent + 1, to); right->print(indent + 1, to); } void ast_binop::find_free(std::set& into) { left->find_free(into); right->find_free(into); } type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; type_ptr ltype = left->typecheck(mgr, env); type_ptr rtype = right->typecheck(mgr, env); type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op), loc); // For better type errors, we first require binary function, // and only then unify each argument. This way, we can // precisely point out which argument is "wrong". type_ptr return_type = mgr.new_type(); type_ptr second_type = mgr.new_type(); type_ptr first_type = mgr.new_type(); type_ptr arrow_one = type_ptr(new type_arr(second_type, return_type)); type_ptr arrow_two = type_ptr(new type_arr(first_type, arrow_one)); mgr.unify(ftype, arrow_two, loc); mgr.unify(first_type, ltype, left->loc); mgr.unify(second_type, rtype, right->loc); return return_type; } void ast_binop::translate(global_scope& scope) { left->translate(scope); right->translate(scope); } void ast_binop::compile(const env_ptr& env, std::vector& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); into.push_back(instruction_ptr(new instruction_pushglobal(op_action(op)))); into.push_back(instruction_ptr(new instruction_mkapp())); into.push_back(instruction_ptr(new instruction_mkapp())); } void ast_app::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "APP:" << std::endl; left->print(indent + 1, to); right->print(indent + 1, to); } void ast_app::find_free(std::set& into) { left->find_free(into); right->find_free(into); } type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; type_ptr ltype = left->typecheck(mgr, env); type_ptr rtype = right->typecheck(mgr, env); type_ptr return_type = mgr.new_type(); type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); mgr.unify(arrow, ltype, left->loc); return return_type; } void ast_app::translate(global_scope& scope) { left->translate(scope); right->translate(scope); } void ast_app::compile(const env_ptr& env, std::vector& into) const { right->compile(env, into); left->compile(env_ptr(new env_offset(1, env)), into); into.push_back(instruction_ptr(new instruction_mkapp())); } void ast_case::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "CASE: " << std::endl; for(auto& branch : branches) { print_indent(indent + 1, to); branch->pat->print(to); to << std::endl; branch->expr->print(indent + 2, to); } } void ast_case::find_free(std::set& into) { of->find_free(into); for(auto& branch : branches) { std::set free_in_branch; std::set pattern_variables; branch->pat->find_variables(pattern_variables); branch->expr->find_free(free_in_branch); for(auto& free : free_in_branch) { if(pattern_variables.find(free) == pattern_variables.end()) into.insert(free); } } } type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; type_var* var; type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); type_ptr branch_type = mgr.new_type(); for(auto& branch : branches) { type_env_ptr new_env = type_scope(env); branch->pat->typecheck(case_type, mgr, new_env); type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); mgr.unify(curr_branch_type, branch_type, branch->expr->loc); } input_type = mgr.resolve(case_type, var); type_app* app_type; return branch_type; } void ast_case::translate(global_scope& scope) { of->translate(scope); for(auto& branch : branches) { branch->expr->translate(scope); } } template struct case_mappings { using tag_type = typename T::tag_type; std::map> defined_cases; std::optional> default_case; std::vector& make_case_for(tag_type tag) { auto existing_case = defined_cases.find(tag); if(existing_case != defined_cases.end()) return existing_case->second; if(default_case) throw type_error("attempted pattern match after catch-all"); return defined_cases[tag]; } std::vector& make_default_case() { if(default_case) throw type_error("attempted repeated use of catch-all"); default_case.emplace(std::vector()); return *default_case; } std::vector& get_specific_case_for(tag_type tag) { auto existing_case = defined_cases.find(tag); assert(existing_case != defined_cases.end()); return existing_case->second; } std::vector& get_default_case() { assert(default_case); return *default_case; } std::vector& get_case_for(tag_type tag) { if(case_defined_for(tag)) return get_specific_case_for(tag); return get_default_case(); } bool case_defined_for(tag_type tag) { return defined_cases.find(tag) != defined_cases.end(); } bool default_case_defined() { return default_case.has_value(); } size_t defined_cases_count() { return defined_cases.size(); } }; struct case_strategy_bool { using tag_type = bool; using repr_type = bool; tag_type tag_from_repr(repr_type b) { return b; } repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { pattern_constr* cpat; if(!(cpat = dynamic_cast(pt.get())) || (cpat->constr != "True" && cpat->constr != "False") || cpat->params.size() != 0) throw type_error("pattern cannot be converted to a boolean"); return cpat->constr == "True"; } void compile_branch( const branch_ptr& branch, const env_ptr& env, repr_type repr, std::vector& into) { branch->expr->compile(env_ptr(new env_offset(1, env)), into); into.push_back(instruction_ptr(new instruction_slide(1))); } size_t case_count(const type* type) { return 2; } void into_instructions( const type* type, case_mappings& ms, std::vector& into) { if(ms.defined_cases_count() == 0) { for(auto& instruction : ms.get_default_case()) into.push_back(std::move(instruction)); return; } into.push_back(instruction_ptr(new instruction_if( std::move(ms.get_case_for(true)), std::move(ms.get_case_for(false))))); } }; struct case_strategy_data { using tag_type = int; using repr_type = std::pair*>; tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; } repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { pattern_constr* cpat; if(!(cpat = dynamic_cast(pt.get()))) throw type_error("pattern cannot be interpreted as constructor."); return std::make_pair( &static_cast(type)->constructors.find(cpat->constr)->second, &cpat->params); } void compile_branch( const branch_ptr& branch, const env_ptr& env, const repr_type& repr, std::vector& into) { env_ptr new_env = env; for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) { new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env)); } into.push_back(instruction_ptr(new instruction_split(repr.second->size()))); branch->expr->compile(new_env, into); into.push_back(instruction_ptr(new instruction_slide(repr.second->size()))); } size_t case_count(const type* type) { return static_cast(type)->constructors.size(); } void into_instructions( const type* type, case_mappings& ms, std::vector& into) { instruction_jump* jump_instruction = new instruction_jump(); instruction_ptr inst(jump_instruction); auto data_type = static_cast(type); for(auto& constr : data_type->constructors) { if(!ms.case_defined_for(constr.second.tag)) continue; jump_instruction->branches.push_back( std::move(ms.get_specific_case_for(constr.second.tag))); jump_instruction->tag_mappings[constr.second.tag] = jump_instruction->branches.size() - 1; } if(ms.default_case_defined()) { jump_instruction->branches.push_back( std::move(ms.get_default_case())); for(auto& constr : data_type->constructors) { if(ms.case_defined_for(constr.second.tag)) continue; jump_instruction->tag_mappings[constr.second.tag] = jump_instruction->branches.size(); } } into.push_back(std::move(inst)); } }; template void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector& into) { T strategy; case_mappings cases; for(auto& branch : node.branches) { pattern_var* vpat; if((vpat = dynamic_cast(branch->pat.get()))) { auto& branch_into = cases.make_default_case(); env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env)); branch->expr->compile(new_env, branch_into); branch_into.push_back(instruction_ptr(new instruction_slide(1))); } else { auto repr = strategy.from_typed_pattern(branch->pat, type); auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr)); strategy.compile_branch(branch, env, repr, branch_into); } } if(!(cases.defined_cases_count() == strategy.case_count(type) || cases.default_case_defined())) throw type_error("incomplete patterns", node.loc); strategy.into_instructions(type, cases, into); } void ast_case::compile(const env_ptr& env, std::vector& into) const { type_app* app_type = dynamic_cast(input_type.get()); type_data* data; type_internal* internal; of->compile(env, into); into.push_back(instruction_ptr(new instruction_eval())); if(app_type && (data = dynamic_cast(app_type->constructor.get()))) { compile_case(*this, env, data, into); return; } else if(app_type && (internal = dynamic_cast(app_type->constructor.get()))) { if(internal->name == "Bool") { compile_case(*this, env, data, into); return; } } throw type_error("attempting unsupported case analysis", of->loc); } void ast_let::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "LET: " << std::endl; in->print(indent + 1, to); } void ast_let::find_free(std::set& into) { definitions.find_free(into); std::set all_free; in->find_free(all_free); for(auto& free_var : all_free) { if(definitions.defs_defn.find(free_var) == definitions.defs_defn.end()) into.insert(free_var); } } type_ptr ast_let::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; definitions.typecheck(mgr, env); return in->typecheck(mgr, definitions.env); } void ast_let::translate(global_scope& scope) { for(auto& def : definitions.defs_data) { def.second->into_globals(scope); } for(auto& def : definitions.defs_defn) { size_t original_params = def.second->params.size(); std::string original_name = def.second->name; auto& global_definition = def.second->into_global(scope); size_t captured = global_definition.params.size() - original_params; type_env_ptr mangled_env = type_scope(env); mangled_env->bind(def.first, env->lookup(def.first), visibility::global); mangled_env->set_mangled_name(def.first, global_definition.name); ast_ptr global_app(new ast_lid(original_name)); global_app->env = mangled_env; for(auto& param : global_definition.params) { if(!(captured--)) break; ast_ptr new_arg(new ast_lid(param)); new_arg->env = env; global_app = ast_ptr(new ast_app(std::move(global_app), std::move(new_arg))); global_app->env = env; } translated_definitions.push_back({ def.first, std::move(global_app) }); } in->translate(scope); } void ast_let::compile(const env_ptr& env, std::vector& into) const { into.push_back(instruction_ptr(new instruction_alloc(translated_definitions.size()))); env_ptr new_env = env; for(auto& def : translated_definitions) { new_env = env_ptr(new env_var(definitions.env->get_mangled_name(def.first), std::move(new_env))); } int offset = translated_definitions.size() - 1; for(auto& def : translated_definitions) { def.second->compile(new_env, into); into.push_back(instruction_ptr(new instruction_update(offset--))); } in->compile(new_env, into); into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size()))); } void ast_lambda::print(int indent, std::ostream& to) const { print_indent(indent, to); to << "LAMBDA"; for(auto& param : params) { to << " " << param; } to << std::endl; body->print(indent+1, to); } void ast_lambda::find_free(std::set& into) { body->find_free(free_variables); for(auto& param : params) { free_variables.erase(param); } into.insert(free_variables.begin(), free_variables.end()); } type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) { this->env = env; var_env = type_scope(env); type_ptr return_type = mgr.new_type(); type_ptr full_type = return_type; for(auto it = params.rbegin(); it != params.rend(); it++) { type_ptr param_type = mgr.new_type(); var_env->bind(*it, param_type); full_type = type_ptr(new type_arr(std::move(param_type), full_type)); } mgr.unify(return_type, body->typecheck(mgr, var_env), body->loc); return full_type; } void ast_lambda::translate(global_scope& scope) { std::vector function_params; for(auto& free_variable : free_variables) { if(env->is_global(free_variable)) continue; function_params.push_back(free_variable); } size_t captured_count = function_params.size(); function_params.insert(function_params.end(), params.begin(), params.end()); auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body)); type_env_ptr mangled_env = type_scope(env); mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global); mangled_env->set_mangled_name("lambda", new_function.name); ast_ptr new_application = ast_ptr(new ast_lid("lambda")); new_application->env = mangled_env; for(auto& param : new_function.params) { if(!(captured_count--)) break; ast_ptr new_arg = ast_ptr(new ast_lid(param)); new_arg->env = env; new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg))); new_application->env = env; } translated = std::move(new_application); } void ast_lambda::compile(const env_ptr& env, std::vector& into) const { translated->compile(env, into); } void pattern_var::print(std::ostream& to) const { to << var; } void pattern_var::find_variables(std::set& into) const { into.insert(var); } void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { env->bind(var, t); } void pattern_constr::print(std::ostream& to) const { to << constr; for(auto& param : params) { to << " " << param; } } void pattern_constr::find_variables(std::set& into) const { into.insert(params.begin(), params.end()); } void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const { type_scheme_ptr constructor_type_scheme = env->lookup(constr); if(!constructor_type_scheme) { throw type_error(std::string("pattern using unknown constructor ") + constr, loc); } type_ptr constructor_type = constructor_type_scheme->instantiate(mgr); for(auto& param : params) { type_arr* arr = dynamic_cast(constructor_type.get()); if(!arr) throw type_error("too many parameters in constructor pattern", loc); env->bind(param, arr->left); constructor_type = arr->right; } mgr.unify(constructor_type, t, loc); }