Danila Fedorin
c7b2a4959f
Boolean cases could be translated to ifs, and integer cases to jumps. That's still in progress.
571 lines
19 KiB
C++
571 lines
19 KiB
C++
#include "ast.hpp"
|
|
#include <ostream>
|
|
#include <type_traits>
|
|
#include "binop.hpp"
|
|
#include "error.hpp"
|
|
#include "type_env.hpp"
|
|
#include "env.hpp"
|
|
|
|
static void print_indent(int n, std::ostream& to) {
|
|
while(n--) to << " ";
|
|
}
|
|
|
|
void ast_int::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "INT: " << value << std::endl;
|
|
}
|
|
|
|
void ast_int::find_free(std::set<std::string>& into) {
|
|
|
|
}
|
|
|
|
type_ptr ast_int::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
return type_ptr(new type_app(env->lookup_type("Int")));
|
|
}
|
|
|
|
void ast_int::translate(global_scope& scope) {
|
|
|
|
}
|
|
|
|
void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
into.push_back(instruction_ptr(new instruction_pushint(value)));
|
|
}
|
|
|
|
void ast_lid::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "LID: " << id << std::endl;
|
|
}
|
|
|
|
void ast_lid::find_free(std::set<std::string>& into) {
|
|
into.insert(id);
|
|
}
|
|
|
|
type_ptr ast_lid::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
return env->lookup(id)->instantiate(mgr);
|
|
}
|
|
|
|
void ast_lid::translate(global_scope& scope) {
|
|
|
|
}
|
|
|
|
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
auto mangled_name = this->env->get_mangled_name(id);
|
|
into.push_back(instruction_ptr(
|
|
(env->has_variable(mangled_name) && !this->env->is_global(id)) ?
|
|
(instruction*) new instruction_push(env->get_offset(mangled_name)) :
|
|
(instruction*) new instruction_pushglobal(mangled_name)));
|
|
}
|
|
|
|
void ast_uid::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "UID: " << id << std::endl;
|
|
}
|
|
|
|
void ast_uid::find_free(std::set<std::string>& into) {
|
|
|
|
}
|
|
|
|
type_ptr ast_uid::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
return env->lookup(id)->instantiate(mgr);
|
|
}
|
|
|
|
void ast_uid::translate(global_scope& scope) {
|
|
|
|
}
|
|
|
|
void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
into.push_back(instruction_ptr(
|
|
new instruction_pushglobal(this->env->get_mangled_name(id))));
|
|
}
|
|
|
|
void ast_binop::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "BINOP: " << op_name(op) << std::endl;
|
|
left->print(indent + 1, to);
|
|
right->print(indent + 1, to);
|
|
}
|
|
|
|
void ast_binop::find_free(std::set<std::string>& into) {
|
|
left->find_free(into);
|
|
right->find_free(into);
|
|
}
|
|
|
|
type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
type_ptr ltype = left->typecheck(mgr, env);
|
|
type_ptr rtype = right->typecheck(mgr, env);
|
|
type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr);
|
|
if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op), loc);
|
|
|
|
// For better type errors, we first require binary function,
|
|
// and only then unify each argument. This way, we can
|
|
// precisely point out which argument is "wrong".
|
|
|
|
type_ptr return_type = mgr.new_type();
|
|
type_ptr second_type = mgr.new_type();
|
|
type_ptr first_type = mgr.new_type();
|
|
type_ptr arrow_one = type_ptr(new type_arr(second_type, return_type));
|
|
type_ptr arrow_two = type_ptr(new type_arr(first_type, arrow_one));
|
|
|
|
mgr.unify(ftype, arrow_two, loc);
|
|
mgr.unify(first_type, ltype, left->loc);
|
|
mgr.unify(second_type, rtype, right->loc);
|
|
return return_type;
|
|
}
|
|
|
|
void ast_binop::translate(global_scope& scope) {
|
|
left->translate(scope);
|
|
right->translate(scope);
|
|
}
|
|
|
|
void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
right->compile(env, into);
|
|
left->compile(env_ptr(new env_offset(1, env)), into);
|
|
|
|
into.push_back(instruction_ptr(new instruction_pushglobal(op_action(op))));
|
|
into.push_back(instruction_ptr(new instruction_mkapp()));
|
|
into.push_back(instruction_ptr(new instruction_mkapp()));
|
|
}
|
|
|
|
void ast_app::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "APP:" << std::endl;
|
|
left->print(indent + 1, to);
|
|
right->print(indent + 1, to);
|
|
}
|
|
|
|
void ast_app::find_free(std::set<std::string>& into) {
|
|
left->find_free(into);
|
|
right->find_free(into);
|
|
}
|
|
|
|
type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
type_ptr ltype = left->typecheck(mgr, env);
|
|
type_ptr rtype = right->typecheck(mgr, env);
|
|
|
|
type_ptr return_type = mgr.new_type();
|
|
type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
|
|
mgr.unify(arrow, ltype, left->loc);
|
|
return return_type;
|
|
}
|
|
|
|
void ast_app::translate(global_scope& scope) {
|
|
left->translate(scope);
|
|
right->translate(scope);
|
|
}
|
|
|
|
void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
right->compile(env, into);
|
|
left->compile(env_ptr(new env_offset(1, env)), into);
|
|
into.push_back(instruction_ptr(new instruction_mkapp()));
|
|
}
|
|
|
|
void ast_case::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "CASE: " << std::endl;
|
|
for(auto& branch : branches) {
|
|
print_indent(indent + 1, to);
|
|
branch->pat->print(to);
|
|
to << std::endl;
|
|
branch->expr->print(indent + 2, to);
|
|
}
|
|
}
|
|
|
|
void ast_case::find_free(std::set<std::string>& into) {
|
|
of->find_free(into);
|
|
for(auto& branch : branches) {
|
|
std::set<std::string> free_in_branch;
|
|
std::set<std::string> pattern_variables;
|
|
branch->pat->find_variables(pattern_variables);
|
|
branch->expr->find_free(free_in_branch);
|
|
for(auto& free : free_in_branch) {
|
|
if(pattern_variables.find(free) == pattern_variables.end())
|
|
into.insert(free);
|
|
}
|
|
}
|
|
}
|
|
|
|
type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
type_var* var;
|
|
type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var);
|
|
type_ptr branch_type = mgr.new_type();
|
|
|
|
for(auto& branch : branches) {
|
|
type_env_ptr new_env = type_scope(env);
|
|
branch->pat->typecheck(case_type, mgr, new_env);
|
|
type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env);
|
|
mgr.unify(curr_branch_type, branch_type, branch->expr->loc);
|
|
}
|
|
|
|
input_type = mgr.resolve(case_type, var);
|
|
type_app* app_type;
|
|
if(!(app_type = dynamic_cast<type_app*>(input_type.get())) ||
|
|
!dynamic_cast<type_data*>(app_type->constructor.get())) {
|
|
throw type_error("attempting case analysis of non-data type", of->loc);
|
|
}
|
|
|
|
return branch_type;
|
|
}
|
|
|
|
void ast_case::translate(global_scope& scope) {
|
|
of->translate(scope);
|
|
for(auto& branch : branches) {
|
|
branch->expr->translate(scope);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
struct case_mappings {
|
|
using tag_type = typename T::tag_type;
|
|
std::map<tag_type, std::vector<instruction_ptr>> defined_cases;
|
|
std::optional<std::vector<instruction_ptr>> default_case;
|
|
|
|
std::vector<instruction_ptr>& make_case_for(tag_type tag) {
|
|
auto existing_case = defined_cases.find(tag);
|
|
if(existing_case != defined_cases.end()) return existing_case->second;
|
|
if(default_case)
|
|
throw type_error("attempted pattern match after catch-all");
|
|
return defined_cases[tag];
|
|
}
|
|
|
|
std::vector<instruction_ptr>& make_default_case() {
|
|
if(default_case)
|
|
throw type_error("attempted repeated use of catch-all");
|
|
default_case.emplace(std::vector<instruction_ptr>());
|
|
return *default_case;
|
|
}
|
|
|
|
std::vector<instruction_ptr>& get_specific_case_for(tag_type tag) {
|
|
auto existing_case = defined_cases.find(tag);
|
|
assert(existing_case != defined_cases.end());
|
|
return existing_case->second;
|
|
}
|
|
|
|
std::vector<instruction_ptr>& get_default_case() {
|
|
assert(default_case);
|
|
return *default_case;
|
|
}
|
|
|
|
bool case_defined_for(tag_type tag) {
|
|
return defined_cases.find(tag) != defined_cases.end();
|
|
}
|
|
|
|
bool default_case_defined() { return default_case.has_value(); }
|
|
|
|
size_t defined_cases_count() { return defined_cases.size(); }
|
|
};
|
|
|
|
|
|
struct case_strategy_bool {
|
|
using tag_type = bool;
|
|
using repr_type = bool;
|
|
|
|
tag_type tag_from_repr(repr_type b) { return b; }
|
|
|
|
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) {
|
|
pattern_constr* cpat;
|
|
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) ||
|
|
(cpat->constr != "True" && cpat->constr != "False") ||
|
|
cpat->params.size() != 0)
|
|
throw type_error("pattern cannot be converted to a boolean");
|
|
return cpat->constr == "True";
|
|
}
|
|
|
|
void compile_branch(
|
|
const branch_ptr& branch,
|
|
const env_ptr& env,
|
|
repr_type repr,
|
|
std::vector<instruction_ptr>& into) {
|
|
branch->expr->compile(env_ptr(new env_offset(1, env)), into);
|
|
}
|
|
|
|
size_t case_count(const type* type) {
|
|
return 2;
|
|
}
|
|
|
|
instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_bool>& ms) {
|
|
throw std::runtime_error("boolean case unimplemented!");
|
|
}
|
|
};
|
|
|
|
struct case_strategy_data {
|
|
using tag_type = int;
|
|
using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>;
|
|
|
|
tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; }
|
|
|
|
repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) {
|
|
pattern_constr* cpat;
|
|
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())))
|
|
throw type_error("pattern cannot be interpreted as constructor.");
|
|
return std::make_pair(
|
|
&static_cast<const type_data*>(type)->constructors.find(cpat->constr)->second,
|
|
&cpat->params);
|
|
}
|
|
|
|
void compile_branch(
|
|
const branch_ptr& branch,
|
|
const env_ptr& env,
|
|
const repr_type& repr,
|
|
std::vector<instruction_ptr>& into) {
|
|
env_ptr new_env = env;
|
|
for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) {
|
|
new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env));
|
|
}
|
|
|
|
into.push_back(instruction_ptr(new instruction_split(repr.second->size())));
|
|
branch->expr->compile(new_env, into);
|
|
into.push_back(instruction_ptr(new instruction_slide(repr.second->size())));
|
|
}
|
|
|
|
size_t case_count(const type* type) {
|
|
return static_cast<const type_data*>(type)->constructors.size();
|
|
}
|
|
|
|
instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_data>& ms) {
|
|
instruction_jump* jump_instruction = new instruction_jump();
|
|
instruction_ptr inst(jump_instruction);
|
|
|
|
auto data_type = static_cast<const type_data*>(type);
|
|
for(auto& constr : data_type->constructors) {
|
|
if(!ms.case_defined_for(constr.second.tag)) continue;
|
|
jump_instruction->branches.push_back(
|
|
std::move(ms.get_specific_case_for(constr.second.tag)));
|
|
jump_instruction->tag_mappings[constr.second.tag] =
|
|
jump_instruction->branches.size() - 1;
|
|
}
|
|
|
|
if(ms.default_case_defined()) {
|
|
jump_instruction->branches.push_back(
|
|
std::move(ms.get_default_case()));
|
|
for(auto& constr : data_type->constructors) {
|
|
if(ms.case_defined_for(constr.second.tag)) continue;
|
|
jump_instruction->tag_mappings[constr.second.tag] =
|
|
jump_instruction->branches.size();
|
|
}
|
|
}
|
|
|
|
return std::move(inst);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) {
|
|
T strategy;
|
|
case_mappings<T> cases;
|
|
for(auto& branch : node.branches) {
|
|
pattern_var* vpat;
|
|
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
|
|
auto& branch_into = cases.make_default_case();
|
|
env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env));
|
|
branch->expr->compile(new_env, branch_into);
|
|
} else {
|
|
auto repr = strategy.from_typed_pattern(branch->pat, type);
|
|
auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr));
|
|
strategy.compile_branch(branch, env, repr, branch_into);
|
|
}
|
|
}
|
|
|
|
if(!(cases.defined_cases_count() == strategy.case_count(type) ||
|
|
cases.default_case_defined()))
|
|
throw type_error("incomplete patterns", node.loc);
|
|
|
|
into.push_back(strategy.into_instruction(type, cases));
|
|
}
|
|
|
|
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
type_app* app_type = dynamic_cast<type_app*>(input_type.get());
|
|
type_data* data;
|
|
type_internal* internal;
|
|
|
|
of->compile(env, into);
|
|
into.push_back(instruction_ptr(new instruction_eval()));
|
|
|
|
if((data = dynamic_cast<type_data*>(app_type->constructor.get()))) {
|
|
compile_case<case_strategy_data>(*this, env, data, into);
|
|
return;
|
|
} else if((internal = dynamic_cast<type_internal*>(app_type->constructor.get()))) {
|
|
if(internal->name == "Bool") {
|
|
compile_case<case_strategy_bool>(*this, env, data, into);
|
|
return;
|
|
}
|
|
}
|
|
|
|
throw std::runtime_error("no known way to compile case expression");
|
|
}
|
|
|
|
void ast_let::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "LET: " << std::endl;
|
|
in->print(indent + 1, to);
|
|
}
|
|
|
|
void ast_let::find_free(std::set<std::string>& into) {
|
|
definitions.find_free(into);
|
|
std::set<std::string> all_free;
|
|
in->find_free(all_free);
|
|
for(auto& free_var : all_free) {
|
|
if(definitions.defs_defn.find(free_var) == definitions.defs_defn.end())
|
|
into.insert(free_var);
|
|
}
|
|
}
|
|
|
|
type_ptr ast_let::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
definitions.typecheck(mgr, env);
|
|
return in->typecheck(mgr, definitions.env);
|
|
}
|
|
|
|
void ast_let::translate(global_scope& scope) {
|
|
for(auto& def : definitions.defs_data) {
|
|
def.second->into_globals(scope);
|
|
}
|
|
for(auto& def : definitions.defs_defn) {
|
|
size_t original_params = def.second->params.size();
|
|
std::string original_name = def.second->name;
|
|
auto& global_definition = def.second->into_global(scope);
|
|
size_t captured = global_definition.params.size() - original_params;
|
|
|
|
type_env_ptr mangled_env = type_scope(env);
|
|
mangled_env->bind(def.first, env->lookup(def.first), visibility::global);
|
|
mangled_env->set_mangled_name(def.first, global_definition.name);
|
|
|
|
ast_ptr global_app(new ast_lid(original_name));
|
|
global_app->env = mangled_env;
|
|
for(auto& param : global_definition.params) {
|
|
if(!(captured--)) break;
|
|
ast_ptr new_arg(new ast_lid(param));
|
|
new_arg->env = env;
|
|
global_app = ast_ptr(new ast_app(std::move(global_app), std::move(new_arg)));
|
|
global_app->env = env;
|
|
}
|
|
translated_definitions.push_back({ def.first, std::move(global_app) });
|
|
}
|
|
in->translate(scope);
|
|
}
|
|
|
|
void ast_let::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
into.push_back(instruction_ptr(new instruction_alloc(translated_definitions.size())));
|
|
env_ptr new_env = env;
|
|
for(auto& def : translated_definitions) {
|
|
new_env = env_ptr(new env_var(definitions.env->get_mangled_name(def.first), std::move(new_env)));
|
|
}
|
|
int offset = translated_definitions.size() - 1;
|
|
for(auto& def : translated_definitions) {
|
|
def.second->compile(new_env, into);
|
|
into.push_back(instruction_ptr(new instruction_update(offset--)));
|
|
}
|
|
in->compile(new_env, into);
|
|
into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size())));
|
|
}
|
|
|
|
void ast_lambda::print(int indent, std::ostream& to) const {
|
|
print_indent(indent, to);
|
|
to << "LAMBDA";
|
|
for(auto& param : params) {
|
|
to << " " << param;
|
|
}
|
|
to << std::endl;
|
|
body->print(indent+1, to);
|
|
}
|
|
|
|
void ast_lambda::find_free(std::set<std::string>& into) {
|
|
body->find_free(free_variables);
|
|
for(auto& param : params) {
|
|
free_variables.erase(param);
|
|
}
|
|
into.insert(free_variables.begin(), free_variables.end());
|
|
}
|
|
|
|
type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) {
|
|
this->env = env;
|
|
var_env = type_scope(env);
|
|
type_ptr return_type = mgr.new_type();
|
|
type_ptr full_type = return_type;
|
|
|
|
for(auto it = params.rbegin(); it != params.rend(); it++) {
|
|
type_ptr param_type = mgr.new_type();
|
|
var_env->bind(*it, param_type);
|
|
full_type = type_ptr(new type_arr(std::move(param_type), full_type));
|
|
}
|
|
|
|
mgr.unify(return_type, body->typecheck(mgr, var_env), body->loc);
|
|
return full_type;
|
|
}
|
|
|
|
void ast_lambda::translate(global_scope& scope) {
|
|
std::vector<std::string> function_params;
|
|
for(auto& free_variable : free_variables) {
|
|
if(env->is_global(free_variable)) continue;
|
|
function_params.push_back(free_variable);
|
|
}
|
|
size_t captured_count = function_params.size();
|
|
function_params.insert(function_params.end(), params.begin(), params.end());
|
|
|
|
auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body));
|
|
type_env_ptr mangled_env = type_scope(env);
|
|
mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global);
|
|
mangled_env->set_mangled_name("lambda", new_function.name);
|
|
ast_ptr new_application = ast_ptr(new ast_lid("lambda"));
|
|
new_application->env = mangled_env;
|
|
|
|
for(auto& param : new_function.params) {
|
|
if(!(captured_count--)) break;
|
|
ast_ptr new_arg = ast_ptr(new ast_lid(param));
|
|
new_arg->env = env;
|
|
new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg)));
|
|
new_application->env = env;
|
|
}
|
|
translated = std::move(new_application);
|
|
}
|
|
|
|
void ast_lambda::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
|
|
translated->compile(env, into);
|
|
}
|
|
|
|
void pattern_var::print(std::ostream& to) const {
|
|
to << var;
|
|
}
|
|
|
|
void pattern_var::find_variables(std::set<std::string>& into) const {
|
|
into.insert(var);
|
|
}
|
|
|
|
void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
|
|
env->bind(var, t);
|
|
}
|
|
|
|
void pattern_constr::print(std::ostream& to) const {
|
|
to << constr;
|
|
for(auto& param : params) {
|
|
to << " " << param;
|
|
}
|
|
}
|
|
|
|
void pattern_constr::find_variables(std::set<std::string>& into) const {
|
|
into.insert(params.begin(), params.end());
|
|
}
|
|
|
|
void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
|
|
type_scheme_ptr constructor_type_scheme = env->lookup(constr);
|
|
if(!constructor_type_scheme) {
|
|
throw type_error(std::string("pattern using unknown constructor ") + constr, loc);
|
|
}
|
|
type_ptr constructor_type = constructor_type_scheme->instantiate(mgr);
|
|
|
|
for(auto& param : params) {
|
|
type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get());
|
|
if(!arr) throw type_error("too many parameters in constructor pattern", loc);
|
|
|
|
env->bind(param, arr->left);
|
|
constructor_type = arr->right;
|
|
}
|
|
|
|
mgr.unify(constructor_type, t, loc);
|
|
}
|