Add prototype impl of case specialization.
Boolean cases could be translated to ifs, and integer cases to jumps. That's still in progress.
This commit is contained in:
		
							parent
							
								
									c1a8dc4557
								
							
						
					
					
						commit
						c7b2a4959f
					
				
							
								
								
									
										215
									
								
								13/ast.cpp
									
									
									
									
									
								
							
							
						
						
									
										215
									
								
								13/ast.cpp
									
									
									
									
									
								
							| @ -1,5 +1,6 @@ | |||||||
| #include "ast.hpp" | #include "ast.hpp" | ||||||
| #include <ostream> | #include <ostream> | ||||||
|  | #include <type_traits> | ||||||
| #include "binop.hpp" | #include "binop.hpp" | ||||||
| #include "error.hpp" | #include "error.hpp" | ||||||
| #include "type_env.hpp" | #include "type_env.hpp" | ||||||
| @ -218,60 +219,184 @@ void ast_case::translate(global_scope& scope) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | template <typename T> | ||||||
|  | struct case_mappings { | ||||||
|  |     using tag_type = typename T::tag_type; | ||||||
|  |     std::map<tag_type, std::vector<instruction_ptr>> defined_cases; | ||||||
|  |     std::optional<std::vector<instruction_ptr>> default_case; | ||||||
|  | 
 | ||||||
|  |     std::vector<instruction_ptr>& make_case_for(tag_type tag) { | ||||||
|  |         auto existing_case = defined_cases.find(tag); | ||||||
|  |         if(existing_case != defined_cases.end()) return existing_case->second; | ||||||
|  |         if(default_case) | ||||||
|  |             throw type_error("attempted pattern match after catch-all"); | ||||||
|  |         return defined_cases[tag]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::vector<instruction_ptr>& make_default_case() { | ||||||
|  |         if(default_case) | ||||||
|  |             throw type_error("attempted repeated use of catch-all"); | ||||||
|  |         default_case.emplace(std::vector<instruction_ptr>()); | ||||||
|  |         return *default_case; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::vector<instruction_ptr>& get_specific_case_for(tag_type tag) { | ||||||
|  |         auto existing_case = defined_cases.find(tag); | ||||||
|  |         assert(existing_case != defined_cases.end()); | ||||||
|  |         return existing_case->second; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::vector<instruction_ptr>& get_default_case() { | ||||||
|  |         assert(default_case); | ||||||
|  |         return *default_case; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     bool case_defined_for(tag_type tag) { | ||||||
|  |         return defined_cases.find(tag) != defined_cases.end(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     bool default_case_defined() { return default_case.has_value(); } | ||||||
|  | 
 | ||||||
|  |     size_t defined_cases_count() { return defined_cases.size(); } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | struct case_strategy_bool { | ||||||
|  |     using tag_type = bool; | ||||||
|  |     using repr_type = bool; | ||||||
|  | 
 | ||||||
|  |     tag_type tag_from_repr(repr_type b) { return b; } | ||||||
|  | 
 | ||||||
|  |     repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { | ||||||
|  |         pattern_constr* cpat; | ||||||
|  |         if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) || | ||||||
|  |                 (cpat->constr != "True" && cpat->constr != "False") || | ||||||
|  |                 cpat->params.size() != 0) | ||||||
|  |             throw type_error("pattern cannot be converted to a boolean"); | ||||||
|  |         return cpat->constr == "True"; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void compile_branch( | ||||||
|  |             const branch_ptr& branch, | ||||||
|  |             const env_ptr& env, | ||||||
|  |             repr_type repr, | ||||||
|  |             std::vector<instruction_ptr>& into) { | ||||||
|  |         branch->expr->compile(env_ptr(new env_offset(1, env)), into); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     size_t case_count(const type* type) { | ||||||
|  |         return 2; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_bool>& ms) { | ||||||
|  |         throw std::runtime_error("boolean case unimplemented!"); | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct case_strategy_data { | ||||||
|  |     using tag_type = int; | ||||||
|  |     using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>; | ||||||
|  | 
 | ||||||
|  |     tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; } | ||||||
|  | 
 | ||||||
|  |     repr_type from_typed_pattern(const pattern_ptr& pt, const type* type) { | ||||||
|  |         pattern_constr* cpat; | ||||||
|  |         if(!(cpat = dynamic_cast<pattern_constr*>(pt.get()))) | ||||||
|  |             throw type_error("pattern cannot be interpreted as constructor."); | ||||||
|  |         return std::make_pair( | ||||||
|  |                 &static_cast<const type_data*>(type)->constructors.find(cpat->constr)->second, | ||||||
|  |                 &cpat->params); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     void compile_branch( | ||||||
|  |             const branch_ptr& branch, | ||||||
|  |             const env_ptr& env, | ||||||
|  |             const repr_type& repr, | ||||||
|  |             std::vector<instruction_ptr>& into) { | ||||||
|  |         env_ptr new_env = env; | ||||||
|  |         for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) { | ||||||
|  |             new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         into.push_back(instruction_ptr(new instruction_split(repr.second->size()))); | ||||||
|  |         branch->expr->compile(new_env, into); | ||||||
|  |         into.push_back(instruction_ptr(new instruction_slide(repr.second->size()))); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     size_t case_count(const type* type) { | ||||||
|  |         return static_cast<const type_data*>(type)->constructors.size(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     instruction_ptr into_instruction(const type* type, case_mappings<case_strategy_data>& ms) { | ||||||
|  |         instruction_jump* jump_instruction = new instruction_jump(); | ||||||
|  |         instruction_ptr inst(jump_instruction); | ||||||
|  | 
 | ||||||
|  |         auto data_type = static_cast<const type_data*>(type); | ||||||
|  |         for(auto& constr : data_type->constructors) { | ||||||
|  |             if(!ms.case_defined_for(constr.second.tag)) continue; | ||||||
|  |             jump_instruction->branches.push_back( | ||||||
|  |                     std::move(ms.get_specific_case_for(constr.second.tag))); | ||||||
|  |             jump_instruction->tag_mappings[constr.second.tag] = | ||||||
|  |                 jump_instruction->branches.size() - 1; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if(ms.default_case_defined()) { | ||||||
|  |             jump_instruction->branches.push_back( | ||||||
|  |                     std::move(ms.get_default_case())); | ||||||
|  |             for(auto& constr : data_type->constructors) { | ||||||
|  |                 if(ms.case_defined_for(constr.second.tag)) continue; | ||||||
|  |                 jump_instruction->tag_mappings[constr.second.tag] = | ||||||
|  |                     jump_instruction->branches.size(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return std::move(inst); | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) { | ||||||
|  |     T strategy; | ||||||
|  |     case_mappings<T> cases; | ||||||
|  |     for(auto& branch : node.branches) { | ||||||
|  |         pattern_var* vpat; | ||||||
|  |         if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) { | ||||||
|  |             auto& branch_into = cases.make_default_case(); | ||||||
|  |             env_ptr new_env(new env_var(branch->expr->env->get_mangled_name(vpat->var), env)); | ||||||
|  |             branch->expr->compile(new_env, branch_into); | ||||||
|  |         } else { | ||||||
|  |             auto repr = strategy.from_typed_pattern(branch->pat, type); | ||||||
|  |             auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr)); | ||||||
|  |             strategy.compile_branch(branch, env, repr, branch_into); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if(!(cases.defined_cases_count() == strategy.case_count(type) || | ||||||
|  |                 cases.default_case_defined())) | ||||||
|  |         throw type_error("incomplete patterns", node.loc); | ||||||
|  | 
 | ||||||
|  |     into.push_back(strategy.into_instruction(type, cases)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { | void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { | ||||||
|     type_app* app_type = dynamic_cast<type_app*>(input_type.get()); |     type_app* app_type = dynamic_cast<type_app*>(input_type.get()); | ||||||
|     type_data* type = dynamic_cast<type_data*>(app_type->constructor.get()); |     type_data* data; | ||||||
|  |     type_internal* internal; | ||||||
| 
 | 
 | ||||||
|     of->compile(env, into); |     of->compile(env, into); | ||||||
|     into.push_back(instruction_ptr(new instruction_eval())); |     into.push_back(instruction_ptr(new instruction_eval())); | ||||||
| 
 | 
 | ||||||
|     instruction_jump* jump_instruction = new instruction_jump(); |     if((data = dynamic_cast<type_data*>(app_type->constructor.get()))) { | ||||||
|     into.push_back(instruction_ptr(jump_instruction)); |         compile_case<case_strategy_data>(*this, env, data, into); | ||||||
|     for(auto& branch : branches) { |         return; | ||||||
|         std::vector<instruction_ptr> branch_instructions; |     } else if((internal = dynamic_cast<type_internal*>(app_type->constructor.get()))) { | ||||||
|         pattern_var* vpat; |         if(internal->name == "Bool") { | ||||||
|         pattern_constr* cpat; |             compile_case<case_strategy_bool>(*this, env, data, into); | ||||||
| 
 |             return; | ||||||
|         if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) { |  | ||||||
|             branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions); |  | ||||||
| 
 |  | ||||||
|             for(auto& constr_pair : type->constructors) { |  | ||||||
|                 if(jump_instruction->tag_mappings.find(constr_pair.second.tag) != |  | ||||||
|                         jump_instruction->tag_mappings.end()) |  | ||||||
|                     break; |  | ||||||
| 
 |  | ||||||
|                 jump_instruction->tag_mappings[constr_pair.second.tag] = |  | ||||||
|                     jump_instruction->branches.size(); |  | ||||||
|             } |  | ||||||
|             jump_instruction->branches.push_back(std::move(branch_instructions)); |  | ||||||
|         } else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) { |  | ||||||
|             env_ptr new_env = env; |  | ||||||
|             for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) { |  | ||||||
|                 new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env)); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             branch_instructions.push_back(instruction_ptr(new instruction_split( |  | ||||||
|                             cpat->params.size()))); |  | ||||||
|             branch->expr->compile(new_env, branch_instructions); |  | ||||||
|             branch_instructions.push_back(instruction_ptr(new instruction_slide( |  | ||||||
|                             cpat->params.size()))); |  | ||||||
| 
 |  | ||||||
|             int new_tag = type->constructors[cpat->constr].tag; |  | ||||||
|             if(jump_instruction->tag_mappings.find(new_tag) != |  | ||||||
|                     jump_instruction->tag_mappings.end()) |  | ||||||
|                 throw type_error("technically not a type error: duplicate pattern", cpat->loc); |  | ||||||
| 
 |  | ||||||
|             jump_instruction->tag_mappings[new_tag] = |  | ||||||
|                 jump_instruction->branches.size(); |  | ||||||
|             jump_instruction->branches.push_back(std::move(branch_instructions)); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     for(auto& constr_pair : type->constructors) { |     throw std::runtime_error("no known way to compile case expression"); | ||||||
|         if(jump_instruction->tag_mappings.find(constr_pair.second.tag) == |  | ||||||
|                 jump_instruction->tag_mappings.end()) |  | ||||||
|             throw type_error("non-total pattern", loc); |  | ||||||
|     } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void ast_let::print(int indent, std::ostream& to) const { | void ast_let::print(int indent, std::ostream& to) const { | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user