Compare commits

...

12 Commits

Author SHA1 Message Date
8a352ed3ea Roll back optimization changes. 2020-09-17 20:45:24 -07:00
02f8306c7b Use an instruction instead of a special-case boolean instruction. 2020-09-17 18:33:52 -07:00
cf6f353f20 Change tagging to assume sign extension.
ARM and x86_64 require "real" pointers to be
sign-extended in their top bits. This means
a working pointer is guaranteed to have either "11"
as leading bits, or "00". So, to tag a "fake" pointer
which is an unboxed 32-bit integer, we simply toggle
the leading bit.
2020-09-17 18:30:55 -07:00
7a631b3557 Make a few more things classes. 2020-09-17 18:30:41 -07:00
5e13047846 Make global scope a class. 2020-09-15 19:45:05 -07:00
c17d532802 Make type_mgr a class. 2020-09-15 19:19:58 -07:00
55e4e61906 Make mangler a class and reformat graph. 2020-09-15 19:13:48 -07:00
f2f88ab9ca Make env a class. 2020-09-15 19:12:12 -07:00
ba418d357f Make type_env a class. 2020-09-15 19:10:36 -07:00
0e3f16139d Make llvm_context a class. 2020-09-15 19:08:00 -07:00
55486d511f Make some refactors for name mangling and encapsulation. 2020-09-15 18:51:28 -07:00
6080094c41 Require mangled names for global variables. 2020-09-15 14:39:31 -07:00
30 changed files with 686 additions and 830 deletions

View File

@ -38,6 +38,8 @@ add_executable(compiler
graph.cpp graph.hpp graph.cpp graph.hpp
global_scope.cpp global_scope.hpp global_scope.cpp global_scope.hpp
parse_driver.cpp parse_driver.hpp parse_driver.cpp parse_driver.hpp
mangler.cpp mangler.hpp
compiler.cpp compiler.hpp
${BISON_parser_OUTPUTS} ${BISON_parser_OUTPUTS}
${FLEX_scanner_OUTPUTS} ${FLEX_scanner_OUTPUTS}
main.cpp main.cpp

View File

@ -3,6 +3,7 @@
#include <type_traits> #include <type_traits>
#include "binop.hpp" #include "binop.hpp"
#include "error.hpp" #include "error.hpp"
#include "instruction.hpp"
#include "type.hpp" #include "type.hpp"
#include "type_env.hpp" #include "type_env.hpp"
#include "env.hpp" #include "env.hpp"
@ -55,15 +56,10 @@ void ast_lid::translate(global_scope& scope) {
} }
void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
auto mangled_name = this->env->get_mangled_name(id);
// Local names shouldn't need mangling.
assert(!(mangled_name != id && !this->env->is_global(id)));
into.push_back(instruction_ptr( into.push_back(instruction_ptr(
(env->has_variable(mangled_name) && !this->env->is_global(id)) ? (this->env->is_global(id)) ?
(instruction*) new instruction_push(env->get_offset(id)) : (instruction*) new instruction_pushglobal(this->env->get_mangled_name(id)) :
(instruction*) new instruction_pushglobal(mangled_name))); (instruction*) new instruction_push(env->get_offset(id))));
} }
void ast_uid::print(int indent, std::ostream& to) const { void ast_uid::print(int indent, std::ostream& to) const {
@ -216,6 +212,10 @@ type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) {
input_type = mgr.resolve(case_type, var); input_type = mgr.resolve(case_type, var);
type_app* app_type; type_app* app_type;
if(!(app_type = dynamic_cast<type_app*>(input_type.get())) ||
!dynamic_cast<type_data*>(app_type->constructor.get())) {
throw type_error("attempting case analysis of non-data type");
}
return branch_type; return branch_type;
} }
@ -227,215 +227,60 @@ void ast_case::translate(global_scope& scope) {
} }
} }
template <typename T>
struct case_mappings {
using tag_type = typename T::tag_type;
std::map<tag_type, std::vector<instruction_ptr>> defined_cases;
std::optional<std::vector<instruction_ptr>> default_case;
std::vector<instruction_ptr>& make_case_for(tag_type tag) {
if(default_case)
throw compiler_error("attempted pattern match after catch-all");
return defined_cases[tag];
}
std::vector<instruction_ptr>& make_default_case() {
if(default_case)
throw compiler_error("attempted repeated use of catch-all");
default_case.emplace(std::vector<instruction_ptr>());
return *default_case;
}
std::vector<instruction_ptr>& get_specific_case_for(tag_type tag) {
auto existing_case = defined_cases.find(tag);
assert(existing_case != defined_cases.end());
return existing_case->second;
}
std::vector<instruction_ptr>& get_default_case() {
assert(default_case);
return *default_case;
}
std::vector<instruction_ptr>& get_case_for(tag_type tag) {
if(case_defined_for(tag)) return get_specific_case_for(tag);
return get_default_case();
}
bool case_defined_for(tag_type tag) {
return defined_cases.find(tag) != defined_cases.end();
}
bool default_case_defined() { return default_case.has_value(); }
size_t defined_cases_count() { return defined_cases.size(); }
};
struct case_strategy_bool {
using tag_type = bool;
using repr_type = bool;
case_strategy_bool(const type* type) {}
tag_type tag_from_repr(repr_type b) { return b; }
repr_type repr_from_pattern(const pattern_ptr& pt) {
pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())) ||
(cpat->constr != "True" && cpat->constr != "False") ||
cpat->params.size() != 0)
throw compiler_error(
"pattern cannot be converted to a boolean",
pt->loc);
return cpat->constr == "True";
}
void compile_branch(
const branch_ptr& branch,
const env_ptr& env,
repr_type repr,
std::vector<instruction_ptr>& into) {
branch->expr->compile(env_ptr(new env_offset(1, env)), into);
into.push_back(instruction_ptr(new instruction_slide(1)));
}
size_t case_count() {
return 2;
}
void into_instructions(
case_mappings<case_strategy_bool>& ms,
std::vector<instruction_ptr>& into) {
if(ms.defined_cases_count() == 0) {
for(auto& instruction : ms.get_default_case())
into.push_back(std::move(instruction));
return;
}
into.push_back(instruction_ptr(new instruction_if(
std::move(ms.get_case_for(true)),
std::move(ms.get_case_for(false)))));
}
};
struct case_strategy_data {
using tag_type = int;
using repr_type = std::pair<const type_data::constructor*, const std::vector<std::string>*>;
const type_data* arg_type;
case_strategy_data(const type* t) {
arg_type = dynamic_cast<const type_data*>(t);
assert(arg_type);
}
tag_type tag_from_repr(const repr_type& repr) { return repr.first->tag; }
repr_type repr_from_pattern(const pattern_ptr& pt) {
pattern_constr* cpat;
if(!(cpat = dynamic_cast<pattern_constr*>(pt.get())))
throw compiler_error(
"pattern cannot be interpreted as constructor.",
pt->loc);
return std::make_pair(
&arg_type->constructors.find(cpat->constr)->second,
&cpat->params);
}
void compile_branch(
const branch_ptr& branch,
const env_ptr& env,
const repr_type& repr,
std::vector<instruction_ptr>& into) {
env_ptr new_env = env;
for(auto it = repr.second->rbegin(); it != repr.second->rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
into.push_back(instruction_ptr(new instruction_split(repr.second->size())));
branch->expr->compile(new_env, into);
into.push_back(instruction_ptr(new instruction_slide(repr.second->size())));
}
size_t case_count() {
return arg_type->constructors.size();
}
void into_instructions(
case_mappings<case_strategy_data>& ms,
std::vector<instruction_ptr>& into) {
instruction_jump* jump_instruction = new instruction_jump();
instruction_ptr inst(jump_instruction);
for(auto& constr : arg_type->constructors) {
if(!ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->branches.push_back(
std::move(ms.get_specific_case_for(constr.second.tag)));
jump_instruction->tag_mappings[constr.second.tag] =
jump_instruction->branches.size() - 1;
}
if(ms.default_case_defined()) {
jump_instruction->branches.push_back(
std::move(ms.get_default_case()));
for(auto& constr : arg_type->constructors) {
if(ms.case_defined_for(constr.second.tag)) continue;
jump_instruction->tag_mappings[constr.second.tag] =
jump_instruction->branches.size();
}
}
into.push_back(std::move(inst));
}
};
template <typename T>
void compile_case(const ast_case& node, const env_ptr& env, const type* type, std::vector<instruction_ptr>& into) {
T strategy(type);
case_mappings<T> cases;
for(auto& branch : node.branches) {
pattern_var* vpat;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
if(cases.defined_cases_count() == strategy.case_count())
throw compiler_error("redundant catch-all pattern", branch->pat->loc);
auto& branch_into = cases.make_default_case();
env_ptr new_env(new env_var(vpat->var, env));
branch->expr->compile(new_env, branch_into);
branch_into.push_back(instruction_ptr(new instruction_slide(1)));
} else {
auto repr = strategy.repr_from_pattern(branch->pat);
auto& branch_into = cases.make_case_for(strategy.tag_from_repr(repr));
strategy.compile_branch(branch, env, repr, branch_into);
}
}
if(!(cases.defined_cases_count() == strategy.case_count() ||
cases.default_case_defined()))
throw compiler_error("incomplete patterns", node.loc);
strategy.into_instructions(cases, into);
}
void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const { void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
type_app* app_type = dynamic_cast<type_app*>(input_type.get()); type_app* app_type = dynamic_cast<type_app*>(input_type.get());
type_data* data; type_data* type = dynamic_cast<type_data*>(app_type->constructor.get());
type_internal* internal;
of->compile(env, into); of->compile(env, into);
into.push_back(instruction_ptr(new instruction_eval())); into.push_back(instruction_ptr(new instruction_eval()));
if(app_type && (data = dynamic_cast<type_data*>(app_type->constructor.get()))) { instruction_jump* jump_instruction = new instruction_jump();
compile_case<case_strategy_data>(*this, env, data, into); into.push_back(instruction_ptr(jump_instruction));
return; for(auto& branch : branches) {
} else if(app_type && (internal = dynamic_cast<type_internal*>(app_type->constructor.get()))) { std::vector<instruction_ptr> branch_instructions;
if(internal->name == "Bool") { pattern_var* vpat;
compile_case<case_strategy_bool>(*this, env, data, into); pattern_constr* cpat;
return;
if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions);
for(auto& constr_pair : type->constructors) {
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) !=
jump_instruction->tag_mappings.end())
break;
jump_instruction->tag_mappings[constr_pair.second.tag] =
jump_instruction->branches.size();
}
jump_instruction->branches.push_back(std::move(branch_instructions));
} else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) {
env_ptr new_env = env;
for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) {
new_env = env_ptr(new env_var(*it, new_env));
}
branch_instructions.push_back(instruction_ptr(new instruction_split(
cpat->params.size())));
branch->expr->compile(new_env, branch_instructions);
branch_instructions.push_back(instruction_ptr(new instruction_slide(
cpat->params.size())));
int new_tag = type->constructors[cpat->constr].tag;
if(jump_instruction->tag_mappings.find(new_tag) !=
jump_instruction->tag_mappings.end())
throw type_error("technically not a type error: duplicate pattern");
jump_instruction->tag_mappings[new_tag] =
jump_instruction->branches.size();
jump_instruction->branches.push_back(std::move(branch_instructions));
} }
} }
throw type_error("attempting unsupported case analysis", of->loc); for(auto& constr_pair : type->constructors) {
if(jump_instruction->tag_mappings.find(constr_pair.second.tag) ==
jump_instruction->tag_mappings.end())
throw type_error("non-total pattern");
}
} }
void ast_let::print(int indent, std::ostream& to) const { void ast_let::print(int indent, std::ostream& to) const {

View File

@ -6,9 +6,6 @@ std::string op_name(binop op) {
case MINUS: return "-"; case MINUS: return "-";
case TIMES: return "*"; case TIMES: return "*";
case DIVIDE: return "/"; case DIVIDE: return "/";
case MODULO: return "%";
case EQUALS: return "==";
case LESS_EQUALS: return "<=";
} }
return "??"; return "??";
} }
@ -19,9 +16,6 @@ std::string op_action(binop op) {
case MINUS: return "minus"; case MINUS: return "minus";
case TIMES: return "times"; case TIMES: return "times";
case DIVIDE: return "divide"; case DIVIDE: return "divide";
case MODULO: return "modulo";
case EQUALS: return "equals";
case LESS_EQUALS: return "less_equals";
} }
return "??"; return "??";
} }

