diff --git a/12/ast.cpp b/12/ast.cpp index 09fece2..3f6d4e0 100644 --- a/12/ast.cpp +++ b/12/ast.cpp @@ -331,6 +331,70 @@ void ast_let::compile(const env_ptr& env, std::vector& into) co into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size()))); } +void ast_lambda::print(int indent, std::ostream& to) const { + print_indent(indent, to); + to << "LAMBDA"; + for(auto& param : params) { + to << " " << param; + } + to << std::endl; + body->print(indent+1, to); +} + +void ast_lambda::find_free(std::set& into) { + body->find_free(free_variables); + for(auto& param : params) { + free_variables.erase(param); + } + into.insert(free_variables.begin(), free_variables.end()); +} + +type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) { + this->env = env; + var_env = type_scope(env); + type_ptr return_type = mgr.new_type(); + type_ptr full_type = return_type; + + for(auto it = params.rbegin(); it != params.rend(); it++) { + type_ptr param_type = mgr.new_type(); + var_env->bind(*it, param_type); + full_type = type_ptr(new type_arr(std::move(param_type), full_type)); + } + + mgr.unify(return_type, body->typecheck(mgr, var_env)); + return full_type; +} + +void ast_lambda::translate(global_scope& scope) { + std::vector function_params; + for(auto& free_variable : free_variables) { + if(env->is_global(free_variable)) continue; + function_params.push_back(free_variable); + } + size_t captured_count = function_params.size(); + function_params.insert(function_params.end(), params.begin(), params.end()); + + auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body)); + type_env_ptr mangled_env = type_scope(env); + mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global); + mangled_env->set_mangled_name("lambda", new_function.name); + ast_ptr new_application = ast_ptr(new ast_lid("lambda")); + new_application->env = mangled_env; + + for(auto& param : new_function.params) { + if(!(captured_count--)) break; + ast_ptr new_arg = ast_ptr(new ast_lid(param)); + new_arg->env = env; + new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg))); + new_application->env = env; + } + translated = std::move(new_application); +} + +void ast_lambda::compile(const env_ptr& env, std::vector& into) const { + translated->compile(env, into); +} + void pattern_var::print(std::ostream& to) const { to << var; } diff --git a/12/ast.hpp b/12/ast.hpp index 29de43e..1e2f8b5 100644 --- a/12/ast.hpp +++ b/12/ast.hpp @@ -146,6 +146,25 @@ struct ast_let : public ast { void compile(const env_ptr& env, std::vector& into) const; }; +struct ast_lambda : public ast { + std::vector params; + ast_ptr body; + + type_env_ptr var_env; + + std::set free_variables; + ast_ptr translated; + + ast_lambda(std::vector ps, ast_ptr b) + : params(std::move(ps)), body(std::move(b)) {} + + void print(int indent, std::ostream& to) const; + void find_free(std::set& into); + type_ptr typecheck(type_mgr& mgr, type_env_ptr& env); + void translate(global_scope& scope); + void compile(const env_ptr& env, std::vector& into) const; +}; + struct pattern_var : public pattern { std::string var; diff --git a/12/parser.y b/12/parser.y index deb0d39..2dc1744 100644 --- a/12/parser.y +++ b/12/parser.y @@ -13,6 +13,7 @@ extern yy::parser::symbol_type yylex(); %} +%token BACKSLASH %token PLUS %token TIMES %token MINUS @@ -44,7 +45,7 @@ extern yy::parser::symbol_type yylex(); %type > typeList %type definitions %type type nonArrowType typeListElement -%type aAdd aMul case let app appBase +%type aAdd aMul case let lambda app appBase %type data %type defn %type branch @@ -100,6 +101,7 @@ appBase | OPAREN aAdd CPAREN { $$ = std::move($2); } | case { $$ = std::move($1); } | let { $$ = std::move($1); } + | lambda { $$ = std::move($1); } ; let @@ -107,6 +109,11 @@ let { $$ = ast_ptr(new ast_let(std::move($3), std::move($7))); } ; +lambda + : BACKSLASH lowercaseParams ARROW OCURLY aAdd CCURLY + { $$ = ast_ptr(new ast_lambda(std::move($2), std::move($5))); } + ; + case : CASE aAdd OF OCURLY branches CCURLY { $$ = ast_ptr(new ast_case(std::move($2), std::move($5))); } diff --git a/12/scanner.l b/12/scanner.l index c417de1..7b61504 100644 --- a/12/scanner.l +++ b/12/scanner.l @@ -13,6 +13,7 @@ %% [ \n]+ {} +\\ { return yy::parser::make_BACKSLASH(); } \+ { return yy::parser::make_PLUS(); } \* { return yy::parser::make_TIMES(); } - { return yy::parser::make_MINUS(); }