#include "ast.hpp"
#include <iostream>
#include "binop.hpp"
#include "definition.hpp"
#include "instruction.hpp"
#include "llvm_context.hpp"
#include "parser.hpp"
#include "error.hpp"
#include "type.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Target/TargetMachine.h"

void yy::parser::error(const std::string& msg) {
    std::cout << "An error occured: " << msg << std::endl;
}

extern std::vector<definition_ptr> program;

void typecheck_program(
        const std::vector<definition_ptr>& prog,
        type_mgr& mgr, type_env& env) {
    type_ptr int_type = type_ptr(new type_base("Int")); 
    type_ptr binop_type = type_ptr(new type_arr(
                int_type,
                type_ptr(new type_arr(int_type, int_type))));

    env.bind("+", binop_type);
    env.bind("-", binop_type);
    env.bind("*", binop_type);
    env.bind("/", binop_type);

    for(auto& def : prog) {
        def->typecheck_first(mgr, env);
    }

    for(auto& def : prog) {
        def->typecheck_second(mgr, env);
    }

    for(auto& pair : env.names) {
        std::cout << pair.first << ": ";
        pair.second->print(mgr, std::cout);
        std::cout << std::endl;
    }

    for(auto& def : prog) {
        def->resolve(mgr);
    }
}

void compile_program(const std::vector<definition_ptr>& prog) {
    for(auto& def : prog) {
        def->compile();

        definition_defn* defn = dynamic_cast<definition_defn*>(def.get());
        if(!defn) continue;
        for(auto& instruction : defn->instructions) {
            instruction->print(0, std::cout);
        }
        std::cout << std::endl;
    }
}

void gen_llvm_internal_op(llvm_context& ctx, binop op) {
    auto new_function = ctx.create_custom_function(op_action(op), 2);
    std::vector<instruction_ptr> instructions;
    instructions.push_back(instruction_ptr(new instruction_push(1)));
    instructions.push_back(instruction_ptr(new instruction_eval()));
    instructions.push_back(instruction_ptr(new instruction_push(1)));
    instructions.push_back(instruction_ptr(new instruction_eval()));
    instructions.push_back(instruction_ptr(new instruction_binop(op)));
    ctx.builder.SetInsertPoint(&new_function->getEntryBlock());
    for(auto& instruction : instructions) {
        instruction->gen_llvm(ctx, new_function);
    }
    ctx.builder.CreateRetVoid();
}

void output_llvm(llvm_context& ctx, const std::string& filename) {
    std::string targetTriple = llvm::sys::getDefaultTargetTriple();

    llvm::InitializeNativeTarget();
    llvm::InitializeNativeTargetAsmParser();
    llvm::InitializeNativeTargetAsmPrinter();

    std::string error;
    const llvm::Target* target =
        llvm::TargetRegistry::lookupTarget(targetTriple, error);
    if (!target) {
        std::cerr << error << std::endl;
    } else {
        std::string cpu = "generic";
        std::string features = "";
        llvm::TargetOptions options;
        llvm::TargetMachine* targetMachine =
            target->createTargetMachine(targetTriple, cpu, features,
                    options, llvm::Optional<llvm::Reloc::Model>());

        ctx.module.setDataLayout(targetMachine->createDataLayout());
        ctx.module.setTargetTriple(targetTriple);

        std::error_code ec;
        llvm::raw_fd_ostream file(filename, ec, llvm::sys::fs::F_None);
        if (ec) {
            throw 0;
        } else {
            llvm::TargetMachine::CodeGenFileType type = llvm::TargetMachine::CGFT_ObjectFile;
            llvm::legacy::PassManager pm;
            if (targetMachine->addPassesToEmitFile(pm, file, NULL, type)) {
                throw 0;
            } else {
                pm.run(ctx.module);
                file.close();
            }
        }
    }
}

void gen_llvm(const std::vector<definition_ptr>& prog) {
    llvm_context ctx;
    gen_llvm_internal_op(ctx, PLUS);
    gen_llvm_internal_op(ctx, MINUS);
    gen_llvm_internal_op(ctx, TIMES);
    gen_llvm_internal_op(ctx, DIVIDE);

    for(auto& definition : prog) {
        definition->gen_llvm_first(ctx);
    }

    for(auto& definition : prog) {
        definition->gen_llvm_second(ctx);
    }
    ctx.module.print(llvm::outs(), nullptr);
    output_llvm(ctx, "program.o");
}

int main() {
    yy::parser parser;
    type_mgr mgr;
    type_env env;

    parser.parse();
    for(auto& definition : program) {
        definition_defn* def = dynamic_cast<definition_defn*>(definition.get());
        if(!def) continue;

        std::cout << def->name;
        for(auto& param : def->params) std::cout << " " << param;
        std::cout << ":" << std::endl;

        def->body->print(1, std::cout);
    }
    try {
        typecheck_program(program, mgr, env);
        compile_program(program);
        gen_llvm(program);
    } catch(unification_error& err) {
        std::cout << "failed to unify types: " << std::endl;
        std::cout << "  (1) \033[34m";
        err.left->print(mgr, std::cout);
        std::cout << "\033[0m" << std::endl;
        std::cout << "  (2) \033[32m";
        err.right->print(mgr, std::cout);
        std::cout << "\033[0m" << std::endl;
    } catch(type_error& err) {
        std::cout << "failed to type check program: " << err.description << std::endl;
    }
}