View File

@ -1,14 +1,16 @@
#pragma once #pragma once
#include <array>
#include <string> #include <string>
enum binop { enum binop {
PLUS, PLUS,
MINUS, MINUS,
TIMES, TIMES,
DIVIDE, DIVIDE
MODULO, };
EQUALS,
LESS_EQUALS, constexpr binop all_binops[] = {
PLUS, MINUS, TIMES, DIVIDE
}; };
std::string op_name(binop op); std::string op_name(binop op);

View File

@ -0,0 +1,153 @@
#include "compiler.hpp"
#include "binop.hpp"
#include "error.hpp"
#include "global_scope.hpp"
#include "parse_driver.hpp"
#include "type.hpp"
#include "type_env.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Target/TargetMachine.h"
void compiler::add_default_types() {
global_env->bind_type("Int", type_ptr(new type_base("Int")));
}
void compiler::add_binop_type(binop op, type_ptr type) {
auto name = mng.new_mangled_name(op_action(op));
global_env->bind(op_name(op), std::move(type), visibility::global);
global_env->set_mangled_name(op_name(op), name);
}
void compiler::add_default_function_types() {
type_ptr int_type = global_env->lookup_type("Int");
assert(int_type != nullptr);
type_ptr int_type_app = type_ptr(new type_app(int_type));
type_ptr closed_int_op_type(
new type_arr(int_type_app, type_ptr(new type_arr(int_type_app, int_type_app))));
constexpr binop closed_ops[] = { PLUS, MINUS, TIMES, DIVIDE };
for(auto& op : closed_ops) add_binop_type(op, closed_int_op_type);
}
void compiler::parse() {
if(!driver())
throw compiler_error("failed to open file");
}
void compiler::typecheck() {
std::set<std::string> free_variables;
global_defs.find_free(free_variables);
global_defs.typecheck(type_m, global_env);
}
void compiler::translate() {
for(auto& data : global_defs.defs_data) {
data.second->into_globals(global_scp);
}
for(auto& defn : global_defs.defs_defn) {
auto& function = defn.second->into_global(global_scp);
function.body->env->get_parent()->set_mangled_name(defn.first, function.name);
}
}
void compiler::compile() {
global_scp.compile();
}
void compiler::create_llvm_binop(binop op) {
auto new_function =
ctx.create_custom_function(global_env->get_mangled_name(op_name(op)), 2);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_binop(op)));
instructions.push_back(instruction_ptr(new instruction_update(2)));
instructions.push_back(instruction_ptr(new instruction_pop(2)));
ctx.get_builder().SetInsertPoint(&new_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.get_builder().CreateRetVoid();
}
void compiler::generate_llvm() {
for(auto op : all_binops) {
create_llvm_binop(op);
}
global_scp.generate_llvm(ctx);
}
void compiler::output_llvm(const std::string& into) {
std::string targetTriple = llvm::sys::getDefaultTargetTriple();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::InitializeNativeTargetAsmPrinter();
std::string error;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(targetTriple, error);
if (!target) {
std::cerr << error << std::endl;
} else {
std::string cpu = "generic";
std::string features = "";
llvm::TargetOptions options;
std::unique_ptr<llvm::TargetMachine> targetMachine(
target->createTargetMachine(targetTriple, cpu, features,
options, llvm::Optional<llvm::Reloc::Model>()));
ctx.get_module().setDataLayout(targetMachine->createDataLayout());
ctx.get_module().setTargetTriple(targetTriple);
std::error_code ec;
llvm::raw_fd_ostream file(into, ec, llvm::sys::fs::F_None);
if (ec) {
throw compiler_error("failed to open object file for writing");
} else {
llvm::CodeGenFileType type = llvm::CGFT_ObjectFile;
llvm::legacy::PassManager pm;
if (targetMachine->addPassesToEmitFile(pm, file, NULL, type)) {
throw compiler_error("failed to add passes to pass manager");
} else {
pm.run(ctx.get_module());
file.close();
}
}
}
}
compiler::compiler(const std::string& filename)
: file_m(), global_defs(), driver(file_m, global_defs, filename),
global_env(new type_env), type_m(), mng(), global_scp(mng), ctx() {
add_default_types();
add_default_function_types();
}
void compiler::operator()(const std::string& into) {
parse();
typecheck();
translate();
compile();
generate_llvm();
output_llvm(into);
}
file_mgr& compiler::get_file_manager() {
return file_m;
}
type_mgr& compiler::get_type_manager() {
return type_m;
}

View File

@ -0,0 +1,37 @@
#pragma once
#include "binop.hpp"
#include "parse_driver.hpp"
#include "definition.hpp"
#include "type_env.hpp"
#include "type.hpp"
#include "global_scope.hpp"
#include "mangler.hpp"
#include "llvm_context.hpp"
class compiler {
private:
file_mgr file_m;
definition_group global_defs;
parse_driver driver;
type_env_ptr global_env;
type_mgr type_m;
mangler mng;
global_scope global_scp;
llvm_context ctx;
void add_default_types();
void add_binop_type(binop op, type_ptr type);
void add_default_function_types();
void parse();
void typecheck();
void translate();
void compile();
void create_llvm_binop(binop op);
void generate_llvm();
void output_llvm(const std::string& into);
public:
compiler(const std::string& filename);
void operator()(const std::string& into);
file_mgr& get_file_manager();
type_mgr& get_type_manager();
};

View File

@ -65,8 +65,8 @@ void definition_data::insert_constructors() const {
for(auto& var : vars) { for(auto& var : vars) {
if(var_set.find(var) != var_set.end()) if(var_set.find(var) != var_set.end())
throw compiler_error( throw compiler_error(
std::string("type variable ") + "type variable " + var +
var + std::string(" used twice in data type definition."), loc); " used twice in data type definition.", loc);
var_set.insert(var); var_set.insert(var);
return_app->arguments.push_back(type_ptr(new type_var(var))); return_app->arguments.push_back(type_ptr(new type_var(var)));
} }

View File

@ -3,7 +3,7 @@
int env_var::get_offset(const std::string& name) const { int env_var::get_offset(const std::string& name) const {
if(name == this->name) return 0; if(name == this->name) return 0;
assert(parent); assert(parent != nullptr);
return parent->get_offset(name) + 1; return parent->get_offset(name) + 1;
} }
@ -14,7 +14,7 @@ bool env_var::has_variable(const std::string& name) const {
} }
int env_offset::get_offset(const std::string& name) const { int env_offset::get_offset(const std::string& name) const {
assert(parent); assert(parent != nullptr);
return parent->get_offset(name) + offset; return parent->get_offset(name) + offset;
} }

View File

@ -2,33 +2,38 @@
#include <memory> #include <memory>
#include <string> #include <string>
struct env { class env {
virtual ~env() = default; public:
virtual ~env() = default;
virtual int get_offset(const std::string& name) const = 0; virtual int get_offset(const std::string& name) const = 0;
virtual bool has_variable(const std::string& name) const = 0; virtual bool has_variable(const std::string& name) const = 0;
}; };
using env_ptr = std::shared_ptr<env>; using env_ptr = std::shared_ptr<env>;
struct env_var : public env { class env_var : public env {
std::string name; private:
env_ptr parent; std::string name;
env_ptr parent;
env_var(std::string n, env_ptr p) public:
: name(std::move(n)), parent(std::move(p)) {} env_var(std::string n, env_ptr p)
: name(std::move(n)), parent(std::move(p)) {}
int get_offset(const std::string& name) const; int get_offset(const std::string& name) const;
bool has_variable(const std::string& name) const; bool has_variable(const std::string& name) const;
}; };
struct env_offset : public env { class env_offset : public env {
int offset; private:
env_ptr parent; int offset;
env_ptr parent;
env_offset(int o, env_ptr p) public:
: offset(o), parent(std::move(p)) {} env_offset(int o, env_ptr p)
: offset(o), parent(std::move(p)) {}
int get_offset(const std::string& name) const; int get_offset(const std::string& name) const;
bool has_variable(const std::string& name) const; bool has_variable(const std::string& name) const;
}; };

