Factor type into case strategy constructor.

This commit is contained in:
Danila Fedorin 2020-09-11 13:03:00 -07:00
parent 9a591d6da6
commit 4cdb9360fb

View File

@ -3,6 +3,7 @@
#include <type_traits> #include <type_traits>
#include "binop.hpp" #include "binop.hpp"
#include "error.hpp" #include "error.hpp"
#include "type.hpp"
#include "type_env.hpp" #include "type_env.hpp"
#include "env.hpp" #include "env.hpp"
@ -275,9 +276,11 @@ struct case_strategy_bool {
using tag_type = bool; using tag_type = bool;
using repr_type = bool; using repr_type = bool;
case_strategy_bool(const type* type) {}
tag_type tag_from_repr(repr_type b) { return b; } tag_type tag_from_repr(repr_type b) { return b; }
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { repr_type repr_from_pattern(const pattern_ptr& pt) {
pattern_constr* cpat; pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) || if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) ||
(cpat->constr != "True" && cpat->constr != "False") || (cpat->constr != "True" && cpat->constr != "False") ||
@ -297,12 +300,11 @@ struct case_strategy_bool {
into.push_back(instruction_ptr(new instruction_slide(1))); into.push_back(instruction_ptr(new instruction_slide(1)));
} }
size_t case_count(const type* type) { size_t case_count() {
return 2; return 2;
} }
void into_instructions( void into_instructions(
const type* type,
case_mappings<case_strategy_bool>& ms, case_mappings<case_strategy_bool>& ms,
std::vector<instruction_ptr>& into) { std::vector<instruction_ptr>& into) {
if(ms.defined_cases_count() == 0) { if(ms.defined_cases_count() == 0) {
@ -321,16 +323,23 @@ struct case_strategy_data {
using tag_type = int; using tag_type = int;
using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>; using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>;
const type_data* arg_type;
case_strategy_data(const type* t) {
arg_type = dynamic_cast<const type_data*>(t);
assert(arg_type);
}
tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; } tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; }
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { repr_type repr_from_pattern(const pattern_ptr& pt) {
pattern_constr* cpat; pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get()))) if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())))
throw type_error( throw type_error(
"pattern cannot be interpreted as constructor.", "pattern cannot be interpreted as constructor.",
pt->loc); pt->loc);
return std::make_pair( return std::make_pair(
&static_cast<const type_data*>(type)->constructors.find(cpat->constr)->second, &arg_type->constructors.find(cpat->constr)->second,
&cpat->params); &cpat->params);
} }
@ -349,19 +358,17 @@ struct case_strategy_data {
into.push_back(instruction_ptr(new instruction_slide(repr.second->size()))); into.push_back(instruction_ptr(new instruction_slide(repr.second->size())));
} }
size_t case_count(const type* type) { size_t case_count() {
return static_cast<const type_data*>(type)->constructors.size(); return arg_type->constructors.size();
} }
void into_instructions( void into_instructions(
const type* type,
case_mappings<case_strategy_data>& ms, case_mappings<case_strategy_data>& ms,
std::vector<instruction_ptr>& into) { std::vector<instruction_ptr>& into) {
instruction_jump* jump_instruction = new instruction_jump(); instruction_jump* jump_instruction = new instruction_jump();
instruction_ptr inst(jump_instruction); instruction_ptr inst(jump_instruction);
auto data_type = static_cast<const type_data*>(type); for(auto& constr : arg_type->constructors) {
for(auto& constr : data_type->constructors) {
if(!ms.case_defined_for(constr.second.tag)) continue; if(!ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->branches.push_back( jump_instruction->branches.push_back(
std::move(ms.get_specific_case_for(constr.second.tag))); std::move(ms.get_specific_case_for(constr.second.tag)));
@ -372,7 +379,7 @@ struct case_strategy_data {
if(ms.default_case_defined()) { if(ms.default_case_defined()) {
jump_instruction->branches.push_back( jump_instruction->branches.push_back(
std::move(ms.get_default_case())); std::move(ms.get_default_case()));
for(auto& constr : data_type->constructors) { for(auto& constr : arg_type->constructors) {
if(ms.case_defined_for(constr.second.tag)) continue; if(ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->tag_mappings[constr.second.tag] = jump_instruction->tag_mappings[constr.second.tag] =
jump_instruction->branches.size(); jump_instruction->branches.size();
@ -385,29 +392,29 @@ struct case_strategy_data {
template <typename T> template <typename T>
void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) { void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) {
T strategy; T strategy(type);
case_mappings<T> cases; case_mappings<T> cases;
for(auto& branch : node.branches) { for(auto& branch : node.branches) {
pattern_var* vpat; pattern_var* vpat;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) { if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
if(cases.defined_cases_count() == strategy.case_count(type)) if(cases.defined_cases_count() == strategy.case_count())
throw type_error("redundant catch-all pattern", branch->pat->loc); throw type_error("redundant catch-all pattern", branch->pat->loc);
auto& branch_into = cases.make_default_case(); auto& branch_into = cases.make_default_case();
env_ptr new_env(new env_var(vpat->var, env)); env_ptr new_env(new env_var(vpat->var, env));
branch->expr->compile(new_env, branch_into); branch->expr->compile(new_env, branch_into);
branch_into.push_back(instruction_ptr(new instruction_slide(1))); branch_into.push_back(instruction_ptr(new instruction_slide(1)));
} else { } else {
auto repr = strategy.from_typed_pattern(branch->pat, type); auto repr = strategy.repr_from_pattern(branch->pat);
auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr)); auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr));
strategy.compile_branch(branch, env, repr, branch_into); strategy.compile_branch(branch, env, repr, branch_into);
} }
} }
if(!(cases.defined_cases_count() == strategy.case_count(type) || if(!(cases.defined_cases_count() == strategy.case_count() ||
cases.default_case_defined())) cases.default_case_defined()))
throw type_error("incomplete patterns", node.loc); throw type_error("incomplete patterns", node.loc);
strategy.into_instructions(type, cases, into); strategy.into_instructions(cases, into);
} }
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {