From 85394b185d0e0d36b8e283380a13659bd615732b Mon Sep 17 00:00:00 2001 From: Danila Fedorin Date: Wed, 9 Sep 2020 22:49:35 -0700 Subject: [PATCH] Add prototype impl of case specialization. Boolean cases could be translated to ifs, and integer cases to jumps. That's still in progress. --- code/compiler/13/ast.cpp | 215 +++++++++++++++++++++++++++++++-------- 1 file changed, 170 insertions(+), 45 deletions(-) diff --git a/code/compiler/13/ast.cpp b/code/compiler/13/ast.cpp index 3f6b9f1..2e83066 100644 --- a/code/compiler/13/ast.cpp +++ b/code/compiler/13/ast.cpp @@ -1,5 +1,6 @@ #include "ast.hpp" #include +#include #include "binop.hpp" #include "error.hpp" #include "type_env.hpp" @@ -218,60 +219,184 @@ void ast_case::translate(global_scope& scope) { } } +template +struct case_mappings { + using tag_type = typename T::tag_type; + std::map> defined_cases; + std::optional> default_case; + + std::vector& make_case_for(tag_type tag) { + auto existing_case = defined_cases.find(tag); + if(existing_case != defined_cases.end()) return existing_case->second; + if(default_case) + throw type_error("attempted pattern match after catch-all"); + return defined_cases[tag]; + } + + std::vector& make_default_case() { + if(default_case) + throw type_error("attempted repeated use of catch-all"); + default_case.emplace(std::vector()); + return *default_case; + } + + std::vector& get_specific_case_for(tag_type tag) { + auto existing_case = defined_cases.find(tag); + assert(existing_case != defined_cases.end()); + return existing_case->second; + } + + std::vector& get_default_case() { + assert(default_case); + return *default_case; + } + + bool case_defined_for(tag_type tag) { + return defined_cases.find(tag) != defined_cases.end(); + } + + bool default_case_defined() { return default_case.has_value(); } + + size_t defined_cases_count() { return defined_cases.size(); } +}; + + +struct case_strategy_bool { + using tag_type = bool; + using repr_type = bool; + + tag_type tag_from_repr(repr_type b) { return b; } + + repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { + pattern_constr* cpat; + if(!(cpat = dynamic_cast(pt.get())) || + (cpat->constr != "True" && cpat->constr != "False") || + cpat->params.size() != 0) + throw type_error("pattern cannot be converted to a boolean"); + return cpat->constr == "True"; + } + + void compile_branch( + const branch_ptr& branch, + const env_ptr& env, + repr_type repr, + std::vector& into) { + branch->expr->compile(env_ptr(new env_offset(1, env)), into); + } + + size_t case_count(const type* type) { + return 2; + } + + instruction_ptr into_instruction(const type* type, case_mappings& ms) { + throw std::runtime_error("boolean case unimplemented!"); + } +}; + +struct case_strategy_data { + using tag_type = int; + using repr_type = std::pair*>; + + tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; } + + repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { + pattern_constr* cpat; + if(!(cpat = dynamic_cast(pt.get()))) + throw type_error("pattern cannot be interpreted as constructor."); + return std::make_pair( + &static_cast(type)->constructors.find(cpat->constr)->second, + &cpat->params); + } + + void compile_branch( + const branch_ptr& branch, + const env_ptr& env, + const repr_type& repr, + std::vector& into) { + env_ptr new_env = env; + for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) { + new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env)); + } + + into.push_back(instruction_ptr(new instruction_split(repr.second->size()))); + branch->expr->compile(new_env, into); + into.push_back(instruction_ptr(new instruction_slide(repr.second->size()))); + } + + size_t case_count(const type* type) { + return static_cast(type)->constructors.size(); + } + + instruction_ptr into_instruction(const type* type, case_mappings& ms) { + instruction_jump* jump_instruction = new instruction_jump(); + instruction_ptr inst(jump_instruction); + + auto data_type = static_cast(type); + for(auto& constr : data_type->constructors) { + if(!ms.case_defined_for(constr.second.tag)) continue; + jump_instruction->branches.push_back( + std::move(ms.get_specific_case_for(constr.second.tag))); + jump_instruction->tag_mappings[constr.second.tag] = + jump_instruction->branches.size() - 1; + } + + if(ms.default_case_defined()) { + jump_instruction->branches.push_back( + std::move(ms.get_default_case())); + for(auto& constr : data_type->constructors) { + if(ms.case_defined_for(constr.second.tag)) continue; + jump_instruction->tag_mappings[constr.second.tag] = + jump_instruction->branches.size(); + } + } + + return std::move(inst); + } +}; + +template +void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector& into) { + T strategy; + case_mappings cases; + for(auto& branch : node.branches) { + pattern_var* vpat; + if((vpat = dynamic_cast(branch->pat.get()))) { + auto& branch_into = cases.make_default_case(); + env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env)); + branch->expr->compile(new_env, branch_into); + } else { + auto repr = strategy.from_typed_pattern(branch->pat, type); + auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr)); + strategy.compile_branch(branch, env, repr, branch_into); + } + } + + if(!(cases.defined_cases_count() == strategy.case_count(type) || + cases.default_case_defined())) + throw type_error("incomplete patterns", node.loc); + + into.push_back(strategy.into_instruction(type, cases)); +} + void ast_case::compile(const env_ptr& env, std::vector& into) const { type_app* app_type = dynamic_cast(input_type.get()); - type_data* type = dynamic_cast(app_type->constructor.get()); + type_data* data; + type_internal* internal; of->compile(env, into); into.push_back(instruction_ptr(new instruction_eval())); - instruction_jump* jump_instruction = new instruction_jump(); - into.push_back(instruction_ptr(jump_instruction)); - for(auto& branch : branches) { - std::vector branch_instructions; - pattern_var* vpat; - pattern_constr* cpat; - - if((vpat = dynamic_cast(branch->pat.get()))) { - branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions); - - for(auto& constr_pair : type->constructors) { - if(jump_instruction->tag_mappings.find(constr_pair.second.tag) != - jump_instruction->tag_mappings.end()) - break; - - jump_instruction->tag_mappings[constr_pair.second.tag] = - jump_instruction->branches.size(); - } - jump_instruction->branches.push_back(std::move(branch_instructions)); - } else if((cpat = dynamic_cast(branch->pat.get()))) { - env_ptr new_env = env; - for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) { - new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env)); - } - - branch_instructions.push_back(instruction_ptr(new instruction_split( - cpat->params.size()))); - branch->expr->compile(new_env, branch_instructions); - branch_instructions.push_back(instruction_ptr(new instruction_slide( - cpat->params.size()))); - - int new_tag = type->constructors[cpat->constr].tag; - if(jump_instruction->tag_mappings.find(new_tag) != - jump_instruction->tag_mappings.end()) - throw type_error("technically not a type error: duplicate pattern", cpat->loc); - - jump_instruction->tag_mappings[new_tag] = - jump_instruction->branches.size(); - jump_instruction->branches.push_back(std::move(branch_instructions)); + if((data = dynamic_cast(app_type->constructor.get()))) { + compile_case(*this, env, data, into); + return; + } else if((internal = dynamic_cast(app_type->constructor.get()))) { + if(internal->name == "Bool") { + compile_case(*this, env, data, into); + return; } } - for(auto& constr_pair : type->constructors) { - if(jump_instruction->tag_mappings.find(constr_pair.second.tag) == - jump_instruction->tag_mappings.end()) - throw type_error("non-total pattern", loc); - } + throw std::runtime_error("no known way to compile case expression"); } void ast_let::print(int indent, std::ostream& to) const {