View File

@ -9,28 +9,28 @@ void compiler_error::print_about(std::ostream& to) {
to << description << std::endl; to << description << std::endl;
} }
void compiler_error::print_location(std::ostream& to, parse_driver& drv, bool highlight) { void compiler_error::print_location(std::ostream& to, file_mgr& fm, bool highlight) {
if(!loc) return; if(!loc) return;
to << "occuring on line " << loc->begin.line << ":" << std::endl; to << "occuring on line " << loc->begin.line << ":" << std::endl;
drv.print_location(to, *loc, highlight); fm.print_location(to, *loc, highlight);
} }
void compiler_error::pretty_print(std::ostream& to, parse_driver& drv) { void compiler_error::pretty_print(std::ostream& to, file_mgr& fm) {
print_about(to); print_about(to);
print_location(to, drv); print_location(to, fm);
} }
const char* type_error::what() const noexcept { const char* type_error::what() const noexcept {
return "an error occured while checking the types of the program"; return "an error occured while checking the types of the program";
} }
void type_error::pretty_print(std::ostream& to, parse_driver& drv) { void type_error::pretty_print(std::ostream& to, file_mgr& fm) {
print_about(to); print_about(to);
print_location(to, drv, true); print_location(to, fm, true);
} }
void unification_error::pretty_print(std::ostream& to, parse_driver& drv, type_mgr& mgr) { void unification_error::pretty_print(std::ostream& to, file_mgr& fm, type_mgr& mgr) {
type_error::pretty_print(to, drv); type_error::pretty_print(to, fm);
to << "the expected type was:" << std::endl; to << "the expected type was:" << std::endl;
to << " \033[34m"; to << " \033[34m";
left->print(mgr, to); left->print(mgr, to);

View File

@ -7,38 +7,43 @@
using maybe_location = std::optional<yy::location>; using maybe_location = std::optional<yy::location>;
struct compiler_error : std::exception { class compiler_error : std::exception {
std::string description; private:
maybe_location loc; std::string description;
maybe_location loc;
compiler_error(std::string d, maybe_location l = std::nullopt) public:
: description(std::move(d)), loc(std::move(l)) {} compiler_error(std::string d, maybe_location l = std::nullopt)
: description(std::move(d)), loc(std::move(l)) {}
const char* what() const noexcept override; const char* what() const noexcept override;
void print_about(std::ostream& to); void print_about(std::ostream& to);
void print_location(std::ostream& to, parse_driver& drv, bool highlight = false); void print_location(std::ostream& to, file_mgr& fm, bool highlight = false);
void pretty_print(std::ostream& to, parse_driver& drv); void pretty_print(std::ostream& to, file_mgr& fm);
}; };
struct type_error : compiler_error { class type_error : compiler_error {
std::optional<yy::location> loc; private:
type_error(std::string d, maybe_location l = std::nullopt) public:
: compiler_error(std::move(d), std::move(l)) {} type_error(std::string d, maybe_location l = std::nullopt)
: compiler_error(std::move(d), std::move(l)) {}
const char* what() const noexcept override; const char* what() const noexcept override;
void pretty_print(std::ostream& to, parse_driver& drv); void pretty_print(std::ostream& to, file_mgr& fm);
}; };
struct unification_error : public type_error { class unification_error : public type_error {
type_ptr left; private:
type_ptr right; type_ptr left;
type_ptr right;
unification_error(type_ptr l, type_ptr r, maybe_location loc = std::nullopt) public:
: left(std::move(l)), right(std::move(r)), unification_error(type_ptr l, type_ptr r, maybe_location loc = std::nullopt)
: left(std::move(l)), right(std::move(r)),
type_error("failed to unify types", std::move(loc)) {} type_error("failed to unify types", std::move(loc)) {}
void pretty_print(std::ostream& to, parse_driver& drv, type_mgr& mgr); void pretty_print(std::ostream& to, file_mgr& fm, type_mgr& mgr);
}; };

View File

@ -16,11 +16,11 @@ void global_function::declare_llvm(llvm_context& ctx) {
} }
void global_function::generate_llvm(llvm_context& ctx) { void global_function::generate_llvm(llvm_context& ctx) {
ctx.builder.SetInsertPoint(&generated_function->getEntryBlock()); ctx.get_builder().SetInsertPoint(&generated_function->getEntryBlock());
for(auto& instruction : instructions) { for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, generated_function); instruction->gen_llvm(ctx, generated_function);
} }
ctx.builder.CreateRetVoid(); ctx.get_builder().CreateRetVoid();
} }
void global_constructor::generate_llvm(llvm_context& ctx) { void global_constructor::generate_llvm(llvm_context& ctx) {
@ -29,21 +29,30 @@ void global_constructor::generate_llvm(llvm_context& ctx) {
std::vector<instruction_ptr> instructions; std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_pack(tag, arity))); instructions.push_back(instruction_ptr(new instruction_pack(tag, arity)));
instructions.push_back(instruction_ptr(new instruction_update(0))); instructions.push_back(instruction_ptr(new instruction_update(0)));
ctx.builder.SetInsertPoint(&new_function->getEntryBlock()); ctx.get_builder().SetInsertPoint(&new_function->getEntryBlock());
for (auto& instruction : instructions) { for (auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function); instruction->gen_llvm(ctx, new_function);
} }
ctx.builder.CreateRetVoid(); ctx.get_builder().CreateRetVoid();
} }
global_function& global_scope::add_function(std::string n, std::vector<std::string> ps, ast_ptr b) { global_function& global_scope::add_function(
global_function* new_function = new global_function(mangle_name(n), std::move(ps), std::move(b)); const std::string& n,
std::vector<std::string> ps,
ast_ptr b) {
auto name = mng->new_mangled_name(n);
global_function* new_function =
new global_function(std::move(name), std::move(ps), std::move(b));
functions.push_back(global_function_ptr(new_function)); functions.push_back(global_function_ptr(new_function));
return *new_function; return *new_function;
} }
global_constructor& global_scope::add_constructor(std::string n, int8_t t, size_t a) { global_constructor& global_scope::add_constructor(
global_constructor* new_constructor = new global_constructor(mangle_name(n), t, a); const std::string& n,
int8_t t,
size_t a) {
auto name = mng->new_mangled_name(n);
global_constructor* new_constructor = new global_constructor(name, t, a);
constructors.push_back(global_constructor_ptr(new_constructor)); constructors.push_back(global_constructor_ptr(new_constructor));
return *new_constructor; return *new_constructor;
} }
@ -65,19 +74,3 @@ void global_scope::generate_llvm(llvm_context& ctx) {
function->generate_llvm(ctx); function->generate_llvm(ctx);
} }
} }
std::string global_scope::mangle_name(const std::string& n) {
auto occurence_it = occurence_count.find(n);
int occurence = 0;
if(occurence_it != occurence_count.end()) {
occurence = occurence_it->second + 1;
}
occurence_count[n] = occurence;
std::string final_name = n;
if (occurence != 0) {
final_name += "_";
final_name += std::to_string(occurence);
}
return final_name;
}

View File

