import os
# from tree_sitter import Language, Parser
# # import pandas as pd
# import openpyxl
import json
import time
import csv
import pathlib
import difflib
import re
from bleu import _bleu
from fuzzywuzzy import fuzz
import random
import numpy as np
from transformers import RobertaTokenizer
#tokens = nltk.word_tokenize(sentence)

folder = str(pathlib.Path(__file__).parent.resolve())
isa_type_dir = folder+"/../../../Dataset"
src_dir = folder+"/../../../Dataset/Code_Generation"
dst_dir = folder

train_lis = []
valid_lis = []
test_lis = []

target_clf = {}
def get_target_clf_list():
    global target_clf
    with open(isa_type_dir+"/comback_isa_type.csv","r",encoding="utf-8") as f:
        reader = csv.reader(f)
        for idx, l in enumerate(reader):
            if l[1].lower() == "arc" or l[1].lower() == "riscv" or l[1].lower() == "nvptx":
                continue
            if l[0] + " " + l[2] not in target_clf.keys():
                target_clf[l[0] + " " + l[2]] = [l[1]]
            else:
                target_clf[l[0] + " " + l[2]] += [l[1]]


def Calculate_Statements_Ratio(Src_List, Fork_Lis, src_name, fork_name):
    src_code = ""
    Fork_code = ""
    idx = 0
    cnt_stmt = 0.0
    while idx < len(Src_List):
        src_code += Src_List[idx].replace(src_name, "").replace(src_name.upper(), "")
        if Src_List[idx] in [";", ":", "{", "}"]:
            src_code += "\n"
            cnt_stmt += 1
        idx += 1
    while idx < len(Fork_Lis):
        Fork_code += Fork_Lis[idx].replace(fork_name, "").replace(fork_name.upper(), "")
        if Fork_Lis[idx] in [";", ":", "{", "}"]:
            Fork_code += "\n"
        idx += 1
        
    code_same = 0
    code_modi = 0
    code_add = 0
    diff_code = list(difflib.Differ().compare(src_code.splitlines(), Fork_code.splitlines()))
    for idx, dv in enumerate(diff_code):
        if dv[0] == '-':
            if idx < len(diff_code) - 1 and diff_code[idx+1][0] == '?':
                code_modi += 1
            else:
                code_add += 1
        elif dv[0] == '+':
            continue
        elif dv[0] == '?':
            continue
                #vega_add -= 1
        elif dv.strip().replace("\n", "") == '':
            continue
        else:
            code_same += 1
    return round(float(code_same) / cnt_stmt, 2)



def Calculate_Gen():
    get_target_clf_list()
    print("############## Exp 2: Calculate Code-LLaMA Gen ################\n")
    
    test_lis = ["nvptx","arc","riscv"]

    avg_accuracy = {}
    codellama_gcc_code = {}
    codellama_llvm_code = {}
    dst_file = dst_dir+"/Input/codellama_gen_output_cleaned.csv"
    with open(dst_file,encoding="utf-8") as f:
        reader = csv.reader(f)
        for idx, row in enumerate(reader):
            if row[0] == "GCC":
                codellama_gcc_code[row[1] + " " + str(row[2])] = row[3]
            else:
                codellama_llvm_code[row[1] + " " + str(row[2])] = row[3]

    for comp_type in ["GCC", "LLVM"]:
        for isa_type in ["GPU", "MPU", "CPU"]:
            target_lis = target_clf[comp_type + " " + isa_type]
            test_target_dic = {}
            cnt_idx = 0
            if comp_type == "GCC":
                if isa_type == "CPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/riscv.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["riscv" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("riscv", "")
                        if k in codellama_gcc_code.keys():
                            chat_code = " ".join(codellama_gcc_code[k]).replace("riscv", "").replace("RISCV", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_gcc_code[k], "riscv", "riscv")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "riscv", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "riscv", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "riscv"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]
                if isa_type == "GPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/nvptx.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["nvptx" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("nvptx", "")
                        if k in codellama_gcc_code.keys():
                            chat_code = " ".join(codellama_gcc_code[k]).replace("nvptx", "").replace("NVPTX", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_gcc_code[k], "nvptx", "nvptx")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "nvptx", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "nvptx", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "nvptx"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]

                if isa_type == "MPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/arc.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["arc" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("arc", "")
                        if k in codellama_gcc_code.keys():
                            chat_code = " ".join(codellama_gcc_code[k]).replace("arc", "").replace("ARC", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_gcc_code[k], "arc", "arc")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "arc", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "arc", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "arc"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]

            if comp_type == "LLVM":
                if isa_type == "CPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/RISCV.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["RISCV" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("RISCV", "")
                        if k in codellama_llvm_code.keys():
                            chat_code = " ".join(codellama_llvm_code[k]).replace("riscv", "").replace("RISCV", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_llvm_code[k], "riscv", "riscv")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "RISCV", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "RISCV", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "RISCV"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]
                if isa_type == "GPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/NVPTX.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["NVPTX" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("NVPTX", "")
                        if k in codellama_llvm_code.keys():
                            chat_code = " ".join(codellama_llvm_code[k]).replace("nvptx", "").replace("NVPTX", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_llvm_code[k], "nvptx", "nvptx")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "NVPTX", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "NVPTX", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "NVPTX"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]
                
                if isa_type == "MPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/ARC.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["ARC" + " " + str(cnt_idx)] = dic["ground_truth"]
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    total_PoVS = 0.0
                    total_BLEU4 = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        bleu4 = 0.0
                        stmt_mod = 0.0
                        src_code = " ".join(test_target_dic[k]).replace("ARC", "")
                        if k in codellama_llvm_code.keys():
                            chat_code = " ".join(codellama_llvm_code[k]).replace("arc", "").replace("ARC", "")
                            stmt_mod = Calculate_Statements_Ratio(test_target_dic[k], codellama_llvm_code[k], "arc", "arc")
                            with open(dst_dir+"/test.output",'w') as f, open(dst_dir+"/test.gold",'w') as f1:
                                f.write(chat_code+'\n')
                                f1.write(src_code+'\n')     
                                if chat_code==src_code:
                                    EM = 1
                                edit_dis = fuzz.ratio(chat_code, src_code)
                            if chat_code.strip() == "":
                                bleu4 = 0
                            else:
                                bleu4 = _bleu(dst_dir+"/test.gold", dst_dir+"/test.output")
                            total_BLEU4 += bleu4
                            total_ED += edit_dis
                            total_PoVS += stmt_mod
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "ARC", k.split(" ")[1], str(round(float(bleu4),2)), str(round(EM*100,2)), str(round(float(edit_dis),2)), str(round(float(stmt_mod)*100,2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "ARC", "average", str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))])
                        avg_accuracy[comp_type + " " + "ARC"] = [str(round(float(total_BLEU4 / cnt_idx),2)), str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2)), str(round(float(total_PoVS / cnt_idx)*100,2))]
    return avg_accuracy





if __name__ == "__main__":
    with open(dst_dir + '/result.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Compiler Type", "Target", "Idx", "BLEU4", "Exact Match", "Edit Didtance", "Stmt_Ratio"])

    avg_dic = Calculate_Gen()

    for k in avg_dic:
        print("########################")
        
        print(k)
        print(" ".join(["BLEU4", "Exact Match", "Edit Didtance", "Stmt_Ratio"]))
        print(" ".join(avg_dic[k]))