Write explanations of AST refactor in compiler series

This commit is contained in:
Danila Fedorin 2019-10-08 21:42:25 -07:00
parent d3d73e0e9c
commit 7e9bd95846
7 changed files with 180 additions and 16 deletions

View File

@ -6,6 +6,20 @@ void print_indent(int n, std::ostream& to) {
while(n--) to << " "; while(n--) to << " ";
} }
type_ptr ast::typecheck_common(type_mgr& mgr, const type_env& env) {
node_type = typecheck(mgr, env);
return node_type;
}
void ast::resolve_common(const type_mgr& mgr) {
type_var* var;
type_ptr resolved_type = mgr.resolve(node_type, var);
if(var) throw type_error("ambiguously typed program");
resolve(mgr);
node_type = std::move(resolved_type);
}
void ast_int::print(int indent, std::ostream& to) const { void ast_int::print(int indent, std::ostream& to) const {
print_indent(indent, to); print_indent(indent, to);
to << "INT: " << value << std::endl; to << "INT: " << value << std::endl;
@ -15,6 +29,10 @@ type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const {
return type_ptr(new type_base("Int")); return type_ptr(new type_base("Int"));
} }
void ast_int::resolve(const type_mgr& mgr) const {
}
void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushint(value))); into.push_back(instruction_ptr(new instruction_pushint(value)));
} }
@ -28,6 +46,10 @@ type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const {
return env.lookup(id); return env.lookup(id);
} }
void ast_lid::resolve(const type_mgr& mgr) const {
}
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr( into.push_back(instruction_ptr(
env->has_variable(id) ? env->has_variable(id) ?
@ -44,6 +66,10 @@ type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const {
return env.lookup(id); return env.lookup(id);
} }
void ast_uid::resolve(const type_mgr& mgr) const {
}
void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
into.push_back(instruction_ptr(new instruction_pushglobal(id))); into.push_back(instruction_ptr(new instruction_pushglobal(id)));
} }
@ -56,8 +82,8 @@ void ast_binop::print(int indent, std::ostream& to) const {
} }
type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const {
type_ptr ltype = left->typecheck(mgr, env); type_ptr ltype = left->typecheck_common(mgr, env);
type_ptr rtype = right->typecheck(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env);
type_ptr ftype = env.lookup(op_name(op)); type_ptr ftype = env.lookup(op_name(op));
if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op));
@ -69,6 +95,11 @@ type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const {
return return_type; return return_type;
} }
void ast_binop::resolve(const type_mgr& mgr) const {
left->resolve_common(mgr);
right->resolve_common(mgr);
}
void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@ -86,8 +117,8 @@ void ast_app::print(int indent, std::ostream& to) const {
} }
type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const {
type_ptr ltype = left->typecheck(mgr, env); type_ptr ltype = left->typecheck_common(mgr, env);
type_ptr rtype = right->typecheck(mgr, env); type_ptr rtype = right->typecheck_common(mgr, env);
type_ptr return_type = mgr.new_type(); type_ptr return_type = mgr.new_type();
type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
@ -95,6 +126,11 @@ type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const {
return return_type; return return_type;
} }
void ast_app::resolve(const type_mgr& mgr) const {
left->resolve_common(mgr);
right->resolve_common(mgr);
}
void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
right->compile(env, into); right->compile(env, into);
left->compile(env_ptr(new env_offset(1, env)), into); left->compile(env_ptr(new env_offset(1, env)), into);
@ -114,23 +150,31 @@ void ast_case::print(int indent, std::ostream& to) const {
type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const {
type_var* var; type_var* var;
type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); type_ptr case_type = mgr.resolve(of->typecheck_common(mgr, env), var);
type_ptr branch_type = mgr.new_type(); type_ptr branch_type = mgr.new_type();
if(!dynamic_cast<type_base*>(case_type.get())) {
throw type_error("attempting case analysis of non-data type");
}
for(auto& branch : branches) { for(auto& branch : branches) {
type_env new_env = env.scope(); type_env new_env = env.scope();
branch->pat->match(case_type, mgr, new_env); branch->pat->match(case_type, mgr, new_env);
type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); type_ptr curr_branch_type = branch->expr->typecheck_common(mgr, new_env);
mgr.unify(branch_type, curr_branch_type); mgr.unify(branch_type, curr_branch_type);
} }
case_type = mgr.resolve(case_type, var);
if(!dynamic_cast<type_data*>(case_type.get())) {
throw type_error("attempting case analysis of non-data type");
}
return branch_type; return branch_type;
} }
void ast_case::resolve(const type_mgr& mgr) const {
of->resolve_common(mgr);
for(auto& branch : branches) {
branch->expr->resolve_common(mgr);
}
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
} }