@ -4,6 +4,7 @@
#include <vector> #include <vector>
#include <llvm/IR/Function.h> #include <llvm/IR/Function.h>
#include "instruction.hpp" #include "instruction.hpp"
#include "mangler.hpp"
struct ast; struct ast;
using ast_ptr = std::unique_ptr<ast>; using ast_ptr = std::unique_ptr<ast>;
@ -39,17 +40,21 @@ struct global_constructor {
using global_constructor_ptr = std::unique_ptr<global_constructor>; using global_constructor_ptr = std::unique_ptr<global_constructor>;
struct global_scope { class global_scope {
std::map<std::string, int> occurence_count;
std::vector<global_function_ptr> functions;
std::vector<global_constructor_ptr> constructors;
global_function& add_function(std::string n, std::vector<std::string> ps, ast_ptr b);
global_constructor& add_constructor(std::string n, int8_t t, size_t a);
void compile();
void generate_llvm(llvm_context& ctx);
private: private:
std::string mangle_name(const std::string& n); std::vector<global_function_ptr> functions;
std::vector<global_constructor_ptr> constructors;
mangler* mng;
public:
global_scope(mangler& m) : mng(&m) {}
global_function& add_function(
const std::string& n,
std::vector<std::string> ps,
ast_ptr b);
global_constructor& add_constructor(const std::string& n, int8_t t, size_t a);
void compile();
void generate_llvm(llvm_context& ctx);
}; };

View File

@ -17,37 +17,38 @@ struct group {
using group_ptr = std::unique_ptr<group>; using group_ptr = std::unique_ptr<group>;
class function_graph { class function_graph {
using group_id = size_t; private:
using group_id = size_t;
struct group_data { struct group_data {
std::set<function> functions; std::set<function> functions;
std::set<group_id> adjacency_list; std::set<group_id> adjacency_list;
size_t indegree; size_t indegree;
group_data() : indegree(0) {} group_data() : indegree(0) {}
}; };
using data_ptr = std::shared_ptr<group_data>; using data_ptr = std::shared_ptr<group_data>;
using edge = std::pair<function, function>; using edge = std::pair<function, function>;
using group_edge = std::pair<group_id, group_id>; using group_edge = std::pair<group_id, group_id>;
std::map<function, std::set<function>> adjacency_lists; std::map<function, std::set<function>> adjacency_lists;
std::set<edge> edges; std::set<edge> edges;
std::set<edge> compute_transitive_edges(); std::set<edge> compute_transitive_edges();
void create_groups( void create_groups(
const std::set<edge>&, const std::set<edge>&,
std::map<function, group_id>&, std::map<function, group_id>&,
std::map<group_id, data_ptr>&); std::map<group_id, data_ptr>&);
void create_edges( void create_edges(
std::map<function, group_id>&, std::map<function, group_id>&,
std::map<group_id, data_ptr>&); std::map<group_id, data_ptr>&);
std::vector<group_ptr> generate_order( std::vector<group_ptr> generate_order(
std::map<function, group_id>&, std::map<function, group_id>&,
std::map<group_id, data_ptr>&); std::map<group_id, data_ptr>&);
public: public:
std::set<function>& add_function(const function& f); std::set<function>& add_function(const function& f);
void add_edge(const function& from, const function& to); void add_edge(const function& from, const function& to);
std::vector<group_ptr> compute_order(); std::vector<group_ptr> compute_order();
}; };

View File

@ -24,9 +24,9 @@ void instruction_pushglobal::print(int indent, std::ostream& to) const {
} }
void instruction_pushglobal::gen_llvm(llvm_context& ctx, Function* f) const { void instruction_pushglobal::gen_llvm(llvm_context& ctx, Function* f) const {
auto& global_f = ctx.custom_functions.at("f_" + name); auto& global_f = ctx.get_custom_function(name);
auto arity = ctx.create_i32(global_f->arity); auto arity = ctx.create_i32(global_f.arity);
ctx.create_push(f, ctx.create_global(f, global_f->function, arity)); ctx.create_push(f, ctx.create_global(f, global_f.function, arity));
} }
void instruction_push::print(int indent, std::ostream& to) const { void instruction_push::print(int indent, std::ostream& to) const {
@ -101,17 +101,17 @@ void instruction_jump::print(int indent, std::ostream& to) const {
void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const { void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const {
auto top_node = ctx.create_peek(f, ctx.create_size(0)); auto top_node = ctx.create_peek(f, ctx.create_size(0));
auto tag = ctx.unwrap_data_tag(top_node); auto tag = ctx.unwrap_data_tag(top_node);
auto safety_block = BasicBlock::Create(ctx.ctx, "safety", f); auto safety_block = ctx.create_basic_block("safety", f);
auto switch_op = ctx.builder.CreateSwitch(tag, safety_block, tag_mappings.size()); auto switch_op = ctx.get_builder().CreateSwitch(tag, safety_block, tag_mappings.size());
std::vector<BasicBlock*> blocks; std::vector<BasicBlock*> blocks;
for(auto& branch : branches) { for(auto& branch : branches) {
auto branch_block = BasicBlock::Create(ctx.ctx, "branch", f); auto branch_block = ctx.create_basic_block("branch", f);
ctx.builder.SetInsertPoint(branch_block); ctx.get_builder().SetInsertPoint(branch_block);
for(auto& instruction : branch) { for(auto& instruction : branch) {
instruction->gen_llvm(ctx, f); instruction->gen_llvm(ctx, f);
} }
ctx.builder.CreateBr(safety_block); ctx.get_builder().CreateBr(safety_block);
blocks.push_back(branch_block); blocks.push_back(branch_block);
} }
@ -119,46 +119,7 @@ void instruction_jump::gen_llvm(llvm_context& ctx, Function* f) const {
switch_op->addCase(ctx.create_i8(mapping.first), blocks[mapping.second]); switch_op->addCase(ctx.create_i8(mapping.first), blocks[mapping.second]);
} }
ctx.builder.SetInsertPoint(safety_block); ctx.get_builder().SetInsertPoint(safety_block);
}
void instruction_if::print(int indent, std::ostream& to) const {
print_indent(indent, to);
to << "If(" << std::endl;
for(auto& instruction : on_true) {
instruction->print(indent + 2, to);
}
to << std::endl;
for(auto& instruction : on_false) {
instruction->print(indent + 2, to);
}
print_indent(indent, to);
to << ")" << std::endl;
}
void instruction_if::gen_llvm(llvm_context& ctx, llvm::Function* f) const {
auto top_node = ctx.create_peek(f, ctx.create_size(0));
auto num = ctx.unwrap_num(top_node);
auto nonzero_block = BasicBlock::Create(ctx.ctx, "nonzero", f);
auto zero_block = BasicBlock::Create(ctx.ctx, "zero", f);
auto resume_block = BasicBlock::Create(ctx.ctx, "resume", f);
auto switch_op = ctx.builder.CreateSwitch(num, nonzero_block, 2);
switch_op->addCase(ctx.create_i32(0), zero_block);
ctx.builder.SetInsertPoint(nonzero_block);
for(auto& instruction : on_true) {
instruction->gen_llvm(ctx, f);
}
ctx.builder.CreateBr(resume_block);
ctx.builder.SetInsertPoint(zero_block);
for(auto& instruction : on_false) {
instruction->gen_llvm(ctx, f);
}
ctx.builder.CreateBr(resume_block);
ctx.builder.SetInsertPoint(resume_block);
} }
void instruction_slide::print(int indent, std::ostream& to) const { void instruction_slide::print(int indent, std::ostream& to) const {
@ -180,13 +141,10 @@ void instruction_binop::gen_llvm(llvm_context& ctx, Function* f) const {
auto right_int = ctx.unwrap_num(ctx.create_pop(f)); auto right_int = ctx.unwrap_num(ctx.create_pop(f));
llvm::Value* result; llvm::Value* result;
switch(op) { switch(op) {
case PLUS: result = ctx.builder.CreateAdd(left_int, right_int); break; case PLUS: result = ctx.get_builder().CreateAdd(left_int, right_int); break;
case MINUS: result = ctx.builder.CreateSub(left_int, right_int); break; case MINUS: result = ctx.get_builder().CreateSub(left_int, right_int); break;
case TIMES: result = ctx.builder.CreateMul(left_int, right_int); break; case TIMES: result = ctx.get_builder().CreateMul(left_int, right_int); break;
case DIVIDE: result = ctx.builder.CreateSDiv(left_int, right_int); break; case DIVIDE: result = ctx.get_builder().CreateSDiv(left_int, right_int); break;
case MODULO: result = ctx.builder.CreateSRem(left_int, right_int); break;
case EQUALS: result = ctx.builder.CreateICmpEQ(left_int, right_int); break;
case LESS_EQUALS: result = ctx.builder.CreateICmpSLE(left_int, right_int); break;
} }
ctx.create_push(f, ctx.create_num(f, result)); ctx.create_push(f, ctx.create_num(f, result));
} }

View File

@ -101,19 +101,6 @@ struct instruction_jump : public instruction {
void gen_llvm(llvm_context& ctx, llvm::Function* f) const; void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
}; };
struct instruction_if : public instruction {
std::vector<instruction_ptr> on_true;
std::vector<instruction_ptr> on_false;
instruction_if(
std::vector<instruction_ptr> t,
std::vector<instruction_ptr> f)
: on_true(std::move(t)), on_false(std::move(f)) {}
void print(int indent, std::ostream& to) const;
void gen_llvm(llvm_context& ctx, llvm::Function* f) const;
};
struct instruction_slide : public instruction { struct instruction_slide : public instruction {
int offset; int offset;

View File

@ -163,6 +163,18 @@ void llvm_context::create_functions() {
); );
} }
IRBuilder<>& llvm_context::get_builder() {
return builder;
}
Module& llvm_context::get_module() {
return module;
}
BasicBlock* llvm_context::create_basic_block(const std::string& name, llvm::Function* f) {
return BasicBlock::Create(ctx, name, f);
}
ConstantInt* llvm_context::create_i8(int8_t i) { ConstantInt* llvm_context::create_i8(int8_t i) {
return ConstantInt::get(ctx, APInt(8, i)); return ConstantInt::get(ctx, APInt(8, i));
} }
@ -225,7 +237,12 @@ Value* llvm_context::unwrap_gmachine_stack_ptr(Value* g) {
} }
Value* llvm_context::unwrap_num(Value* v) { Value* llvm_context::unwrap_num(Value* v) {
return builder.CreatePtrToInt(v, IntegerType::getInt32Ty(ctx)); auto num_ptr_type = PointerType::getUnqual(struct_types.at("node_num"));
auto cast = builder.CreatePointerCast(v, num_ptr_type);
auto offset_0 = create_i32(0);
auto offset_1 = create_i32(1);
auto int_ptr = builder.CreateGEP(cast, { offset_0, offset_1 });
return builder.CreateLoad(int_ptr);
} }
Value* llvm_context::create_num(Function* f, Value* v) { Value* llvm_context::create_num(Function* f, Value* v) {
auto alloc_num_f = functions.at("alloc_num"); auto alloc_num_f = functions.at("alloc_num");
@ -254,7 +271,7 @@ Value* llvm_context::create_app(Function* f, Value* l, Value* r) {
return create_track(f, alloc_app_call); return create_track(f, alloc_app_call);
} }
llvm::Function* llvm_context::create_custom_function(std::string name, int32_t arity) { llvm::Function* llvm_context::create_custom_function(const std::string& name, int32_t arity) {
auto void_type = llvm::Type::getVoidTy(ctx); auto void_type = llvm::Type::getVoidTy(ctx);
auto new_function = llvm::Function::Create( auto new_function = llvm::Function::Create(
function_type, function_type,
@ -271,3 +288,7 @@ llvm::Function* llvm_context::create_custom_function(std::string name, int32_t a
return new_function; return new_function;
} }
llvm_context::custom_function& llvm_context::get_custom_function(const std::string& name) {
return *custom_functions.at("f_" + name);
}

View File

@ -7,66 +7,75 @@
#include <llvm/IR/Value.h> #include <llvm/IR/Value.h>
#include <map> #include <map>
struct llvm_context { class llvm_context {
struct custom_function { public:
llvm::Function* function; struct custom_function {
int32_t arity; llvm::Function* function;
}; int32_t arity;
};
using custom_function_ptr = std::unique_ptr<custom_function>; using custom_function_ptr = std::unique_ptr<custom_function>;
llvm::LLVMContext ctx; private:
llvm::IRBuilder<> builder; llvm::LLVMContext ctx;
llvm::Module module; llvm::IRBuilder<> builder;
llvm::Module module;
std::map<std::string, custom_function_ptr> custom_functions; std::map<std::string, custom_function_ptr> custom_functions;
std::map<std::string, llvm::Function*> functions; std::map<std::string, llvm::Function*> functions;
std::map<std::string, llvm::StructType*> struct_types; std::map<std::string, llvm::StructType*> struct_types;
llvm::StructType* stack_type; llvm::StructType* stack_type;
llvm::StructType* gmachine_type; llvm::StructType* gmachine_type;
llvm::PointerType* stack_ptr_type; llvm::PointerType* stack_ptr_type;
llvm::PointerType* gmachine_ptr_type; llvm::PointerType* gmachine_ptr_type;
llvm::PointerType* node_ptr_type; llvm::PointerType* node_ptr_type;
llvm::IntegerType* tag_type; llvm::IntegerType* tag_type;
llvm::FunctionType* function_type; llvm::FunctionType* function_type;
llvm_context() void create_types();
: builder(ctx), module("bloglang", ctx) { void create_functions();
create_types();
create_functions();
}
void create_types(); public:
void create_functions(); llvm_context()
: builder(ctx), module("bloglang", ctx) {
create_types();
create_functions();
}
llvm::ConstantInt* create_i8(int8_t); llvm::IRBuilder<>& get_builder();
llvm::ConstantInt* create_i32(int32_t); llvm::Module& get_module();
llvm::ConstantInt* create_size(size_t);
llvm::Value* create_pop(llvm::Function*); llvm::BasicBlock* create_basic_block(const std::string& name, llvm::Function* f);
llvm::Value* create_peek(llvm::Function*, llvm::Value*);
void create_push(llvm::Function*, llvm::Value*);
void create_popn(llvm::Function*, llvm::Value*);
void create_update(llvm::Function*, llvm::Value*);
void create_pack(llvm::Function*, llvm::Value*, llvm::Value*);
void create_split(llvm::Function*, llvm::Value*);
void create_slide(llvm::Function*, llvm::Value*);
void create_alloc(llvm::Function*, llvm::Value*);
llvm::Value* create_track(llvm::Function*, llvm::Value*);
void create_unwind(llvm::Function*); llvm::ConstantInt* create_i8(int8_t);
llvm::ConstantInt* create_i32(int32_t);
llvm::ConstantInt* create_size(size_t);
llvm::Value* unwrap_gmachine_stack_ptr(llvm::Value*); llvm::Value* create_pop(llvm::Function*);
llvm::Value* create_peek(llvm::Function*, llvm::Value*);
void create_push(llvm::Function*, llvm::Value*);
void create_popn(llvm::Function*, llvm::Value*);
void create_update(llvm::Function*, llvm::Value*);
void create_pack(llvm::Function*, llvm::Value*, llvm::Value*);
void create_split(llvm::Function*, llvm::Value*);
void create_slide(llvm::Function*, llvm::Value*);
void create_alloc(llvm::Function*, llvm::Value*);
llvm::Value* create_track(llvm::Function*, llvm::Value*);
llvm::Value* unwrap_num(llvm::Value*); void create_unwind(llvm::Function*);
llvm::Value* create_num(llvm::Function*, llvm::Value*);
llvm::Value* unwrap_data_tag(llvm::Value*); llvm::Value* unwrap_gmachine_stack_ptr(llvm::Value*);
llvm::Value* create_global(llvm::Function*, llvm::Value*, llvm::Value*); llvm::Value* unwrap_num(llvm::Value*);
llvm::Value* create_num(llvm::Function*, llvm::Value*);
llvm::Value* create_app(llvm::Function*, llvm::Value*, llvm::Value*); llvm::Value* unwrap_data_tag(llvm::Value*);
llvm::Function* create_custom_function(std::string name, int32_t arity); llvm::Value* create_global(llvm::Function*, llvm::Value*, llvm::Value*);
llvm::Value* create_app(llvm::Function*, llvm::Value*, llvm::Value*);
llvm::Function* create_custom_function(const std::string& name, int32_t arity);
custom_function& get_custom_function(const std::string& name);
}; };

View File

@ -1,214 +1,27 @@
#include "ast.hpp" #include "ast.hpp"
#include <iostream> #include <iostream>
#include "binop.hpp"
#include "definition.hpp"
#include "graph.hpp"
#include "instruction.hpp"
#include "llvm_context.hpp"
#include "parser.hpp" #include "parser.hpp"
#include "compiler.hpp"
#include "error.hpp" #include "error.hpp"
#include "type.hpp"
#include "parse_driver.hpp"
#include "type_env.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Target/TargetMachine.h"
void yy::parser::error(const yy::location& loc, const std::string& msg) { void yy::parser::error(const yy::location& loc, const std::string& msg) {
std::cerr << "An error occured: " << msg << std::endl; std::cerr << "An error occured: " << msg << std::endl;
} }
void prelude_types(definition_group& defs, type_env_ptr env) {
type_ptr int_type = type_ptr(new type_internal("Int"));
env->bind_type("Int", int_type);
type_ptr int_type_app = type_ptr(new type_app(int_type));
type_ptr bool_type = type_ptr(new type_internal("Bool"));
env->bind_type("Bool", bool_type);
type_ptr bool_type_app = type_ptr(new type_app(bool_type));
type_ptr binop_type = type_ptr(new type_arr(
int_type_app,
type_ptr(new type_arr(int_type_app, int_type_app))));
type_ptr cmp_type = type_ptr(new type_arr(
int_type_app,
type_ptr(new type_arr(int_type_app, bool_type_app))));
constexpr binop number_ops[] = { PLUS, MINUS, TIMES, DIVIDE, MODULO };
constexpr binop cmp_ops[] = { EQUALS, LESS_EQUALS };
for(auto& op : number_ops) {
env->bind(op_name(op), binop_type, visibility::global);
env->set_mangled_name(op_name(op), op_action(op));
}
for(auto& op : cmp_ops) {
env->bind(op_name(op), cmp_type, visibility::global);
env->set_mangled_name(op_name(op), op_action(op));
}
env->bind("True", bool_type_app, visibility::global);
env->bind("False", bool_type_app, visibility::global);
}
void typecheck_program(
definition_group& defs,
type_mgr& mgr, type_env_ptr& env) {
prelude_types(defs, env);
std::set<std::string> free;
defs.find_free(free);
defs.typecheck(mgr, env);
#ifdef DEBUG_OUT
for(auto& pair : defs.env->names) {
std::cout << pair.first << ": ";
pair.second.type->print(mgr, std::cout);
std::cout << std::endl;
}
#endif
}
global_scope translate_program(definition_group& group) {
global_scope scope;
for(auto& data : group.defs_data) {
data.second->into_globals(scope);
}
for(auto& defn : group.defs_defn) {
auto& function = defn.second->into_global(scope);
function.body->env->parent->set_mangled_name(defn.first, function.name);
}
return scope;
}
void output_llvm(llvm_context& ctx, const std::string& filename) {
std::string targetTriple = llvm::sys::getDefaultTargetTriple();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::InitializeNativeTargetAsmPrinter();
std::string error;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(targetTriple, error);
if (!target) {
std::cerr << error << std::endl;
} else {
std::string cpu = "generic";
std::string features = "";
llvm::TargetOptions options;
std::unique_ptr<llvm::TargetMachine> targetMachine(
target->createTargetMachine(targetTriple, cpu, features,
options, llvm::Optional<llvm::Reloc::Model>()));
ctx.module.setDataLayout(targetMachine->createDataLayout());
ctx.module.setTargetTriple(targetTriple);
std::error_code ec;
llvm::raw_fd_ostream file(filename, ec, llvm::sys::fs::F_None);
if (ec) {
throw compiler_error("failed to open object file for writing");
} else {
llvm::CodeGenFileType type = llvm::CGFT_ObjectFile;
llvm::legacy::PassManager pm;
if (targetMachine->addPassesToEmitFile(pm, file, NULL, type)) {
throw compiler_error("failed to add passes to pass manager");
} else {
pm.run(ctx.module);
file.close();
}
}
}
}
void gen_llvm_internal_op(llvm_context& ctx, binop op) {
auto new_function = ctx.create_custom_function(op_action(op), 2);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_push(1)));
instructions.push_back(instruction_ptr(new instruction_eval()));
instructions.push_back(instruction_ptr(new instruction_binop(op)));
instructions.push_back(instruction_ptr(new instruction_update(2)));
instructions.push_back(instruction_ptr(new instruction_pop(2)));
ctx.builder.SetInsertPoint(&new_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.builder.CreateRetVoid();
}
void gen_llvm_boolean_constructor(llvm_context& ctx, const std::string& s, bool b) {
auto new_function = ctx.create_custom_function(s, 0);
std::vector<instruction_ptr> instructions;
instructions.push_back(instruction_ptr(new instruction_pushint(b)));
instructions.push_back(instruction_ptr(new instruction_update(0)));
ctx.builder.SetInsertPoint(&new_function->getEntryBlock());
for(auto& instruction : instructions) {
instruction->gen_llvm(ctx, new_function);
}
ctx.builder.CreateRetVoid();
}
void gen_llvm(global_scope& scope) {
llvm_context ctx;
gen_llvm_internal_op(ctx, PLUS);
gen_llvm_internal_op(ctx, MINUS);
gen_llvm_internal_op(ctx, TIMES);
gen_llvm_internal_op(ctx, DIVIDE);
gen_llvm_internal_op(ctx, MODULO);
gen_llvm_internal_op(ctx, EQUALS);
gen_llvm_internal_op(ctx, LESS_EQUALS);
gen_llvm_boolean_constructor(ctx, "True", true);
gen_llvm_boolean_constructor(ctx, "False", false);
scope.generate_llvm(ctx);
#ifdef DEBUG_OUT
ctx.module.print(llvm::outs(), nullptr);
#endif
output_llvm(ctx, "program.o");
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
if(argc != 2) { if(argc != 2) {
std::cerr << "please enter a file to compile." << std::endl; std::cerr << "please enter a file to compile." << std::endl;
exit(1); exit(1);
} }
parse_driver driver(argv[1]); compiler cmp(argv[1]);
if(!driver.run_parse()) {
std::cerr << "failed to parse file " << argv[1] << std::endl;
exit(1);
}
type_mgr mgr;
type_env_ptr env(new type_env);
#ifdef DEBUG_OUT
for(auto& def_defn : driver.global_defs.defs_defn) {
std::cout << def_defn.second->name;
for(auto& param : def_defn.second->params) std::cout << " " << param;
std::cout << ":" << std::endl;
def_defn.second->body->print(1, std::cout);
std::cout << std::endl;
}
#endif
try { try {
typecheck_program(driver.global_defs, mgr, env); cmp("program.o");
global_scope scope = translate_program(driver.global_defs);
scope.compile();
gen_llvm(scope);
} catch(unification_error& err) { } catch(unification_error& err) {
err.pretty_print(std::cerr, driver, mgr); err.pretty_print(std::cerr, cmp.get_file_manager(), cmp.get_type_manager());
} catch(type_error& err) { } catch(type_error& err) {
err.pretty_print(std::cerr, driver); err.pretty_print(std::cerr, cmp.get_file_manager());
} catch (compiler_error& err) { } catch (compiler_error& err) {
err.pretty_print(std::cerr, driver); err.pretty_print(std::cerr, cmp.get_file_manager());
} }
} }

View File

@ -0,0 +1,17 @@
#include "mangler.hpp"
std::string mangler::new_mangled_name(const std::string& n) {
auto occurence_it = occurence_count.find(n);
int occurence = 0;
if(occurence_it != occurence_count.end()) {
occurence = occurence_it->second + 1;
}
occurence_count[n] = occurence;
std::string final_name = n;
if (occurence != 0) {
final_name += "_";
final_name += std::to_string(occurence);
}
return final_name;
}

View File

@ -0,0 +1,11 @@
#pragma once
#include <string>
#include <map>
class mangler {
private:
std::map<std::string, int> occurence_count;
public:
std::string new_mangled_name(const std::string& str);
};

View File

@ -2,44 +2,37 @@
#include "scanner.hpp" #include "scanner.hpp"
#include <sstream> #include <sstream>
bool parse_driver::run_parse() { file_mgr::file_mgr() : file_offset(0) {
FILE* stream = fopen(file_name.c_str(), "r");
if(!stream) return false;
line_offsets.push_back(0); line_offsets.push_back(0);
yyscan_t scanner;
yylex_init(&scanner);
yyset_in(stream, scanner);
yy::parser parser(scanner, *this);
parser();
yylex_destroy(scanner);
fclose(stream);
file_contents = string_stream.str();
return true;
} }
void parse_driver::write(const char* buf, size_t len) { void file_mgr::write(const char* buf, size_t len) {
string_stream.write(buf, len); string_stream.write(buf, len);
file_offset += len; file_offset += len;
} }
void parse_driver::mark_line() { void file_mgr::mark_line() {
line_offsets.push_back(file_offset); line_offsets.push_back(file_offset);
} }
size_t parse_driver::get_index(int line, int column) { void file_mgr::finalize() {
assert(line > 0 && line <= line_offsets.size()); file_contents = string_stream.str();
return line_offsets[line-1] + column - 1;
} }
size_t parse_driver::get_line_end(int line) { size_t file_mgr::get_index(int line, int column) const {
assert(line > 0 && line <= line_offsets.size());
return line_offsets.at(line-1) + column - 1;
}
size_t file_mgr::get_line_end(int line) const {
if(line == line_offsets.size()) return file_contents.size(); if(line == line_offsets.size()) return file_contents.size();
return get_index(line+1, 1); return get_index(line+1, 1);
} }
void parse_driver::print_location( void file_mgr::print_location(
std::ostream& stream, std::ostream& stream,
const yy::location& loc, const yy::location& loc,
bool highlight) { bool highlight) const {
size_t print_start = get_index(loc.begin.line, 1); size_t print_start = get_index(loc.begin.line, 1);
size_t highlight_start = get_index(loc.begin.line, loc.begin.column); size_t highlight_start = get_index(loc.begin.line, loc.begin.column);
size_t highlight_end = get_index(loc.end.line, loc.end.column); size_t highlight_end = get_index(loc.end.line, loc.end.column);
@ -51,3 +44,29 @@ void parse_driver::print_location(
if(highlight) stream << "\033[0m"; if(highlight) stream << "\033[0m";
stream.write(content + highlight_end, print_end - highlight_end); stream.write(content + highlight_end, print_end - highlight_end);
} }
bool parse_driver::operator()() {
FILE* stream = fopen(file_name.c_str(), "r");
if(!stream) return false;
yyscan_t scanner;
yylex_init(&scanner);
yyset_in(stream, scanner);
yy::parser parser(scanner, *this);
parser();
yylex_destroy(scanner);
fclose(stream);
file_m->finalize();
return true;
}
yy::location& parse_driver::get_current_location() {
return location;
}
file_mgr& parse_driver::get_file_manager() const {
return *file_m;
}
definition_group& parse_driver::get_global_defs() const {
return *global_defs;
}

View File

@ -11,29 +11,46 @@ struct parse_driver;
void scanner_init(parse_driver* d, yyscan_t* scanner); void scanner_init(parse_driver* d, yyscan_t* scanner);
void scanner_destroy(yyscan_t* scanner); void scanner_destroy(yyscan_t* scanner);
struct parse_driver { class file_mgr {
std::string file_name; private:
std::ostringstream string_stream; std::ostringstream string_stream;
std::string file_contents; std::string file_contents;
yy::location location; size_t file_offset;
size_t file_offset; std::vector<size_t> line_offsets;
std::vector<size_t> line_offsets; public:
file_mgr();
definition_group global_defs; void write(const char* buffer, size_t len);
void mark_line();
void finalize();
parse_driver(const std::string& file) size_t get_index(int line, int column) const;
: file_name(file), file_offset(0) {} size_t get_line_end(int line) const;
void print_location(
std::ostream& stream,
const yy::location& loc,
bool highlight = true) const;
};
bool run_parse(); class parse_driver {
void write(const char* buff, size_t len); private:
void mark_line(); std::string file_name;
size_t get_index(int line, int column); yy::location location;
size_t get_line_end(int line); definition_group* global_defs;
void print_location( file_mgr* file_m;
std::ostream& stream,
const yy::location& loc, public:
bool highlight = true); parse_driver(
file_mgr& mgr,
definition_group& defs,
const std::string& file)
: file_name(file), file_m(&mgr), global_defs(&defs) {}
bool operator()();
yy::location& get_current_location();
file_mgr& get_file_manager() const;
definition_group& get_global_defs() const;
}; };
#define YY_DECL yy::parser::symbol_type yylex(yyscan_t yyscanner, parse_driver& drv) #define YY_DECL yy::parser::symbol_type yylex(yyscan_t yyscanner, parse_driver& drv)

View File

@ -18,14 +18,10 @@ using yyscan_t = void*;
} }
%token BACKSLASH %token BACKSLASH
%token BACKTICK
%token PLUS %token PLUS
%token TIMES %token TIMES
%token MINUS %token MINUS
%token DIVIDE %token DIVIDE
%token MODULO
%token EQUALS
%token LESS_EQUALS
%token <int> INT %token <int> INT
%token DEFN %token DEFN
%token DATA %token DATA
@ -53,10 +49,9 @@ using yyscan_t = void*;
%type <std::vector<branch_ptr>> branches %type <std::vector<branch_ptr>> branches
%type <std::vector<constructor_ptr>> constructors %type <std::vector<constructor_ptr>> constructors
%type <std::vector<parsed_type_ptr>> typeList %type <std::vector<parsed_type_ptr>> typeList
%type <binop> anyBinop
%type <definition_group> definitions %type <definition_group> definitions
%type <parsed_type_ptr> type nonArrowType typeListElement %type <parsed_type_ptr> type nonArrowType typeListElement
%type <ast_ptr> aInfix aEq aAdd aMul case let lambda app appBase %type <ast_ptr> aAdd aMul case let lambda app appBase
%type <definition_data_ptr> data %type <definition_data_ptr> data
%type <definition_defn_ptr> defn %type <definition_defn_ptr> defn
%type <branch_ptr> branch %type <branch_ptr> branch
@ -68,7 +63,7 @@ using yyscan_t = void*;
%% %%
program program
: definitions { $1.vis = visibility::global; std::swap(drv.global_defs, $1); } : definitions { $1.vis = visibility::global; std::swap(drv.get_global_defs(), $1); }
; ;
definitions definitions
@ -78,7 +73,7 @@ definitions
; ;
defn defn
: DEFN LID lowercaseParams EQUAL OCURLY aInfix CCURLY : DEFN LID lowercaseParams EQUAL OCURLY aAdd CCURLY
{ $$ = definition_defn_ptr( { $$ = definition_defn_ptr(
new definition_defn(std::move($2), std::move($3), std::move($6), @$)); } new definition_defn(std::move($2), std::move($3), std::move($6), @$)); }
; ;
@ -88,22 +83,6 @@ lowercaseParams
| lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); } | lowercaseParams LID { $$ = std::move($1); $$.push_back(std::move($2)); }
; ;
aInfix
: aInfix BACKTICK LID BACKTICK aEq
{ $$ = ast_ptr(new ast_app(
ast_ptr(new ast_app(ast_ptr(new ast_lid(std::move($3))), std::move($1))), std::move($5))); }
| aInfix BACKTICK UID BACKTICK aEq
{ $$ = ast_ptr(new ast_app(
ast_ptr(new ast_app(ast_ptr(new ast_uid(std::move($3))), std::move($1))), std::move($5))); }
| aEq { $$ = std::move($1); }
;
aEq
: aAdd EQUALS aAdd { $$ = ast_ptr(new ast_binop(EQUALS, std::move($1), std::move($3), @$)); }
| aAdd LESS_EQUALS aAdd { $$ = ast_ptr(new ast_binop(LESS_EQUALS, std::move($1), std::move($3), @$)); }
| aAdd { $$ = std::move($1); }
;
aAdd aAdd
: aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3), @$)); } : aAdd PLUS aMul { $$ = ast_ptr(new ast_binop(PLUS, std::move($1), std::move($3), @$)); }
| aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3), @$)); } | aAdd MINUS aMul { $$ = ast_ptr(new ast_binop(MINUS, std::move($1), std::move($3), @$)); }
@ -113,7 +92,6 @@ aAdd
aMul aMul
: aMul TIMES app { $$ = ast_ptr(new ast_binop(TIMES, std::move($1), std::move($3), @$)); } : aMul TIMES app { $$ = ast_ptr(new ast_binop(TIMES, std::move($1), std::move($3), @$)); }
| aMul DIVIDE app { $$ = ast_ptr(new ast_binop(DIVIDE, std::move($1), std::move($3), @$)); } | aMul DIVIDE app { $$ = ast_ptr(new ast_binop(DIVIDE, std::move($1), std::move($3), @$)); }
| aMul MODULO app { $$ = ast_ptr(new ast_binop(MODULO, std::move($1), std::move($3), @$)); }
| app { $$ = std::move($1); } | app { $$ = std::move($1); }
; ;
@ -126,35 +104,24 @@ appBase
: INT { $$ = ast_ptr(new ast_int($1, @$)); } : INT { $$ = ast_ptr(new ast_int($1, @$)); }
| LID { $$ = ast_ptr(new ast_lid(std::move($1), @$)); } | LID { $$ = ast_ptr(new ast_lid(std::move($1), @$)); }
| UID { $$ = ast_ptr(new ast_uid(std::move($1), @$)); } | UID { $$ = ast_ptr(new ast_uid(std::move($1), @$)); }
| OPAREN aInfix CPAREN { $$ = std::move($2); } | OPAREN aAdd CPAREN { $$ = std::move($2); }
| OPAREN anyBinop CPAREN { $$ = ast_ptr(new ast_lid(op_name($2))); }
| case { $$ = std::move($1); } | case { $$ = std::move($1); }
| let { $$ = std::move($1); } | let { $$ = std::move($1); }
| lambda { $$ = std::move($1); } | lambda { $$ = std::move($1); }
; ;
anyBinop
: PLUS { $$ = PLUS; }
| MINUS { $$ = MINUS; }
| TIMES { $$ = TIMES; }
| DIVIDE { $$ = DIVIDE; }
| MODULO { $$ = MODULO; }
| EQUALS { $$ = EQUALS; }
| LESS_EQUALS { $$ = LESS_EQUALS; }
;
let let
: LET OCURLY definitions CCURLY IN OCURLY aInfix CCURLY : LET OCURLY definitions CCURLY IN OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_let(std::move($3), std::move($7), @$)); } { $$ = ast_ptr(new ast_let(std::move($3), std::move($7), @$)); }
; ;
lambda lambda
: BACKSLASH lowercaseParams ARROW OCURLY aInfix CCURLY : BACKSLASH lowercaseParams ARROW OCURLY aAdd CCURLY
{ $$ = ast_ptr(new ast_lambda(std::move($2), std::move($5), @$)); } { $$ = ast_ptr(new ast_lambda(std::move($2), std::move($5), @$)); }
; ;
case case
: CASE aInfix OF OCURLY branches CCURLY : CASE aAdd OF OCURLY branches CCURLY
{ $$ = ast_ptr(new ast_case(std::move($2), std::move($5), @$)); } { $$ = ast_ptr(new ast_case(std::move($2), std::move($5), @$)); }
; ;
@ -164,7 +131,7 @@ branches
; ;
branch branch
: pattern ARROW OCURLY aInfix CCURLY : pattern ARROW OCURLY aAdd CCURLY
{ $$ = branch_ptr(new branch(std::move($1), std::move($4))); } { $$ = branch_ptr(new branch(std::move($1), std::move($4))); }
; ;

