#include "ast.hpp"
#include "error.hpp"

void definition_defn::typecheck_first(type_mgr& mgr, type_env& env) {
    return_type = mgr.new_type();
    type_ptr full_type = return_type;

    for(auto it = params.rbegin(); it != params.rend(); it++) {
        type_ptr param_type = mgr.new_type();
        full_type = type_ptr(new type_arr(param_type, full_type));
        param_types.push_back(param_type);
    }

    env.bind(name, full_type);
}

void definition_defn::typecheck_second(type_mgr& mgr, const type_env& env) const {
    type_env new_env = env.scope();
    auto param_it = params.begin();
    auto type_it = param_types.rbegin();

    while(param_it != params.end() && type_it != param_types.rend()) {
        new_env.bind(*param_it, *type_it);
        param_it++;
        type_it++;
    }

    type_ptr body_type = body->typecheck_common(mgr, new_env);
    mgr.unify(return_type, body_type);
}

void definition_defn::resolve(const type_mgr& mgr) {
    type_var* var;
    body->resolve_common(mgr);

    return_type = mgr.resolve(return_type, var);
    if(var) throw type_error("ambiguously typed program");
    for(auto& param_type : param_types) {
        param_type = mgr.resolve(param_type, var);
        if(var) throw type_error("ambiguously typed program");
    }
}

void definition_defn::compile() {
    env_ptr new_env = env_ptr(new env_offset(0, nullptr));
    for(auto it = params.rbegin(); it != params.rend(); it++) {
        new_env = env_ptr(new env_var(*it, new_env));
    }
    body->compile(new_env, instructions);
    instructions.push_back(instruction_ptr(new instruction_update(params.size())));
    instructions.push_back(instruction_ptr(new instruction_pop(params.size())));
}

void definition_data::typecheck_first(type_mgr& mgr, type_env& env) {
    type_data* this_type = new type_data(name);
    type_ptr return_type = type_ptr(this_type);
    int next_tag = 0;

    for(auto& constructor : constructors) {
        constructor->tag = next_tag;
        this_type->constructors[constructor->name] = { next_tag++ };

        type_ptr full_type = return_type;
        for(auto it = constructor->types.rbegin(); it != constructor->types.rend(); it++) {
            type_ptr type = type_ptr(new type_base(*it));
            full_type = type_ptr(new type_arr(type, full_type));
        }

        env.bind(constructor->name, full_type);
    }
}

void definition_data::typecheck_second(type_mgr& mgr, const type_env& env) const {
    // Nothing
}

void definition_data::resolve(const type_mgr& mgr) {
    // Nothing
}

void definition_data::compile() {

}