View File

@ -8,12 +8,18 @@
#include "env.hpp" #include "env.hpp"
struct ast { struct ast {
type_ptr node_type;
virtual ~ast() = default; virtual ~ast() = default;
virtual void print(int indent, std::ostream& to) const = 0; virtual void print(int indent, std::ostream& to) const = 0;
virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0; virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0;
virtual void resolve(const type_mgr& mgr) const = 0;
virtual void compile(const env_ptr& env, virtual void compile(const env_ptr& env,
std::vector<instruction_ptr>& into) const = 0; std::vector<instruction_ptr>& into) const = 0;
type_ptr typecheck_common(type_mgr& mgr, const type_env& env);
void resolve_common(const type_mgr& mgr);
}; };
using ast_ptr = std::unique_ptr<ast>; using ast_ptr = std::unique_ptr<ast>;
@ -52,6 +58,7 @@ struct definition {
virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0;
virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0;
virtual void resolve(const type_mgr& mgr) const = 0;
}; };
using definition_ptr = std::unique_ptr<definition>; using definition_ptr = std::unique_ptr<definition>;
@ -64,6 +71,7 @@ struct ast_int : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -75,6 +83,7 @@ struct ast_lid : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -86,6 +95,7 @@ struct ast_uid : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -99,6 +109,7 @@ struct ast_binop : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -111,6 +122,7 @@ struct ast_app : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -123,6 +135,7 @@ struct ast_case : public ast {
void print(int indent, std::ostream& to) const; void print(int indent, std::ostream& to) const;
type_ptr typecheck(type_mgr& mgr, const type_env& env) const; type_ptr typecheck(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const; void compile(const env_ptr& env, std::vector<instruction_ptr>& into) const;
}; };
@ -162,6 +175,7 @@ struct definition_defn : public definition {
void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_first(type_mgr& mgr, type_env& env);
void typecheck_second(type_mgr& mgr, const type_env& env) const; void typecheck_second(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
}; };
struct definition_data : public definition { struct definition_data : public definition {
@ -173,4 +187,5 @@ struct definition_data : public definition {
void typecheck_first(type_mgr& mgr, type_env& env); void typecheck_first(type_mgr& mgr, type_env& env);
void typecheck_second(type_mgr& mgr, const type_env& env) const; void typecheck_second(type_mgr& mgr, const type_env& env) const;
void resolve(const type_mgr& mgr) const;
}; };

View File

@ -24,16 +24,23 @@ void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const
type_it++; type_it++;
} }
type_ptr body_type = body->typecheck(mgr, new_env); type_ptr body_type = body->typecheck_common(mgr, new_env);
mgr.unify(return_type, body_type); mgr.unify(return_type, body_type);
} }
void definition_defn::resolve(const type_mgr& mgr) const {
body->resolve_common(mgr);
}
void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { void definition_data::typecheck_first(type_mgr& mgr, type_env& env) {
type_ptr return_type = type_ptr(new type_base(name)); type_data* this_type = new type_data(name);
type_ptr return_type = type_ptr(this_type);
int next_tag = 0;
for(auto& constructor : constructors) { for(auto& constructor : constructors) {
type_ptr full_type = return_type; this_type->constructors[constructor->name] = { next_tag++ };
type_ptr full_type = return_type;
for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) { for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) {
type_ptr type = type_ptr(new type_base(*it)); type_ptr type = type_ptr(new type_base(*it));
full_type = type_ptr(new type_arr(type, full_type)); full_type = type_ptr(new type_arr(type, full_type));
@ -46,3 +53,8 @@ void definition_data::typecheck_first(type_mgr& mgr, type_env& env) {
void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const { void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const {
// Nothing // Nothing
} }
void definition_data::resolve(const type_mgr& mgr) const {
// Nothing
}

View File

@ -36,6 +36,10 @@ void typecheck_program(
pair.second->print(mgr, std::cout); pair.second->print(mgr, std::cout);
std::cout << std::endl; std::cout << std::endl;
} }
for(auto& def : prog) {
def->resolve(mgr);
}
} }
int main() { int main() {

View File

@ -44,7 +44,7 @@ type_ptr type_mgr::new_arrow_type() {
return type_ptr(new type_arr(new_type(), new_type())); return type_ptr(new type_arr(new_type(), new_type()));
} }
type_ptr type_mgr::resolve(type_ptr t, type_var*& var) { type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const {
type_var* cast; type_var* cast;
var = nullptr; var = nullptr;

View File

@ -30,6 +30,17 @@ struct type_base : public type {
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
}; };
struct type_data : public type_base {
struct constructor {
int tag;
};
std::map<std::string, constructor> constructors;
type_data(std::string n)
: type_base(std::move(n)) {}
};
struct type_arr : public type { struct type_arr : public type {
type_ptr left; type_ptr left;
type_ptr right; type_ptr right;
@ -49,6 +60,6 @@ struct type_mgr {
type_ptr new_arrow_type(); type_ptr new_arrow_type();
void unify(type_ptr l, type_ptr r); void unify(type_ptr l, type_ptr r);
type_ptr resolve(type_ptr t, type_var*& var); type_ptr resolve(type_ptr t, type_var*& var) const;
void bind(const std::string& s, type_ptr t); void bind(const std::string& s, type_ptr t);
}; };

View File

@ -253,7 +253,7 @@ We do not have to do this for `ast_uid`:
{{< codelines "C++" "compiler/06/ast.cpp" 47 49 >}} {{< codelines "C++" "compiler/06/ast.cpp" 47 49 >}}
On to `ast_binop`! This is the first time we have to change our environment. On to `ast_binop`! This is the first time we have to change our environment.
Once we build the right operand on the stack, every offset that we counted As we said earlier, once we build the right operand on the stack, every offset that we counted
from the top of the stack will have been shifted by 1 (we see this from the top of the stack will have been shifted by 1 (we see this
in our compilation scheme for function application). So, in our compilation scheme for function application). So,
we create a new environment with `env_offset`, and use that we create a new environment with `env_offset`, and use that
@ -274,3 +274,81 @@ for the exact same reason as before.
Case expressions are the only thing left on the agenda. This Case expressions are the only thing left on the agenda. This
is the time during which we have to perform desugaring. Here, is the time during which we have to perform desugaring. Here,
though, we run into an issue: we don't have tags assigned to constructors! though, we run into an issue: we don't have tags assigned to constructors!
We need to adjust our code to keep track of the tags of the various
constructors of a type. To do this, we add a subclass for the `type_base`
struct, called `type_data`:
{{< todo >}}Link code{{< /todo >}}
When we create types from `definition_data`, we tag the corresponding constructors:
{{< todo >}}Link code{{< /todo >}}
Ah, but that doesn't solve the problem. Once we performed type checking, we don't keep
the types that we computed for an AST node in the node. And obviously, we don't want
to go looking for them again. Furthermore, we can't just look up a constructor
in the environment, since we can well have patterns that don't have __any__ constructors:
```
match l {
l -> { 0 }
}
```
So, we want each `ast` node to store its type (well, in practice we only need this for
`ast_case`, but we might as well store it for all nodes). We can add it, no problem:
{{< todo >}}Link code{{< /todo >}}
Now, we can add another, non-virtual `typecheck` method (let's call it `typecheck_common`,
since naming is hard). This method will call `typecheck`, and store the output into
the `node_type` field.
The signature is identical to `typecheck`, except it's neither virtual nor const:
```
type_ptr typecheck_common(type_mgr& mgr, const type_env& env);
```
And the implementation is as simple as you think:
{{< todo >}}Link code{{< /todo >}}
In client code (`definition_defn::typecheck_first` for instance), we should now
use `typecheck_common` instead of `typecheck`. With that done, we're almost there.
However, we're still missing something: most likely, the initial type assigned to any
node is a `type_var`, or a type variable. In this case, `type_var` __needs__ the information
from `type_mgr`, which we will not be keeping around. Besides, it's cleaner to keep the actual type
as a member of the node, not a variable type that references it. In order
to address this, we write two conversion functions that call `resolve` on all
types in an AST, given a type manager. After this is done, the type manager can be thrown away.
The signatures of the functions are as follows:
```
void resolve_common(const type_mgr& mgr);
virtual void resolve(const type_mgr& mgr) const = 0;
```
We also add the `resolve` method to `definition`, so that we can call it
without having to run `dynamic_cast`. The implementation for `resolve_common`
just resolves the type:
{{< todo >}}Link code{{< /todo >}}
The virtual `resolve` just calls `resolve_common` on an all `ast` children
of a node. Here's a sample implementation from `ast_binop`:
{{< todo >}}Link code{{< /todo >}}
And here's the implementation of `resolve` on `definition_defn`:
{{< todo >}}Link code{{< /todo >}}
Finally, we call `resolve` from inside `typecheck_program` in `main.cpp`:
{{< todo >}}Link code{{< /todo >}}
Finally, we're ready to implement the code for compiling `ast_case`.
{{< todo >}}Figure out how to keep all trees not requiring a type manager. {{< /todo >}}
{{< todo >}}Backport bugfix in case's typecheck{{< /todo >}}