View File

@ -1,13 +1,9 @@
#include <bits/stdint-intn.h>
#include <stdint.h> #include <stdint.h>
#include <assert.h> #include <assert.h>
#include <memory.h> #include <memory.h>
#include <stdio.h> #include <stdio.h>
#include "runtime.h" #include "runtime.h"
#define INT_MARKER (1l << 63)
#define IS_INT(n) ((uint64_t) n & INT_MARKER)
struct node_base* alloc_node() { struct node_base* alloc_node() {
struct node_base* new_node = malloc(sizeof(struct node_app)); struct node_base* new_node = malloc(sizeof(struct node_app));
new_node->gc_next = NULL; new_node->gc_next = NULL;
@ -25,7 +21,10 @@ struct node_app* alloc_app(struct node_base* l, struct node_base* r) {
} }
struct node_num* alloc_num(int32_t n) { struct node_num* alloc_num(int32_t n) {
return (struct node_num*) (INT_MARKER | n); struct node_num* node = (struct node_num*) alloc_node();
node->base.tag = NODE_NUM;
node->value = n;
return node;
} }
struct node_global* alloc_global(void (*f)(struct gmachine*), int32_t a) { struct node_global* alloc_global(void (*f)(struct gmachine*), int32_t a) {
@ -50,7 +49,7 @@ void free_node_direct(struct node_base* n) {
} }
void gc_visit_node(struct node_base* n) { void gc_visit_node(struct node_base* n) {
if(IS_INT(n) || n->gc_reachable) return; if(n->gc_reachable) return;
n->gc_reachable = 1; n->gc_reachable = 1;
if(n->tag == NODE_APP) { if(n->tag == NODE_APP) {
@ -170,7 +169,6 @@ void gmachine_split(struct gmachine* g, size_t n) {
} }
struct node_base* gmachine_track(struct gmachine* g, struct node_base* b) { struct node_base* gmachine_track(struct gmachine* g, struct node_base* b) {
if(IS_INT(b)) return b;
g->gc_node_count++; g->gc_node_count++;
b->gc_next = g->gc_nodes; b->gc_next = g->gc_nodes;
g->gc_nodes = b; g->gc_nodes = b;
@ -210,9 +208,7 @@ void unwind(struct gmachine* g) {
while(1) { while(1) {
struct node_base* peek = stack_peek(s, 0); struct node_base* peek = stack_peek(s, 0);
if(IS_INT(peek)) { if(peek->tag == NODE_APP) {
break;
} else if(peek->tag == NODE_APP) {
struct node_app* n = (struct node_app*) peek; struct node_app* n = (struct node_app*) peek;
stack_push(s, n->left); stack_push(s, n->left);
} else if(peek->tag == NODE_GLOBAL) { } else if(peek->tag == NODE_GLOBAL) {
@ -238,9 +234,7 @@ void unwind(struct gmachine* g) {
extern void f_main(struct gmachine* s); extern void f_main(struct gmachine* s);
void print_node(struct node_base* n) { void print_node(struct node_base* n) {
if(IS_INT(n)) { if(n->tag == NODE_APP) {
printf("%d", (int32_t) n);
} else if(n->tag == NODE_APP) {
struct node_app* app = (struct node_app*) n; struct node_app* app = (struct node_app*) n;
print_node(app->left); print_node(app->left);
putchar(' '); putchar(' ');
@ -252,6 +246,9 @@ void print_node(struct node_base* n) {
printf("(Global: %p)", global->function); printf("(Global: %p)", global->function);
} else if(n->tag == NODE_IND) { } else if(n->tag == NODE_IND) {
print_node(((struct node_ind*) n)->next); print_node(((struct node_ind*) n)->next);
} else if(n->tag == NODE_NUM) {
struct node_num* num = (struct node_num*) n;
printf("%d", num->value);
} }
} }

View File

@ -9,38 +9,37 @@
#include "parse_driver.hpp" #include "parse_driver.hpp"
#include "parser.hpp" #include "parser.hpp"
#define YY_USER_ACTION drv.write(yytext, yyleng); drv.location.step(); drv.location.columns(yyleng); #define YY_USER_ACTION \
drv.get_file_manager().write(yytext, yyleng); \
LOC.step(); LOC.columns(yyleng);
#define LOC drv.get_current_location()
%} %}
%% %%
\n { drv.location.lines(); drv.mark_line(); } \n { drv.get_current_location().lines(); drv.get_file_manager().mark_line(); }
[ ]+ {} [ ]+ {}
\\ { return yy::parser::make_BACKSLASH(drv.location); } \\ { return yy::parser::make_BACKSLASH(LOC); }
\+ { return yy::parser::make_PLUS(drv.location); } \+ { return yy::parser::make_PLUS(LOC); }
\* { return yy::parser::make_TIMES(drv.location); } \* { return yy::parser::make_TIMES(LOC); }
- { return yy::parser::make_MINUS(drv.location); } - { return yy::parser::make_MINUS(LOC); }
\/ { return yy::parser::make_DIVIDE(drv.location); } \/ { return yy::parser::make_DIVIDE(LOC); }
% { return yy::parser::make_MODULO(drv.location); } [0-9]+ { return yy::parser::make_INT(atoi(yytext), LOC); }
== { return yy::parser::make_EQUALS(drv.location); } defn { return yy::parser::make_DEFN(LOC); }
\<= { return yy::parser::make_LESS_EQUALS(drv.location); } data { return yy::parser::make_DATA(LOC); }
` { return yy::parser::make_BACKTICK(drv.location); } case { return yy::parser::make_CASE(LOC); }
[0-9]+ { return yy::parser::make_INT(atoi(yytext), drv.location); } of { return yy::parser::make_OF(LOC); }
defn { return yy::parser::make_DEFN(drv.location); } let { return yy::parser::make_LET(LOC); }
data { return yy::parser::make_DATA(drv.location); } in { return yy::parser::make_IN(LOC); }
case { return yy::parser::make_CASE(drv.location); } \{ { return yy::parser::make_OCURLY(LOC); }
of { return yy::parser::make_OF(drv.location); } \} { return yy::parser::make_CCURLY(LOC); }
let { return yy::parser::make_LET(drv.location); } \( { return yy::parser::make_OPAREN(LOC); }
in { return yy::parser::make_IN(drv.location); } \) { return yy::parser::make_CPAREN(LOC); }
\{ { return yy::parser::make_OCURLY(drv.location); } , { return yy::parser::make_COMMA(LOC); }
\} { return yy::parser::make_CCURLY(drv.location); } -> { return yy::parser::make_ARROW(LOC); }
\( { return yy::parser::make_OPAREN(drv.location); } = { return yy::parser::make_EQUAL(LOC); }
\) { return yy::parser::make_CPAREN(drv.location); } [a-z][a-zA-Z]* { return yy::parser::make_LID(std::string(yytext), LOC); }
, { return yy::parser::make_COMMA(drv.location); } [A-Z][a-zA-Z]* { return yy::parser::make_UID(std::string(yytext), LOC); }
-> { return yy::parser::make_ARROW(drv.location); } <<EOF>> { return yy::parser::make_YYEOF(LOC); }
= { return yy::parser::make_EQUAL(drv.location); }
[a-z][a-zA-Z]* { return yy::parser::make_LID(std::string(yytext), drv.location); }
[A-Z][a-zA-Z]* { return yy::parser::make_UID(std::string(yytext), drv.location); }
<<EOF>> { return yy::parser::make_YYEOF(drv.location); }
%% %%

View File

@ -26,9 +26,9 @@ type_ptr type_scheme::instantiate(type_mgr& mgr) const {
} }
void type_var::print(const type_mgr& mgr, std::ostream& to) const { void type_var::print(const type_mgr& mgr, std::ostream& to) const {
auto it = mgr.types.find(name); auto type = mgr.lookup(name);
if(it != mgr.types.end()) { if(type) {
it->second->print(mgr, to); type->print(mgr, to);
} else { } else {
to << name; to << name;
} }
@ -38,10 +38,6 @@ void type_base::print(const type_mgr& mgr, std::ostream& to) const {
to << name; to << name;
} }
void type_internal::print(const type_mgr& mgr, std::ostream& to) const {
to << "!" << name;
}
void type_arr::print(const type_mgr& mgr, std::ostream& to) const { void type_arr::print(const type_mgr& mgr, std::ostream& to) const {
type_var* var; type_var* var;
bool print_parenths = dynamic_cast<type_arr*>(mgr.resolve(left, var).get()) != nullptr; bool print_parenths = dynamic_cast<type_arr*>(mgr.resolve(left, var).get()) != nullptr;
@ -82,6 +78,12 @@ type_ptr type_mgr::new_arrow_type() {
return type_ptr(new type_arr(new_type(), new_type())); return type_ptr(new type_arr(new_type(), new_type()));
} }
type_ptr type_mgr::lookup(const std::string& var) const {
auto types_it = types.find(var);
if(types_it != types.end()) return types_it->second;
return nullptr;
}
type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const { type_ptr type_mgr::resolve(type_ptr t, type_var*& var) const {
type_var* cast; type_var* cast;
@ -122,8 +124,7 @@ void type_mgr::unify(type_ptr l, type_ptr r, const std::optional<yy::location>&
} else if((lid = dynamic_cast<type_base*>(l.get())) && } else if((lid = dynamic_cast<type_base*>(l.get())) &&
(rid = dynamic_cast<type_base*>(r.get()))) { (rid = dynamic_cast<type_base*>(r.get()))) {
if(lid->name == rid->name && if(lid->name == rid->name &&
lid->arity == rid->arity && lid->arity == rid->arity)
lid->is_internal() == rid->is_internal())
return; return;
} else if((lapp = dynamic_cast<type_app*>(l.get())) && } else if((lapp = dynamic_cast<type_app*>(l.get())) &&
(rapp = dynamic_cast<type_app*>(r.get()))) { (rapp = dynamic_cast<type_app*>(r.get()))) {

View File

@ -7,7 +7,7 @@
#include <optional> #include <optional>
#include "location.hh" #include "location.hh"
struct type_mgr; class type_mgr;
struct type { struct type {
virtual ~type() = default; virtual ~type() = default;
@ -46,17 +46,6 @@ struct type_base : public type {
: name(std::move(n)), arity(a) {} : name(std::move(n)), arity(a) {}
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
virtual bool is_internal() const { return false; }
};
struct type_internal : public type_base {
type_internal(std::string n, int32_t a = 0)
: type_base(std::move(n), a) {}
void print(const type_mgr& mgr, std::ostream& to) const;
bool is_internal() const { return true; }
}; };
struct type_data : public type_base { struct type_data : public type_base {
@ -90,20 +79,23 @@ struct type_app : public type {
void print(const type_mgr& mgr, std::ostream& to) const; void print(const type_mgr& mgr, std::ostream& to) const;
}; };
struct type_mgr { class type_mgr {
int last_id = 0; private:
std::map<std::string, type_ptr> types; int last_id = 0;
std::map<std::string, type_ptr> types;
std::string new_type_name(); public:
type_ptr new_type(); std::string new_type_name();
type_ptr new_arrow_type(); type_ptr new_type();
type_ptr new_arrow_type();
void unify(type_ptr l, type_ptr r, const std::optional<yy::location>& loc = std::nullopt); void unify(type_ptr l, type_ptr r, const std::optional<yy::location>& loc = std::nullopt);
type_ptr substitute( type_ptr substitute(
const std::map<std::string, type_ptr>& subst, const std::map<std::string, type_ptr>& subst,
const type_ptr& t) const; const type_ptr& t) const;
type_ptr resolve(type_ptr t, type_var*& var) const; type_ptr lookup(const std::string& var) const;
void bind(const std::string& s, type_ptr t); type_ptr resolve(type_ptr t, type_var*& var) const;
void find_free(const type_ptr& t, std::set<std::string>& into) const; void bind(const std::string& s, type_ptr t);
void find_free(const type_scheme_ptr& t, std::set<std::string>& into) const; void find_free(const type_ptr& t, std::set<std::string>& into) const;
void find_free(const type_scheme_ptr& t, std::set<std::string>& into) const;
}; };

View File

@ -2,6 +2,10 @@
#include "type.hpp" #include "type.hpp"
#include "error.hpp" #include "error.hpp"
type_env_ptr type_env::get_parent() {
return parent;
}
void type_env::find_free(const type_mgr& mgr, std::set<std::string>& into) const { void type_env::find_free(const type_mgr& mgr, std::set<std::string>& into) const {
if(parent != nullptr) parent->find_free(mgr, into); if(parent != nullptr) parent->find_free(mgr, into);
for(auto& binding : names) { for(auto& binding : names) {
@ -44,10 +48,9 @@ void type_env::set_mangled_name(const std::string& name, const std::string& mang
const std::string& type_env::get_mangled_name(const std::string& name) const { const std::string& type_env::get_mangled_name(const std::string& name) const {
auto it = names.find(name); auto it = names.find(name);
if(it != names.end()) if(it != names.end()) return it->second.mangled_name;
return (it->second.mangled_name != "") ? it->second.mangled_name : name; assert(parent != nullptr);
if(parent) return parent->get_mangled_name(name); return parent->get_mangled_name(name);
return name;
} }
type_ptr type_env::lookup_type(const std::string& name) const { type_ptr type_env::lookup_type(const std::string& name) const {

View File

@ -10,39 +10,42 @@ using type_env_ptr = std::shared_ptr<type_env>;
enum class visibility { global,local }; enum class visibility { global,local };
struct type_env { class type_env {
struct variable_data { private:
type_scheme_ptr type; struct variable_data {
visibility vis; type_scheme_ptr type;
std::string mangled_name; visibility vis;
std::string mangled_name;
variable_data() variable_data()
: variable_data(nullptr, visibility::local, "") {} : variable_data(nullptr, visibility::local, "") {}
variable_data(type_scheme_ptr t, visibility v, std::string n) variable_data(type_scheme_ptr t, visibility v, std::string n)
: type(std::move(t)), vis(v), mangled_name(std::move(n)) {} : type(std::move(t)), vis(v), mangled_name(std::move(n)) {}
}; };
type_env_ptr parent; type_env_ptr parent;
std::map<std::string, variable_data> names; std::map<std::string, variable_data> names;
std::map<std::string, type_ptr> type_names; std::map<std::string, type_ptr> type_names;
type_env(type_env_ptr p) : parent(std::move(p)) {} public:
type_env() : type_env(nullptr) {} type_env(type_env_ptr p) : parent(std::move(p)) {}
type_env() : type_env(nullptr) {}
void find_free(const type_mgr& mgr, std::set<std::string>& into) const; type_env_ptr get_parent();
void find_free_except(const type_mgr& mgr, const group& avoid, void find_free(const type_mgr& mgr, std::set<std::string>& into) const;
std::set<std::string>& into) const; void find_free_except(const type_mgr& mgr, const group& avoid,
type_scheme_ptr lookup(const std::string& name) const; std::set<std::string>& into) const;
bool is_global(const std::string& name) const; type_scheme_ptr lookup(const std::string& name) const;
void set_mangled_name(const std::string& name, const std::string& mangled); bool is_global(const std::string& name) const;
const std::string& get_mangled_name(const std::string& name) const; void set_mangled_name(const std::string& name, const std::string& mangled);
type_ptr lookup_type(const std::string& name) const; const std::string& get_mangled_name(const std::string& name) const;
void bind(const std::string& name, type_ptr t, type_ptr lookup_type(const std::string& name) const;
visibility v = visibility::local); void bind(const std::string& name, type_ptr t,
void bind(const std::string& name, type_scheme_ptr t, visibility v = visibility::local);
visibility v = visibility::local); void bind(const std::string& name, type_scheme_ptr t,
void bind_type(const std::string& type_name, type_ptr t); visibility v = visibility::local);
void generalize(const std::string& name, const group& grp, type_mgr& mgr); void bind_type(const std::string& type_name, type_ptr t);
void generalize(const std::string& name, const group& grp, type_mgr& mgr);
}; };