#include "ast.hpp"
#include <ostream>
#include "binop.hpp"
#include "error.hpp"
#include "type_env.hpp"
#include "env.hpp"

static void print_indent(int n, std::ostream& to) {
    while(n--) to << "  ";
}

void ast_int::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "INT: " << value << std::endl;
}

void ast_int::find_free(std::set<std::string>& into) {

}

type_ptr ast_int::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    return type_ptr(new type_app(env->lookup_type("Int")));
}

void ast_int::translate(global_scope& scope) {

}

void ast_int::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    into.push_back(instruction_ptr(new instruction_pushint(value)));
}

void ast_lid::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "LID: " << id << std::endl;
}

void ast_lid::find_free(std::set<std::string>& into) {
    into.insert(id);
}

type_ptr ast_lid::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    return env->lookup(id)->instantiate(mgr);
}

void ast_lid::translate(global_scope& scope) {

}

void ast_lid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    auto mangled_name = this->env->get_mangled_name(id);
    into.push_back(instruction_ptr(
        (env->has_variable(mangled_name) && !this->env->is_global(id)) ?
            (instruction*) new instruction_push(env->get_offset(mangled_name)) :
            (instruction*) new instruction_pushglobal(mangled_name)));
}

void ast_uid::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "UID: " << id << std::endl;
}

void ast_uid::find_free(std::set<std::string>& into) {

}

type_ptr ast_uid::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    return env->lookup(id)->instantiate(mgr);
}

void ast_uid::translate(global_scope& scope) {

}

void ast_uid::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    into.push_back(instruction_ptr(
                new instruction_pushglobal(this->env->get_mangled_name(id))));
}

void ast_binop::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "BINOP: " << op_name(op) << std::endl;
    left->print(indent + 1, to);
    right->print(indent + 1, to);
}

void ast_binop::find_free(std::set<std::string>& into) {
    left->find_free(into);
    right->find_free(into);
}

type_ptr ast_binop::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    type_ptr ltype = left->typecheck(mgr, env);
    type_ptr rtype = right->typecheck(mgr, env);
    type_ptr ftype = env->lookup(op_name(op))->instantiate(mgr);
    if(!ftype) throw type_error(std::string("unknown binary operator ") + op_name(op));

    type_ptr return_type = mgr.new_type();
    type_ptr arrow_one = type_ptr(new type_arr(rtype, return_type));
    type_ptr arrow_two = type_ptr(new type_arr(ltype, arrow_one));

    mgr.unify(arrow_two, ftype);
    return return_type;
}

void ast_binop::translate(global_scope& scope) {
    left->translate(scope);
    right->translate(scope);
}

void ast_binop::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    right->compile(env, into);
    left->compile(env_ptr(new env_offset(1, env)), into);

    into.push_back(instruction_ptr(new instruction_pushglobal(op_action(op))));
    into.push_back(instruction_ptr(new instruction_mkapp()));
    into.push_back(instruction_ptr(new instruction_mkapp()));
}

void ast_app::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "APP:" << std::endl;
    left->print(indent + 1, to);
    right->print(indent + 1, to);
}

void ast_app::find_free(std::set<std::string>& into) {
    left->find_free(into);
    right->find_free(into);
}

type_ptr ast_app::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    type_ptr ltype = left->typecheck(mgr, env);
    type_ptr rtype = right->typecheck(mgr, env);

    type_ptr return_type = mgr.new_type();
    type_ptr arrow = type_ptr(new type_arr(rtype, return_type));
    mgr.unify(arrow, ltype);
    return return_type;
}

void ast_app::translate(global_scope& scope) {
    left->translate(scope);
    right->translate(scope);
}

void ast_app::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    right->compile(env, into);
    left->compile(env_ptr(new env_offset(1, env)), into);
    into.push_back(instruction_ptr(new instruction_mkapp()));
}

void ast_case::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "CASE: " << std::endl;
    for(auto& branch : branches) {
        print_indent(indent + 1, to);
        branch->pat->print(to);
        to << std::endl;
        branch->expr->print(indent + 2, to);
    }
}

void ast_case::find_free(std::set<std::string>& into) {
    of->find_free(into);
    for(auto& branch : branches) {
        std::set<std::string> free_in_branch;
        std::set<std::string> pattern_variables;
        branch->pat->find_variables(pattern_variables);
        branch->expr->find_free(free_in_branch);
        for(auto& free : free_in_branch) {
            if(pattern_variables.find(free) == pattern_variables.end())
                into.insert(free);
        }
    }
}

