Add prototype impl of case specialization.

Boolean cases could be translated to ifs, and
integer cases to jumps. That's still in progress.
This commit is contained in:
Danila Fedorin 2020-09-09 22:49:35 -07:00
parent 86b49f9cc3
commit 85394b185d

View File

@ -1,5 +1,6 @@
#include "ast.hpp" #include "ast.hpp"
#include <ostream> #include <ostream>
#include <type_traits>
#include "binop.hpp" #include "binop.hpp"
#include "error.hpp" #include "error.hpp"
#include "type_env.hpp" #include "type_env.hpp"
@ -218,60 +219,184 @@ void ast_case::translate(global_scope& scope) {
} }
} }
template <typename T>
struct case_mappings {
using tag_type = typename T::tag_type;
std::map<tag_type, std::vector<instruction_ptr>> defined_cases;
std::optional<std::vector<instruction_ptr>> default_case;
std::vector<instruction_ptr>& make_case_for(tag_type tag) {
auto existing_case = defined_cases.find(tag);
if(existing_case != defined_cases.end()) return existing_case->second;
if(default_case)
throw type_error("attempted pattern match after catch-all");
return defined_cases[tag];
}
std::vector<instruction_ptr>& make_default_case() {
if(default_case)
throw type_error("attempted repeated use of catch-all");
default_case.emplace(std::vector<instruction_ptr>());
return *default_case;
}
std::vector<instruction_ptr>& get_specific_case_for(tag_type tag) {
auto existing_case = defined_cases.find(tag);
assert(existing_case != defined_cases.end());
return existing_case->second;
}
std::vector<instruction_ptr>& get_default_case() {
assert(default_case);
return *default_case;
}
bool case_defined_for(tag_type tag) {
return defined_cases.find(tag) != defined_cases.end();
}
bool default_case_defined() { return default_case.has_value(); }
size_t defined_cases_count() { return defined_cases.size(); }
};
struct case_strategy_bool {
using tag_type = bool;
using repr_type = bool;
tag_type tag_from_repr(repr_type b) { return b; }
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) {
pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) ||
(cpat->constr != "True" && cpat->constr != "False") ||
cpat->params.size() != 0)
throw type_error("pattern cannot be converted to a boolean");
return cpat->constr == "True";
}
void compile_branch(
const branch_ptr& branch,
const env_ptr& env,
repr_type repr,
std::vector<instruction_ptr>& into) {
branch->expr->compile(env_ptr(new env_offset(1, env)), into);
}
size_t case_count(const type* type) {
return 2;
}
instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_bool>& ms) {
throw std::runtime_error("boolean case unimplemented!");
}
};
struct case_strategy_data {
using tag_type = int;
using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>;
tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; }
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) {
pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())))
throw type_error("pattern cannot be interpreted as constructor.");
return std::make_pair(
&static_cast<const type_data*>(type)->constructors.find(cpat->constr)->second,
&cpat->params);
}
void compile_branch(
const branch_ptr& branch,
const env_ptr& env,
const repr_type& repr,
std::vector<instruction_ptr>& into) {
env_ptr new_env = env;
for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) {
new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env));
}
into.push_back(instruction_ptr(new instruction_split(repr.second->size())));
branch->expr->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_slide(repr.second->size())));
}
size_t case_count(const type* type) {
return static_cast<const type_data*>(type)->constructors.size();
}
instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_data>& ms) {
instruction_jump* jump_instruction = new instruction_jump();
instruction_ptr inst(jump_instruction);
auto data_type = static_cast<const type_data*>(type);
for(auto& constr : data_type->constructors) {
if(!ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->branches.push_back(
std::move(ms.get_specific_case_for(constr.second.tag)));
jump_instruction->tag_mappings[constr.second.tag] =
jump_instruction->branches.size() - 1;
}
if(ms.default_case_defined()) {
jump_instruction->branches.push_back(
std::move(ms.get_default_case()));
for(auto& constr : data_type->constructors) {
if(ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->tag_mappings[constr.second.tag] =
jump_instruction->branches.size();
}
}
return std::move(inst);
}
};
template <typename T>
void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) {
T strategy;
case_mappings<T> cases;
for(auto& branch : node.branches) {
pattern_var* vpat;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
auto& branch_into = cases.make_default_case();
env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env));
branch->expr->compile(new_env, branch_into);
} else {
auto repr = strategy.from_typed_pattern(branch->pat, type);
auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr));
strategy.compile_branch(branch, env, repr, branch_into);
}
}
if(!(cases.defined_cases_count() == strategy.case_count(type) ||
cases.default_case_defined()))
throw type_error("incomplete patterns", node.loc);
into.push_back(strategy.into_instruction(type, cases));
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
type_app* app_type = dynamic_cast<type_app*>(input_type.get()); type_app* app_type = dynamic_cast<type_app*>(input_type.get());
type_data* type = dynamic_cast<type_data*>(app_type->constructor.get()); type_data* data;
type_internal* internal;
of->compile(env, into); of->compile(env, into);
into.push_back(instruction_ptr(new instruction_eval())); into.push_back(instruction_ptr(new instruction_eval()));
instruction_jump* jump_instruction = new instruction_jump(); if((data = dynamic_cast<type_data*>(app_type->constructor.get()))) {
into.push_back(instruction_ptr(jump_instruction)); compile_case<case_strategy_data>(*this, env, data, into);
for(auto& branch : branches) { return;
std::vector<instruction_ptr> branch_instructions; } else if((internal = dynamic_cast<type_internal*>(app_type->constructor.get()))) {
pattern_var* vpat; if(internal->name == "Bool") {
pattern_constr* cpat; compile_case<case_strategy_bool>(*this, env, data, into);
return;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions);
for(auto& constr_pair : type->constructors) {
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) !=
jump_instruction->tag_mappings.end())
break;
jump_instruction->tag_mappings[constr_pair.second.tag] =
jump_instruction->branches.size();
}
jump_instruction->branches.push_back(std::move(branch_instructions));
} else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) {
env_ptr new_env = env;
for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) {
new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env));
}
branch_instructions.push_back(instruction_ptr(new instruction_split(
cpat->params.size())));
branch->expr->compile(new_env, branch_instructions);
branch_instructions.push_back(instruction_ptr(new instruction_slide(
cpat->params.size())));
int new_tag = type->constructors[cpat->constr].tag;
if(jump_instruction->tag_mappings.find(new_tag) !=
jump_instruction->tag_mappings.end())
throw type_error("technically not a type error: duplicate pattern", cpat->loc);
jump_instruction->tag_mappings[new_tag] =
jump_instruction->branches.size();
jump_instruction->branches.push_back(std::move(branch_instructions));
} }
} }
for(auto& constr_pair : type->constructors) { throw std::runtime_error("no known way to compile case expression");
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) ==
jump_instruction->tag_mappings.end())
throw type_error("non-total pattern", loc);
}
} }
void ast_let::print(int indent, std::ostream& to) const { void ast_let::print(int indent, std::ostream& to) const {