From ef86d60bda1054cdbc11ef65112f756a3c493181 Mon Sep 17 00:00:00 2001 From: Danila Fedorin Date: Tue, 1 Oct 2019 11:05:21 -0700 Subject: [PATCH] Create new 'branch' for part 6 of compiler series --- 06/CMakeLists.txt | 25 ++++++ 06/ast.cpp | 144 ++++++++++++++++++++++++++++++++++ 06/ast.hpp | 172 +++++++++++++++++++++++++++++++++++++++++ 06/definition.cpp | 48 ++++++++++++ 06/env.cpp | 16 ++++ 06/env.hpp | 16 ++++ 06/error.cpp | 5 ++ 06/error.hpp | 21 +++++ 06/examples/bad1.txt | 2 + 06/examples/bad2.txt | 1 + 06/examples/bad3.txt | 8 ++ 06/examples/works1.txt | 2 + 06/examples/works2.txt | 3 + 06/examples/works3.txt | 7 ++ 06/main.cpp | 70 +++++++++++++++++ 06/parser.y | 140 +++++++++++++++++++++++++++++++++ 06/scanner.l | 34 ++++++++ 06/type.cpp | 99 ++++++++++++++++++++++++ 06/type.hpp | 54 +++++++++++++ 19 files changed, 867 insertions(+) create mode 100644 06/CMakeLists.txt create mode 100644 06/ast.cpp create mode 100644 06/ast.hpp create mode 100644 06/definition.cpp create mode 100644 06/env.cpp create mode 100644 06/env.hpp create mode 100644 06/error.cpp create mode 100644 06/error.hpp create mode 100644 06/examples/bad1.txt create mode 100644 06/examples/bad2.txt create mode 100644 06/examples/bad3.txt create mode 100644 06/examples/works1.txt create mode 100644 06/examples/works2.txt create mode 100644 06/examples/works3.txt create mode 100644 06/main.cpp create mode 100644 06/parser.y create mode 100644 06/scanner.l create mode 100644 06/type.cpp create mode 100644 06/type.hpp diff --git a/06/CMakeLists.txt b/06/CMakeLists.txt new file mode 100644 index 0000000..9d2d571 --- /dev/null +++ b/06/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required(VERSION 3.1) +project(compiler) + +find_package(BISON) +find_package(FLEX) +bison_target(parser + ${CMAKE_CURRENT_SOURCE_DIR}/parser.y + ${CMAKE_CURRENT_BINARY_DIR}/parser.cpp + COMPILE_FLAGS "-d") +flex_target(scanner + ${CMAKE_CURRENT_SOURCE_DIR}/scanner.l + ${CMAKE_CURRENT_BINARY_DIR}/scanner.cpp) +add_flex_bison_dependency(scanner parser) + +add_executable(compiler + ast.cpp ast.hpp definition.cpp + env.cpp env.hpp + type.cpp type.hpp + error.cpp error.hpp + ${BISON_parser_OUTPUTS} + ${FLEX_scanner_OUTPUTS} + main.cpp +) +target_include_directories(compiler PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_include_directories(compiler PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/06/ast.cpp b/06/ast.cpp new file mode 100644 index 0000000..f8a88df --- /dev/null +++ b/06/ast.cpp @@ -0,0 +1,144 @@ +#include "ast.hpp" +#include +#include "error.hpp" + +std::string op_name(binop op) { + switch(op) { + case PLUS: return "+"; + case MINUS: return "-"; + case TIMES: return "*"; + case DIVIDE: return "/"; + } + return "??"; +} + +void print_indent(int n, std::ostream& to) { + while(n--) to << " "; +} + +void ast_int::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "INT: " << value << std::endl; +} + +type_ptr ast_int::typecheck(type_mgr& mgr, const type_env& env) const { + return type_ptr(new type_base("Int")); +} + +void ast_lid::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "LID: " << id << std::endl; +} + +type_ptr ast_lid::typecheck(type_mgr& mgr, const type_env& env) const { + return env.lookup(id); +} + +void ast_uid::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "UID: " << id << std::endl; +} + +type_ptr ast_uid::typecheck(type_mgr& mgr, const type_env& env) const { + return env.lookup(id); +} + +void ast_binop::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "BINOP: " << op_name(op) << std::endl; + left->print(indent + 1, to); + right->print(indent + 1, to); +} + +type_ptr ast_binop::typecheck(type_mgr& mgr, const type_env& env) const { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); + type_ptr ftype = env.lookup(op_name(op)); + if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op)); + + type_ptr return_type = mgr.new_type(); + type_ptr arrow_one = type_ptr(new type_arr(rtype, return_type)); + type_ptr arrow_two = type_ptr(new type_arr(ltype, arrow_one)); + + mgr.unify(arrow_two, ftype); + return return_type; +} + +void ast_app::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "APP:" << std::endl; + left->print(indent + 1, to); + right->print(indent + 1, to); +} + +type_ptr ast_app::typecheck(type_mgr& mgr, const type_env& env) const { + type_ptr ltype = left->typecheck(mgr, env); + type_ptr rtype = right->typecheck(mgr, env); + + type_ptr return_type = mgr.new_type(); + type_ptr arrow = type_ptr(new type_arr(rtype, return_type)); + mgr.unify(arrow, ltype); + return return_type; +} + +void ast_case::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "CASE: " << std::endl; + for(auto& branch : branches) { + print_indent(indent + 1, to); + branch->pat->print(to); + to << std::endl; + branch->expr->print(indent + 2, to); + } +} + +type_ptr ast_case::typecheck(type_mgr& mgr, const type_env& env) const { + type_var* var; + type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var); + type_ptr branch_type = mgr.new_type(); + + if(!dynamic_cast(case_type.get())) { + throw type_error("attempting case analysis of non-data type"); + } + + for(auto& branch : branches) { + type_env new_env = env.scope(); + branch->pat->match(case_type, mgr, new_env); + type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env); + mgr.unify(branch_type, curr_branch_type); + } + + return branch_type; +} + +void pattern_var::print(std::ostream& to) const { + to << var; +} + +void pattern_var::match(type_ptr t, type_mgr& mgr, type_env& env) const { + env.bind(var, t); +} + +void pattern_constr::print(std::ostream& to) const { + to << constr; + for(auto& param : params) { + to << " " << param; + } +} + +void pattern_constr::match(type_ptr t, type_mgr& mgr, type_env& env) const { + type_ptr constructor_type = env.lookup(constr); + if(!constructor_type) { + throw type_error(std::string("pattern using unknown constructor ") + constr); + } + + for(int i = 0; i < params.size(); i++) { + type_arr* arr = dynamic_cast(constructor_type.get()); + if(!arr) throw type_error("too many parameters in constructor pattern"); + + env.bind(params[i], arr->left); + constructor_type = arr->right; + } + + mgr.unify(t, constructor_type); +} diff --git a/06/ast.hpp b/06/ast.hpp new file mode 100644 index 0000000..fcfed19 --- /dev/null +++ b/06/ast.hpp @@ -0,0 +1,172 @@ +#pragma once +#include +#include +#include "type.hpp" +#include "env.hpp" + +struct ast { + virtual ~ast() = default; + + virtual void print(int indent, std::ostream& to) const = 0; + virtual type_ptr typecheck(type_mgr& mgr, const type_env& env) const = 0; +}; + +using ast_ptr = std::unique_ptr; + +struct pattern { + virtual ~pattern() = default; + + virtual void print(std::ostream& to) const = 0; + virtual void match(type_ptr t, type_mgr& mgr, type_env& env) const = 0; +}; + +using pattern_ptr = std::unique_ptr; + +struct branch { + pattern_ptr pat; + ast_ptr expr; + + branch(pattern_ptr p, ast_ptr a) + : pat(std::move(p)), expr(std::move(a)) {} +}; + +using branch_ptr = std::unique_ptr; + +struct constructor { + std::string name; + std::vector types; + + constructor(std::string n, std::vector ts) + : name(std::move(n)), types(std::move(ts)) {} +}; + +using constructor_ptr = std::unique_ptr; + +struct definition { + virtual ~definition() = default; + + virtual void typecheck_first(type_mgr& mgr, type_env& env) = 0; + virtual void typecheck_second(type_mgr& mgr, const type_env& env) const = 0; +}; + +using definition_ptr = std::unique_ptr; + +enum binop { + PLUS, + MINUS, + TIMES, + DIVIDE +}; + +struct ast_int : public ast { + int value; + + explicit ast_int(int v) + : value(v) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct ast_lid : public ast { + std::string id; + + explicit ast_lid(std::string i) + : id(std::move(i)) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct ast_uid : public ast { + std::string id; + + explicit ast_uid(std::string i) + : id(std::move(i)) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct ast_binop : public ast { + binop op; + ast_ptr left; + ast_ptr right; + + ast_binop(binop o, ast_ptr l, ast_ptr r) + : op(o), left(std::move(l)), right(std::move(r)) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct ast_app : public ast { + ast_ptr left; + ast_ptr right; + + ast_app(ast_ptr l, ast_ptr r) + : left(std::move(l)), right(std::move(r)) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct ast_case : public ast { + ast_ptr of; + std::vector branches; + + ast_case(ast_ptr o, std::vector b) + : of(std::move(o)), branches(std::move(b)) {} + + void print(int indent, std::ostream& to) const; + type_ptr typecheck(type_mgr& mgr, const type_env& env) const; +}; + +struct pattern_var : public pattern { + std::string var; + + pattern_var(std::string v) + : var(std::move(v)) {} + + void print(std::ostream &to) const; + void match(type_ptr t, type_mgr& mgr, type_env& env) const; +}; + +struct pattern_constr : public pattern { + std::string constr; + std::vector params; + + pattern_constr(std::string c, std::vector p) + : constr(std::move(c)), params(std::move(p)) {} + + void print(std::ostream &to) const; + void match(type_ptr t, type_mgr&, type_env& env) const; +}; + +struct definition_defn : public definition { + std::string name; + std::vector params; + ast_ptr body; + + type_ptr return_type; + std::vector param_types; + + definition_defn(std::string n, std::vector p, ast_ptr b) + : name(std::move(n)), params(std::move(p)), body(std::move(b)) { + + } + + void typecheck_first(type_mgr& mgr, type_env& env); + void typecheck_second(type_mgr& mgr, const type_env& env) const; +}; + +struct definition_data : public definition { + std::string name; + std::vector constructors; + + definition_data(std::string n, std::vector cs) + : name(std::move(n)), constructors(std::move(cs)) {} + + void typecheck_first(type_mgr& mgr, type_env& env); + void typecheck_second(type_mgr& mgr, const type_env& env) const; +}; diff --git a/06/definition.cpp b/06/definition.cpp new file mode 100644 index 0000000..f0889b6 --- /dev/null +++ b/06/definition.cpp @@ -0,0 +1,48 @@ +#include "ast.hpp" + +void definition_defn::typecheck_first(type_mgr& mgr, type_env& env) { + return_type = mgr.new_type(); + type_ptr full_type = return_type; + + for(auto it = params.rbegin(); it != params.rend(); it++) { + type_ptr param_type = mgr.new_type(); + full_type = type_ptr(new type_arr(param_type, full_type)); + param_types.push_back(param_type); + } + + env.bind(name, full_type); +} + +void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const { + type_env new_env = env.scope(); + auto param_it = params.begin(); + auto type_it = param_types.rbegin(); + + while(param_it != params.end() && type_it != param_types.rend()) { + new_env.bind(*param_it, *type_it); + param_it++; + type_it++; + } + + type_ptr body_type = body->typecheck(mgr, new_env); + mgr.unify(return_type, body_type); +} + +void definition_data::typecheck_first(type_mgr& mgr, type_env& env) { + type_ptr return_type = type_ptr(new type_base(name)); + + for(auto& constructor : constructors) { + type_ptr full_type = return_type; + + for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) { + type_ptr type = type_ptr(new type_base(*it)); + full_type = type_ptr(new type_arr(type, full_type)); + } + + env.bind(constructor->name, full_type); + } +} + +void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const { + // Nothing +} diff --git a/06/env.cpp b/06/env.cpp new file mode 100644 index 0000000..74059a8 --- /dev/null +++ b/06/env.cpp @@ -0,0 +1,16 @@ +#include "env.hpp" + +type_ptr type_env::lookup(const std::string& name) const { + auto it = names.find(name); + if(it != names.end()) return it->second; + if(parent) return parent->lookup(name); + return nullptr; +} + +void type_env::bind(const std::string& name, type_ptr t) { + names[name] = t; +} + +type_env type_env::scope() const { + return type_env(this); +} diff --git a/06/env.hpp b/06/env.hpp new file mode 100644 index 0000000..6470bdd --- /dev/null +++ b/06/env.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include "type.hpp" + +struct type_env { + std::map names; + type_env const* parent = nullptr; + + type_env(type_env const* p) + : parent(p) {} + type_env() : type_env(nullptr) {} + + type_ptr lookup(const std::string& name) const; + void bind(const std::string& name, type_ptr t); + type_env scope() const; +}; diff --git a/06/error.cpp b/06/error.cpp new file mode 100644 index 0000000..f5125e3 --- /dev/null +++ b/06/error.cpp @@ -0,0 +1,5 @@ +#include "error.hpp" + +const char* type_error::what() const noexcept { + return "an error occured while checking the types of the program"; +} diff --git a/06/error.hpp b/06/error.hpp new file mode 100644 index 0000000..5bfbc7e --- /dev/null +++ b/06/error.hpp @@ -0,0 +1,21 @@ +#pragma once +#include +#include "type.hpp" + +struct type_error : std::exception { + std::string description; + + type_error(std::string d) + : description(std::move(d)) {} + + const char* what() const noexcept override; +}; + +struct unification_error : public type_error { + type_ptr left; + type_ptr right; + + unification_error(type_ptr l, type_ptr r) + : left(std::move(l)), right(std::move(r)), + type_error("failed to unify types") {} +}; diff --git a/06/examples/bad1.txt b/06/examples/bad1.txt new file mode 100644 index 0000000..86d4bc4 --- /dev/null +++ b/06/examples/bad1.txt @@ -0,0 +1,2 @@ +data Bool = { True, False } +defn main = { 3 + True } diff --git a/06/examples/bad2.txt b/06/examples/bad2.txt new file mode 100644 index 0000000..def8785 --- /dev/null +++ b/06/examples/bad2.txt @@ -0,0 +1 @@ +defn main = { 1 2 3 4 5 } diff --git a/06/examples/bad3.txt b/06/examples/bad3.txt new file mode 100644 index 0000000..6f82b3d --- /dev/null +++ b/06/examples/bad3.txt @@ -0,0 +1,8 @@ +data List = { Nil, Cons Int List } + +defn head l = { + case l of { + Nil -> { 0 } + Cons x y z -> { x } + } +} diff --git a/06/examples/works1.txt b/06/examples/works1.txt new file mode 100644 index 0000000..bedb5d8 --- /dev/null +++ b/06/examples/works1.txt @@ -0,0 +1,2 @@ +defn main = { plus 320 6 } +defn plus x y = { x + y } diff --git a/06/examples/works2.txt b/06/examples/works2.txt new file mode 100644 index 0000000..8332fde --- /dev/null +++ b/06/examples/works2.txt @@ -0,0 +1,3 @@ +defn add x y = { x + y } +defn double x = { add x x } +defn main = { double 163 } diff --git a/06/examples/works3.txt b/06/examples/works3.txt new file mode 100644 index 0000000..cfffd20 --- /dev/null +++ b/06/examples/works3.txt @@ -0,0 +1,7 @@ +data List = { Nil, Cons Int List } +defn length l = { + case l of { + Nil -> { 0 } + Cons x xs -> { 1 + length xs } + } +} diff --git a/06/main.cpp b/06/main.cpp new file mode 100644 index 0000000..60dd9c9 --- /dev/null +++ b/06/main.cpp @@ -0,0 +1,70 @@ +#include "ast.hpp" +#include +#include "parser.hpp" +#include "error.hpp" +#include "type.hpp" + +void yy::parser::error(const std::string& msg) { + std::cout << "An error occured: " << msg << std::endl; +} + +extern std::vector program; + +void typecheck_program( + const std::vector& prog, + type_mgr& mgr, type_env& env) { + type_ptr int_type = type_ptr(new type_base("Int")); + type_ptr binop_type = type_ptr(new type_arr( + int_type, + type_ptr(new type_arr(int_type, int_type)))); + + env.bind("+", binop_type); + env.bind("-", binop_type); + env.bind("*", binop_type); + env.bind("/", binop_type); + + for(auto& def : prog) { + def->typecheck_first(mgr, env); + } + + for(auto& def : prog) { + def->typecheck_second(mgr, env); + } + + for(auto& pair : env.names) { + std::cout << pair.first << ": "; + pair.second->print(mgr, std::cout); + std::cout << std::endl; + } +} + +int main() { + yy::parser parser; + type_mgr mgr; + type_env env; + + parser.parse(); + for(auto& definition : program) { + definition_defn* def = dynamic_cast(definition.get()); + if(!def) continue; + + std::cout << def->name; + for(auto& param : def->params) std::cout << " " << param; + std::cout << ":" << std::endl; + + def->body->print(1, std::cout); + } + try { + typecheck_program(program, mgr, env); + } catch(unification_error& err) { + std::cout << "failed to unify types: " << std::endl; + std::cout << " (1) \033[34m"; + err.left->print(mgr, std::cout); + std::cout << "\033[0m" << std::endl; + std::cout << " (2) \033[32m"; + err.right->print(mgr, std::cout); + std::cout << "\033[0m" << std::endl; + } catch(type_error& err) { + std::cout << "failed to type check program: " << err.description << std::endl; + } +} diff --git a/06/parser.y b/06/parser.y new file mode 100644 index 0000000..3874aca --- /dev/null +++ b/06/parser.y @@ -0,0 +1,140 @@ +%{ +#include +#include +#include "ast.hpp" +#include "parser.hpp" + +std::vector program; +extern yy::parser::symbol_type yylex(); + +%} + +%token PLUS +%token TIMES +%token MINUS +%token DIVIDE +%token INT +%token DEFN +%token DATA +%token CASE +%token OF +%token OCURLY +%token CCURLY +%token OPAREN +%token CPAREN +%token COMMA +%token ARROW +%token EQUAL +%token LID +%token UID + +%language "c++" +%define api.value.type variant +%define api.token.constructor + +%type > lowercaseParams uppercaseParams +%type > program definitions +%type > branches +%type > constructors +%type aAdd aMul case app appBase +%type definition defn data +%type branch +%type pattern +%type constructor + +%start program + +%% + +program + : definitions { program = std::move($1); } + ; + +definitions + : definitions definition { $$ = std::move($1); $$.push_back(std::move($2)); } + | definition { $$ = std::vector(); $$.push_back(std::move($1)); } + ; + +definition + : defn { $$ = std::move($1); } + | data { $$ = std::move($1); } + ; + +defn + : DEFN LID lowercaseParams EQUAL OCURLY aAdd CCURLY + { $$ = definition_ptr( + new definition_defn(std::move($2), std::move($3), std::move($6))); } + ; + +lowercaseParams + : %empty { $$ = std::vector(); } + | lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); } + ; + +uppercaseParams + : %empty { $$ = std::vector(); } + | uppercaseParams UID { $$ = std::move($1); $$.push_back(std::move($2)); } + ; + +aAdd + : aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3))); } + | aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3))); } + | aMul { $$ = std::move($1); } + ; + +aMul + : aMul TIMES app { $$ = ast_ptr(new ast_binop(TIMES, std::move($1), std::move($3))); } + | aMul DIVIDE app { $$ = ast_ptr(new ast_binop(DIVIDE, std::move($1), std::move($3))); } + | app { $$ = std::move($1); } + ; + +app + : app appBase { $$ = ast_ptr(new ast_app(std::move($1), std::move($2))); } + | appBase { $$ = std::move($1); } + ; + +appBase + : INT { $$ = ast_ptr(new ast_int($1)); } + | LID { $$ = ast_ptr(new ast_lid(std::move($1))); } + | UID { $$ = ast_ptr(new ast_uid(std::move($1))); } + | OPAREN aAdd CPAREN { $$ = std::move($2); } + | case { $$ = std::move($1); } + ; + +case + : CASE aAdd OF OCURLY branches CCURLY + { $$ = ast_ptr(new ast_case(std::move($2), std::move($5))); } + ; + +branches + : branches branch { $$ = std::move($1); $$.push_back(std::move($2)); } + | branch { $$ = std::vector(); $$.push_back(std::move($1));} + ; + +branch + : pattern ARROW OCURLY aAdd CCURLY + { $$ = branch_ptr(new branch(std::move($1), std::move($4))); } + ; + +pattern + : LID { $$ = pattern_ptr(new pattern_var(std::move($1))); } + | UID lowercaseParams + { $$ = pattern_ptr(new pattern_constr(std::move($1), std::move($2))); } + ; + +data + : DATA UID EQUAL OCURLY constructors CCURLY + { $$ = definition_ptr(new definition_data(std::move($2), std::move($5))); } + ; + +constructors + : constructors COMMA constructor { $$ = std::move($1); $$.push_back(std::move($3)); } + | constructor + { $$ = std::vector(); $$.push_back(std::move($1)); } + ; + +constructor + : UID uppercaseParams + { $$ = constructor_ptr(new constructor(std::move($1), std::move($2))); } + ; + diff --git a/06/scanner.l b/06/scanner.l new file mode 100644 index 0000000..683deeb --- /dev/null +++ b/06/scanner.l @@ -0,0 +1,34 @@ +%option noyywrap + +%{ +#include +#include "ast.hpp" +#include "parser.hpp" + +#define YY_DECL yy::parser::symbol_type yylex() + +%} + +%% + +[ \n]+ {} +\+ { return yy::parser::make_PLUS(); } +\* { return yy::parser::make_TIMES(); } +- { return yy::parser::make_MINUS(); } +\/ { return yy::parser::make_DIVIDE(); } +[0-9]+ { return yy::parser::make_INT(atoi(yytext)); } +defn { return yy::parser::make_DEFN(); } +data { return yy::parser::make_DATA(); } +case { return yy::parser::make_CASE(); } +of { return yy::parser::make_OF(); } +\{ { return yy::parser::make_OCURLY(); } +\} { return yy::parser::make_CCURLY(); } +\( { return yy::parser::make_OPAREN(); } +\) { return yy::parser::make_CPAREN(); } +, { return yy::parser::make_COMMA(); } +-> { return yy::parser::make_ARROW(); } += { return yy::parser::make_EQUAL(); } +[a-z][a-zA-Z]* { return yy::parser::make_LID(std::string(yytext)); } +[A-Z][a-zA-Z]* { return yy::parser::make_UID(std::string(yytext)); } + +%% diff --git a/06/type.cpp b/06/type.cpp new file mode 100644 index 0000000..0fc7364 --- /dev/null +++ b/06/type.cpp @@ -0,0 +1,99 @@ +#include "type.hpp" +#include +#include +#include "error.hpp" + +void type_var::print(const type_mgr& mgr, std::ostream& to) const { + auto it = mgr.types.find(name); + if(it != mgr.types.end()) { + it->second->print(mgr, to); + } else { + to << name; + } +} + +void type_base::print(const type_mgr& mgr, std::ostream& to) const { + to << name; +} + +void type_arr::print(const type_mgr& mgr, std::ostream& to) const { + left->print(mgr, to); + to << " -> ("; + right->print(mgr, to); + to << ")"; +} + +std::string type_mgr::new_type_name() { + int temp = last_id++; + std::string str = ""; + + while(temp != -1) { + str += (char) ('a' + (temp % 26)); + temp = temp / 26 - 1; + } + + std::reverse(str.begin(), str.end()); + return str; +} + +type_ptr type_mgr::new_type() { + return type_ptr(new type_var(new_type_name())); +} + +type_ptr type_mgr::new_arrow_type() { + return type_ptr(new type_arr(new_type(), new_type())); +} + +type_ptr type_mgr::resolve(type_ptr t, type_var*& var) { + type_var* cast; + + var = nullptr; + while((cast = dynamic_cast(t.get()))) { + auto it = types.find(cast->name); + + if(it == types.end()) { + var = cast; + break; + } + t = it->second; + } + + return t; +} + +void type_mgr::unify(type_ptr l, type_ptr r) { + type_var* lvar; + type_var* rvar; + type_arr* larr; + type_arr* rarr; + type_base* lid; + type_base* rid; + + l = resolve(l, lvar); + r = resolve(r, rvar); + + if(lvar) { + bind(lvar->name, r); + return; + } else if(rvar) { + bind(rvar->name, l); + return; + } else if((larr = dynamic_cast(l.get())) && + (rarr = dynamic_cast(r.get()))) { + unify(larr->left, rarr->left); + unify(larr->right, rarr->right); + return; + } else if((lid = dynamic_cast(l.get())) && + (rid = dynamic_cast(r.get()))) { + if(lid->name == rid->name) return; + } + + throw unification_error(l, r); +} + +void type_mgr::bind(const std::string& s, type_ptr t) { + type_var* other = dynamic_cast(t.get()); + + if(other && other->name == s) return; + types[s] = t; +} diff --git a/06/type.hpp b/06/type.hpp new file mode 100644 index 0000000..2774c29 --- /dev/null +++ b/06/type.hpp @@ -0,0 +1,54 @@ +#pragma once +#include +#include + +struct type_mgr; + +struct type { + virtual ~type() = default; + + virtual void print(const type_mgr& mgr, std::ostream& to) const = 0; +}; + +using type_ptr = std::shared_ptr; + +struct type_var : public type { + std::string name; + + type_var(std::string n) + : name(std::move(n)) {} + + void print(const type_mgr& mgr, std::ostream& to) const; +}; + +struct type_base : public type { + std::string name; + + type_base(std::string n) + : name(std::move(n)) {} + + void print(const type_mgr& mgr, std::ostream& to) const; +}; + +struct type_arr : public type { + type_ptr left; + type_ptr right; + + type_arr(type_ptr l, type_ptr r) + : left(std::move(l)), right(std::move(r)) {} + + void print(const type_mgr& mgr, std::ostream& to) const; +}; + +struct type_mgr { + int last_id = 0; + std::map types; + + std::string new_type_name(); + type_ptr new_type(); + type_ptr new_arrow_type(); + + void unify(type_ptr l, type_ptr r); + type_ptr resolve(type_ptr t, type_var*& var); + void bind(const std::string& s, type_ptr t); +};