type_ptr ast_case::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    type_var* var;
    type_ptr case_type = mgr.resolve(of->typecheck(mgr, env), var);
    type_ptr branch_type = mgr.new_type();

    for(auto& branch : branches) {
        type_env_ptr new_env = type_scope(env);
        branch->pat->typecheck(case_type, mgr, new_env);
        type_ptr curr_branch_type = branch->expr->typecheck(mgr, new_env);
        mgr.unify(branch_type, curr_branch_type);
    }

    input_type = mgr.resolve(case_type, var);
    type_app* app_type;
    if(!(app_type = dynamic_cast<type_app*>(input_type.get())) ||
            !dynamic_cast<type_data*>(app_type->constructor.get())) {
        throw type_error("attempting case analysis of non-data type");
    }

    return branch_type;
}

void ast_case::translate(global_scope& scope) {
    of->translate(scope);
    for(auto& branch : branches) {
        branch->expr->translate(scope);
    }
}

void ast_case::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    type_app* app_type = dynamic_cast<type_app*>(input_type.get());
    type_data* type = dynamic_cast<type_data*>(app_type->constructor.get());

    of->compile(env, into);
    into.push_back(instruction_ptr(new instruction_eval()));

    instruction_jump* jump_instruction = new instruction_jump();
    into.push_back(instruction_ptr(jump_instruction));
    for(auto& branch : branches) {
        std::vector<instruction_ptr> branch_instructions;
        pattern_var* vpat;
        pattern_constr* cpat;

        if((vpat = dynamic_cast<pattern_var*>(branch->pat.get()))) {
            branch->expr->compile(env_ptr(new env_offset(1, env)), branch_instructions);

            for(auto& constr_pair : type->constructors) {
                if(jump_instruction->tag_mappings.find(constr_pair.second.tag) !=
                        jump_instruction->tag_mappings.end())
                    break;

                jump_instruction->tag_mappings[constr_pair.second.tag] =
                    jump_instruction->branches.size();
            }
            jump_instruction->branches.push_back(std::move(branch_instructions));
        } else if((cpat = dynamic_cast<pattern_constr*>(branch->pat.get()))) {
            env_ptr new_env = env;
            for(auto it = cpat->params.rbegin(); it != cpat->params.rend(); it++) {
                new_env = env_ptr(new env_var(branch->expr->env->get_mangled_name(*it), new_env));
            }

            branch_instructions.push_back(instruction_ptr(new instruction_split(
                            cpat->params.size())));
            branch->expr->compile(new_env, branch_instructions);
            branch_instructions.push_back(instruction_ptr(new instruction_slide(
                            cpat->params.size())));

            int new_tag = type->constructors[cpat->constr].tag;
            if(jump_instruction->tag_mappings.find(new_tag) !=
                    jump_instruction->tag_mappings.end())
                throw type_error("technically not a type error: duplicate pattern");

            jump_instruction->tag_mappings[new_tag] =
                jump_instruction->branches.size();
            jump_instruction->branches.push_back(std::move(branch_instructions));
        }
    }

    for(auto& constr_pair : type->constructors) {
        if(jump_instruction->tag_mappings.find(constr_pair.second.tag) ==
                jump_instruction->tag_mappings.end())
            throw type_error("non-total pattern");
    }
}

void ast_let::print(int indent, std::ostream& to) const {
    print_indent(indent, to);
    to << "LET: " << std::endl;
    in->print(indent + 1, to);
}

void ast_let::find_free(std::set<std::string>& into) {
    definitions.find_free(into);
    std::set<std::string> all_free;
    in->find_free(all_free);
    for(auto& free_var : all_free) {
        if(definitions.defs_defn.find(free_var) == definitions.defs_defn.end())
            into.insert(free_var);
    }
}

type_ptr ast_let::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    definitions.typecheck(mgr, env);
    return in->typecheck(mgr, definitions.env);
}

void ast_let::translate(global_scope& scope) {
    for(auto& def : definitions.defs_data) {
        def.second->into_globals(scope);
    }
    for(auto& def : definitions.defs_defn) {
        size_t original_params = def.second->params.size();
        std::string original_name = def.second->name;
        auto& global_definition = def.second->into_global(scope);
        size_t captured = global_definition.params.size() - original_params;

        type_env_ptr mangled_env = type_scope(env);
        mangled_env->bind(def.first, env->lookup(def.first), visibility::global);
        mangled_env->set_mangled_name(def.first, global_definition.name);

        ast_ptr global_app(new ast_lid(original_name));
        global_app->env = mangled_env;
        for(auto& param : global_definition.params) {
            if(!(captured--)) break;
            ast_ptr new_arg(new ast_lid(param));
            new_arg->env = env;
            global_app = ast_ptr(new ast_app(std::move(global_app), std::move(new_arg)));
            global_app->env = env;
        }
        translated_definitions.push_back({ def.first, std::move(global_app) });
    }
    in->translate(scope);
}

