Write explanations of AST refactor in compiler series

This commit is contained in:
Danila Fedorin 2019-10-08 21:42:25 -07:00
parent e054cb98cc
commit e4bf3a8672
6 changed files with 101 additions and 15 deletions

View File

@ -6,6 +6,20 @@ void print_indent(int n, std::ostream& to) {
while(n--) to << " "; while(n--) to << " ";
} }
type_ptr ast::typecheck_common(type_mgr& mgr, const type_env& env) {
node_type = typecheck(mgr, env);
return node_type;
}
void ast::resolve_common(const type_mgr& mgr) {
type_var* var;
type_ptr resolved_type = mgr.resolve(node_type, var);
if(var) throw type_error("ambiguously typed program");
resolve(mgr);
node_type = std::move(resolved_type);
}
void ast_int::print(int indent, std::ostream& to) const { void ast_int::print(int indent, std::ostream& to) const {
print_indent(indent, to); print_indent(indent, to);
to << "INT: " << value << std::endl; to << "INT: " << value << std::endl;
@ -15,6 +29,10 @@ type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const {
return type_ptr(new type_base("Int")); return type_ptr(new type_base("Int"));
} }
void ast_int::resolve(const type_mgr& mgr) const {
}
void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushint(value))); into.push_back(instruction_ptr(new instruction_pushint(value)));
} }
@ -28,6 +46,10 @@ type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const {
return env.lookup(id); return env.lookup(id);
} }
void ast_lid::resolve(const type_mgr& mgr) const {
}
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr( into.push_back(instruction_ptr(
env->has_variable(id) ? env->has_variable(id) ?
@ -44,6 +66,10 @@ type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const {
return env.lookup(id); return env.lookup(id);
} }
void ast_uid::resolve(const type_mgr& mgr) const {
}
void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushglobal(id))); into.push_back(instruction_ptr(new instruction_pushglobal(id)));
} }
@ -56,8 +82,8 @@ void ast_binop::print(int indent, std::ostream& to) const {
} }
type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const {
type_ptr ltype = left->typecheck(mgr, env); type_ptr ltype = left->typecheck_common(mgr, env);
type_ptr rtype = right->typecheck(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env);
type_ptr ftype = env.lookup(op_name(op)); type_ptr ftype = env.lookup(op_name(op));
if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op));
@ -69,6 +95,11 @@ type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const {
return return_type; return return_type;
} }
void ast_binop::resolve(const type_mgr& mgr) const {
left->resolve_common(mgr);
right->resolve_common(mgr);
}
void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@ -86,8 +117,8 @@ void ast_app::print(int indent, std::ostream& to) const {
} }
type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const {
type_ptr ltype = left->typecheck(mgr, env); type_ptr ltype = left->typecheck_common(mgr, env);
type_ptr rtype = right->typecheck(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env);
type_ptr return_type = mgr.new_type(); type_ptr return_type = mgr.new_type();
type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
@ -95,6 +126,11 @@ type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const {
return return_type; return return_type;
} }
void ast_app::resolve(const type_mgr& mgr) const {
left->resolve_common(mgr);
right->resolve_common(mgr);
}
void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@ -114,23 +150,31 @@ void ast_case::print(int indent, std::ostream& to) const {
type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const {
type_var* var; type_var* var;
type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); type_ptr case_type = mgr.resolve(of->typecheck_common(mgr, env), var);
type_ptr branch_type = mgr.new_type(); type_ptr branch_type = mgr.new_type();
if(!dynamic_cast<type_base*>(case_type.get())) {
throw type_error("attempting case analysis of non-data type");
}
for(auto& branch : branches) { for(auto& branch : branches) {
type_env new_env = env.scope(); type_env new_env = env.scope();
branch->pat->match(case_type, mgr, new_env); branch->pat->match(case_type, mgr, new_env);
type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); type_ptr curr_branch_type = branch->expr->typecheck_common(mgr, new_env);
mgr.unify(branch_type, curr_branch_type); mgr.unify(branch_type, curr_branch_type);
} }
case_type = mgr.resolve(case_type, var);
if(!dynamic_cast<type_data*>(case_type.get())) {
throw type_error("attempting case analysis of non-data type");
}
return branch_type; return branch_type;
} }
void ast_case::resolve(const type_mgr& mgr) const {
of->resolve_common(mgr);
for(auto& branch : branches) {
branch->expr->resolve_common(mgr);
}
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
} }

View File

