diff --git a/code/compiler/12/definition.cpp b/code/compiler/12/definition.cpp index 5f820b7..535a268 100644 --- a/code/compiler/12/definition.cpp +++ b/code/compiler/12/definition.cpp @@ -5,6 +5,7 @@ #include "llvm_context.hpp" #include "type.hpp" #include "type_env.hpp" +#include "graph.hpp" #include #include #include @@ -107,3 +108,58 @@ void definition_data::generate_llvm(llvm_context& ctx) { ctx.builder.CreateRetVoid(); } } + +void definition_group::find_free(type_mgr& mgr, type_env_ptr& env, std::set& into) { + this->env = type_scope(env); + + for(auto& def_pair : defs_defn) { + def_pair.second->find_free(mgr, env); + std::set local_dependencies; + for(auto& free_var : def_pair.second->free_variables) { + if(defs_defn.find(free_var) == defs_defn.end()) { + into.insert(free_var); + } else { + local_dependencies.insert(free_var); + } + } + std::swap(def_pair.second->free_variables, local_dependencies); + } +} + +void definition_group::typecheck(type_mgr& mgr) { + for(auto& def_data : defs_data) { + def_data.second->insert_types(env); + } + for(auto& def_data : defs_data) { + def_data.second->insert_constructors(); + } + + function_graph dependency_graph; + + for(auto& def_defn : defs_defn) { + def_defn.second->find_free(mgr, env); + dependency_graph.add_function(def_defn.second->name); + + for(auto& dependency : def_defn.second->free_variables) { + if(defs_defn.find(dependency) == defs_defn.end()) + throw 0; + dependency_graph.add_edge(def_defn.second->name, dependency); + } + } + + std::vector groups = dependency_graph.compute_order(); + for(auto it = groups.rbegin(); it != groups.rend(); it++) { + auto& group = *it; + for(auto& def_defnn_name : group->members) { + auto& def_defn = defs_defn.find(def_defnn_name)->second; + def_defn->insert_types(mgr); + } + for(auto& def_defnn_name : group->members) { + auto& def_defn = defs_defn.find(def_defnn_name)->second; + def_defn->typecheck(mgr); + } + for(auto& def_defnn_name : group->members) { + env->generalize(def_defnn_name, mgr); + } + } +} diff --git a/code/compiler/12/definition.hpp b/code/compiler/12/definition.hpp index d70d539..a063dec 100644 --- a/code/compiler/12/definition.hpp +++ b/code/compiler/12/definition.hpp @@ -75,4 +75,9 @@ using definition_data_ptr = std::unique_ptr; struct definition_group { std::map defs_data; std::map defs_defn; + + type_env_ptr env; + + void find_free(type_mgr& mgr, type_env_ptr& env, std::set& into); + void typecheck(type_mgr& mgr); }; diff --git a/code/compiler/12/graph.hpp b/code/compiler/12/graph.hpp index 2db8d7c..2807442 100644 --- a/code/compiler/12/graph.hpp +++ b/code/compiler/12/graph.hpp @@ -7,7 +7,6 @@ #include #include #include -#include using function = std::string; diff --git a/code/compiler/12/main.cpp b/code/compiler/12/main.cpp index 6495d92..2ec963c 100644 --- a/code/compiler/12/main.cpp +++ b/code/compiler/12/main.cpp @@ -24,8 +24,7 @@ void yy::parser::error(const std::string& msg) { extern definition_group global_defs; void typecheck_program( - const std::map& defs_data, - const std::map& defs_defn, + definition_group& defs, type_mgr& mgr, type_env_ptr& env) { type_ptr int_type = type_ptr(new type_base("Int")); env->bind_type("Int", int_type); @@ -39,43 +38,11 @@ void typecheck_program( env->bind("*", binop_type); env->bind("/", binop_type); - for(auto& def_data : defs_data) { - def_data.second->insert_types(env); - } - for(auto& def_data : defs_data) { - def_data.second->insert_constructors(); - } + std::set free; + defs.find_free(mgr, env, free); + defs.typecheck(mgr); - function_graph dependency_graph; - - for(auto& def_defn : defs_defn) { - def_defn.second->find_free(mgr, env); - dependency_graph.add_function(def_defn.second->name); - - for(auto& dependency : def_defn.second->free_variables) { - if(defs_defn.find(dependency) == defs_defn.end()) - throw 0; - dependency_graph.add_edge(def_defn.second->name, dependency); - } - } - - std::vector groups = dependency_graph.compute_order(); - for(auto it = groups.rbegin(); it != groups.rend(); it++) { - auto& group = *it; - for(auto& def_defnn_name : group->members) { - auto& def_defn = defs_defn.find(def_defnn_name)->second; - def_defn->insert_types(mgr); - } - for(auto& def_defnn_name : group->members) { - auto& def_defn = defs_defn.find(def_defnn_name)->second; - def_defn->typecheck(mgr); - } - for(auto& def_defnn_name : group->members) { - env->generalize(def_defnn_name, mgr); - } - } - - for(auto& pair : env->names) { + for(auto& pair : defs.env->names) { std::cout << pair.first << ": "; pair.second->print(mgr, std::cout); std::cout << std::endl; @@ -187,7 +154,7 @@ int main() { def_defn.second->body->print(1, std::cout); } try { - typecheck_program(global_defs.defs_data, global_defs.defs_defn, mgr, env); + typecheck_program(global_defs, mgr, env); compile_program(global_defs.defs_defn); gen_llvm(global_defs.defs_data, global_defs.defs_defn); } catch(unification_error& err) {