void ast_let::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    into.push_back(instruction_ptr(new instruction_alloc(translated_definitions.size())));
    env_ptr new_env = env;
    for(auto& def : translated_definitions) {
        new_env = env_ptr(new env_var(definitions.env->get_mangled_name(def.first), std::move(new_env)));
    }
    int offset = translated_definitions.size() - 1;
    for(auto& def : translated_definitions) {
        def.second->compile(new_env, into);
        into.push_back(instruction_ptr(new instruction_update(offset--)));
    }
    in->compile(new_env, into);
    into.push_back(instruction_ptr(new instruction_slide(translated_definitions.size())));
}

void ast_lambda::print(int indent, std::ostream&  to) const {
    print_indent(indent, to);
    to << "LAMBDA";
    for(auto& param : params) {
        to << " " << param;
    }
    to << std::endl;
    body->print(indent+1, to);
}

void ast_lambda::find_free(std::set<std::string>& into) {
    body->find_free(free_variables);
    for(auto& param : params) {
        free_variables.erase(param);
    }
    into.insert(free_variables.begin(), free_variables.end());
}

type_ptr ast_lambda::typecheck(type_mgr& mgr, type_env_ptr& env) {
    this->env = env;
    var_env = type_scope(env);
    type_ptr return_type = mgr.new_type();
    type_ptr full_type = return_type;

    for(auto it = params.rbegin(); it != params.rend(); it++) {
        type_ptr param_type = mgr.new_type();
        var_env->bind(*it, param_type);
        full_type = type_ptr(new type_arr(std::move(param_type), full_type));
    }

    mgr.unify(return_type, body->typecheck(mgr, var_env));
    return full_type;
}

void ast_lambda::translate(global_scope& scope) {
    std::vector<std::string> function_params;
    for(auto& free_variable : free_variables) {
        if(env->is_global(free_variable)) continue;
        function_params.push_back(free_variable);
    }
    size_t captured_count = function_params.size();
    function_params.insert(function_params.end(), params.begin(), params.end());

    auto& new_function = scope.add_function("lambda", std::move(function_params), std::move(body));
    type_env_ptr mangled_env = type_scope(env);
    mangled_env->bind("lambda", type_scheme_ptr(nullptr), visibility::global);
    mangled_env->set_mangled_name("lambda", new_function.name);
    ast_ptr new_application = ast_ptr(new ast_lid("lambda"));
    new_application->env = mangled_env;

    for(auto& param : new_function.params) {
        if(!(captured_count--)) break;
        ast_ptr new_arg = ast_ptr(new ast_lid(param));
        new_arg->env = env;
        new_application = ast_ptr(new ast_app(std::move(new_application), std::move(new_arg)));
        new_application->env = env;
    }
    translated = std::move(new_application);
}

void ast_lambda::compile(const env_ptr& env, std::vector<instruction_ptr>& into) const {
    translated->compile(env, into);
}

void pattern_var::print(std::ostream& to) const {
    to << var;
}

void pattern_var::find_variables(std::set<std::string>& into) const {
    into.insert(var);
}

void pattern_var::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
    env->bind(var, t);
}

void pattern_constr::print(std::ostream& to) const {
    to << constr;
    for(auto& param : params) {
        to << " " << param;
    }
}

void pattern_constr::find_variables(std::set<std::string>& into) const {
    into.insert(params.begin(), params.end());
}

void pattern_constr::typecheck(type_ptr t, type_mgr& mgr, type_env_ptr& env) const {
    type_scheme_ptr constructor_type_scheme = env->lookup(constr);
    if(!constructor_type_scheme) {
        throw type_error(std::string("pattern using unknown constructor ") + constr);
    }
    type_ptr constructor_type = constructor_type_scheme->instantiate(mgr);

    for(auto& param : params) {
        type_arr* arr = dynamic_cast<type_arr*>(constructor_type.get());
        if(!arr) throw type_error("too many parameters in constructor pattern");

        env->bind(param, arr->left);
        constructor_type = arr->right;
    }

    mgr.unify(t, constructor_type);
}