@ -8,12 +8,18 @@
#include "env.hpp" #include "env.hpp"
struct ast { struct ast {
type_ptr node_type;
virtual ~ast() = default; virtual ~ast() = default;
virtual void print(int indent, std::ostream& to) const = 0; virtual void print(int indent, std::ostream& to) const = 0;
virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0; virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0;
virtual void resolve(const type_mgr& mgr) const = 0;
virtual void compile(const env_ptr& env, virtual void compile(const env_ptr& env,
std::vector<instruction_ptr>& into) const = 0; std::vector<instruction_ptr>& into) const = 0;
type_ptr typecheck_common(type_mgr& mgr, const type_env& env);
void resolve_common(const type_mgr& mgr);
}; };
using ast_ptr = std::unique_ptr<ast>; using ast_ptr = std::unique_ptr<ast>;
@ -52,6 +58,7 @@ struct definition {
virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0;
virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0;
virtual void resolve(const type_mgr& mgr) const = 0;
}; };
using definition_ptr = std::unique_ptr<definition>; using definition_ptr = std::unique_ptr<definition>;
@ -64,6 +71,7 @@ struct ast_int : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -75,6 +83,7 @@ struct ast_lid : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -86,6 +95,7 @@ struct ast_uid : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -99,6 +109,7 @@ struct ast_binop : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -111,6 +122,7 @@ struct ast_app : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -123,6 +135,7 @@ struct ast_case : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -162,6 +175,7 @@ struct definition_defn : public definition {
void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_first(type_mgr& mgr, type_env& env);
void typecheck_second(type_mgr& mgr, const type_env& env) const; void typecheck_second(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
}; };
struct definition_data : public definition { struct definition_data : public definition {
@ -173,4 +187,5 @@ struct definition_data : public definition {
void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_first(type_mgr& mgr, type_env& env);
void typecheck_second(type_mgr& mgr, const type_env& env) const; void typecheck_second(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
}; };

View File

@ -24,16 +24,23 @@ void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const
type_it++; type_it++;
} }
type_ptr body_type = body->typecheck(mgr, new_env); type_ptr body_type = body->typecheck_common(mgr, new_env);
mgr.unify(return_type, body_type); mgr.unify(return_type, body_type);
} }
void definition_defn::resolve(const type_mgr& mgr) const {
body->resolve_common(mgr);
}
void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { void definition_data::typecheck_first(type_mgr& mgr, type_env& env) {
type_ptr return_type = type_ptr(new type_base(name)); type_data* this_type = new type_data(name);
type_ptr return_type = type_ptr(this_type);
int next_tag = 0;
for(auto& constructor : constructors) { for(auto& constructor : constructors) {
type_ptr full_type = return_type; this_type->constructors[constructor->name] = { next_tag++ };
type_ptr full_type = return_type;
for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) { for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) {
type_ptr type = type_ptr(new type_base(*it)); type_ptr type = type_ptr(new type_base(*it));
full_type = type_ptr(new type_arr(type, full_type)); full_type = type_ptr(new type_arr(type, full_type));
@ -46,3 +53,8 @@ void definition_data::typecheck_first(type_mgr& mgr, type_env& env) {
void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const { void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const {
// Nothing // Nothing
} }
void definition_data::resolve(const type_mgr& mgr) const {
// Nothing
}

View File

@ -36,6 +36,10 @@ void typecheck_program(
pair.second->print(mgr, std::cout); pair.second->print(mgr, std::cout);
std::cout << std::endl; std::cout << std::endl;
} }
for(auto& def : prog) {
def->resolve(mgr);
}
} }
int main() { int main() {

View File

@ -44,7 +44,7 @@ type_ptr type_mgr::new_arrow_type() {
return type_ptr(new type_arr(new_type(), new_type())); return type_ptr(new type_arr(new_type(), new_type()));
} }
type_ptr type_mgr::resolve(type_ptr t, type_var*& var) { type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const {
type_var* cast; type_var* cast;
var = nullptr; var = nullptr;

View File

@ -30,6 +30,17 @@ struct type_base : public type {
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
}; };
struct type_data : public type_base {
struct constructor {
int tag;
};
std::map<std::string, constructor> constructors;
type_data(std::string n)
: type_base(std::move(n)) {}
};
struct type_arr : public type { struct type_arr : public type {
type_ptr left; type_ptr left;
type_ptr right; type_ptr right;
@ -49,6 +60,6 @@ struct type_mgr {
type_ptr new_arrow_type(); type_ptr new_arrow_type();
void unify(type_ptr l, type_ptr r); void unify(type_ptr l, type_ptr r);
type_ptr resolve(type_ptr t, type_var*& var); type_ptr resolve(type_ptr t, type_var*& var) const;
void bind(const std::string& s, type_ptr t); void bind(const std::string& s, type_ptr t);
}; };