import os
import glob
import argparse
from code_efficiency_calculator import run_model_task


def calculate_memory_usage(dat_file_path):
    with open(dat_file_path, 'r') as file:
        prev_time = 0
        prev_mem_mb = 0
        mem_time_mb_s = 0
        next(file)
        for line in file:
            if "__main__." in line:
                continue
            parts = line.split()
            mem_in_mb = float(parts[1])
            timestamp = float(parts[2])
            if prev_time > 0:
                time_interval_s = timestamp - prev_time
                mem_time_mb_s += (prev_mem_mb + mem_in_mb) / 2 * time_interval_s
            prev_time = timestamp
            prev_mem_mb = mem_in_mb
        return mem_time_mb_s


def calculate_runtime(dat_file_path):
    with open(dat_file_path, 'r') as file:
        start_time = float("inf")
        end_time = float("-inf")
        next(file)
        for line in file:
            if "__main__." in line:
                continue
            parts = line.split()
            timestamp = float(parts[2])
            start_time = min(start_time, timestamp)
            end_time = max(end_time, timestamp)
        return max(end_time - start_time,0)

def report_max_memory_usage(dat_file_path):
    max_memory_usage = 0
    with open(dat_file_path, 'r') as file:
        next(file)
        for line in file:
            if "__main__." in line:
                continue
            parts = line.split()
            mem_in_mb = float(parts[1])
            max_memory_usage = max(max_memory_usage, mem_in_mb)
        return max_memory_usage

def report_results(task, model, file):
    run_model_task(task, model, file)
    dat_directory = f"./results/{task}_{model}"
    canonical_solution_directory = f"./results/{task}_canonical_solution"
    canonical_solution_memory_usage = {}
    canonical_solution_execution_time = {}
    canonical_solution_max_memory_usage = {}
    for dat_file in glob.glob(os.path.join(canonical_solution_directory, "*.dat")):
        try:
            problem_idx = os.path.basename(dat_file).split('.')[0]
            canonical_solution_memory_usage[int(problem_idx)] = calculate_memory_usage(dat_file)
            canonical_solution_execution_time[int(problem_idx)] = calculate_runtime(dat_file)
            canonical_solution_max_memory_usage[int(problem_idx)] = report_max_memory_usage(dat_file)
        except:
            pass


    global_result = {}


    completion_memory_usage = {}
    execution_time = {}
    max_memory_usage = {}
    task_idx = {}
    for dat_file in glob.glob(os.path.join(dat_directory, "*.dat")):
        try:
            problem_idx = os.path.basename(dat_file).split('.')[0]
            execution_time_result = calculate_runtime(dat_file)
            completion_memory_usage[int(problem_idx)] = calculate_memory_usage(dat_file)
            execution_time[int(problem_idx)] = calculate_runtime(dat_file)
            max_memory_usage[int(problem_idx)] = report_max_memory_usage(dat_file)
            task_idx[int(problem_idx)] = dat_file
        except Exception as e:
            print(dat_file)
    global_result[model] = {"completion_memory_usage":completion_memory_usage,"execution_time":execution_time,"max_memory_usage":max_memory_usage,"task_idx":task_idx}



    save_results = []
    max_net_lists = {}
    max_nmu_lists = {}
    max_ntmu_lists = {}

    for model in global_result.keys():
        completion_memory_usage = global_result[model]["completion_memory_usage"]
        execution_time = global_result[model]["execution_time"]
        max_memory_usage = global_result[model]["max_memory_usage"]

        # report execution time
        total_execution_time = 0

        # report normalized execution time
        normalized_execution_time = 0

        # report max memory usage
        total_max_memory_usage = 0

        # report normalized max memory usage
        normalized_max_memory_usage = 0

        # report memory usage
        total_memory_usage = 0
        total_canonical_solution_max_memory_usage = 0
        total_canonical_solution_execution_time = 0
        total_canonical_solution_memory_usage = 0
        # report normalized memory usage
        normalized_memory_usage = 0
        total_codes = 0
        normalized_execution_time_list = []
        normalized_max_memory_usage_list = []
        normalized_memory_usage_list = []
        total_fast = 0
        total_95 = 0
        total_97=0
        total_99=0
        total_100=0
        total_101=0
        total_1000=0
        total_500=0
        category_tmp = {}
        total_10000=0
        max_net = float("-inf")
        max_nmu = float("-inf")
        max_tmu = float("-inf")

        total_500_net = 0
        total_500_nmu = 0
        total_500_tmu = 0
        # print(len(completion_memory_usage))
        for idx in completion_memory_usage.keys():
            if idx not in canonical_solution_memory_usage.keys():
                continue


            total_memory_usage += completion_memory_usage[idx]
            total_execution_time += execution_time[idx]
            total_max_memory_usage += max_memory_usage[idx]
            total_canonical_solution_max_memory_usage+=canonical_solution_max_memory_usage[idx]
            total_canonical_solution_memory_usage+=canonical_solution_memory_usage[idx]
            total_canonical_solution_execution_time+=canonical_solution_execution_time[idx]
            if execution_time[idx]/canonical_solution_execution_time[idx]>5:
                total_500_net+=1
            if max_net<execution_time[idx]/canonical_solution_execution_time[idx]:
                max_net = execution_time[idx]/canonical_solution_execution_time[idx]
            normalized_execution_time += execution_time[idx]/canonical_solution_execution_time[idx]
            normalized_execution_time_list.append(execution_time[idx]/canonical_solution_execution_time[idx])
            if max_memory_usage[idx]/canonical_solution_max_memory_usage[idx]>5:
                total_500_nmu+=1
            if max_nmu<max_memory_usage[idx]/canonical_solution_max_memory_usage[idx]:
                max_nmu = max_memory_usage[idx]/canonical_solution_max_memory_usage[idx]
            normalized_max_memory_usage += max_memory_usage[idx]/canonical_solution_max_memory_usage[idx]
            normalized_max_memory_usage_list.append(max_memory_usage[idx]/canonical_solution_max_memory_usage[idx])

            if completion_memory_usage[idx]/canonical_solution_memory_usage[idx]>5:
                total_500_tmu+=1
            net = execution_time[idx] / canonical_solution_execution_time[idx]
            nmu = completion_memory_usage[idx] / canonical_solution_memory_usage[idx]
            ntmu = max_memory_usage[idx] / canonical_solution_max_memory_usage[idx]
            normalized_memory_usage += completion_memory_usage[idx]/canonical_solution_memory_usage[idx]
            normalized_memory_usage_list.append(completion_memory_usage[idx]/canonical_solution_memory_usage[idx])

            if len(max_net_lists) < 10 or net > min(max_net_lists.keys()):
                if len(max_net_lists) >= 10:
                    min_key = min(max_net_lists.keys())
                    del max_net_lists[min_key]
                max_net_lists[net] = (model, idx)

            if len(max_nmu_lists) < 10 or nmu > min(max_nmu_lists.keys()):
                if len(max_nmu_lists) >= 10:
                    min_key = min(max_nmu_lists.keys())
                    del max_nmu_lists[min_key]
                max_nmu_lists[nmu] = (model, idx)

            if len(max_ntmu_lists) < 10 or ntmu > min(max_ntmu_lists.keys()):
                if len(max_ntmu_lists) >= 10:
                    min_key = min(max_ntmu_lists.keys())
                    del max_ntmu_lists[min_key]
                max_ntmu_lists[ntmu] = (model, idx)
            max_tmu = max(max_tmu,completion_memory_usage[idx]/canonical_solution_memory_usage[idx])
            total_codes+=1

        if len(normalized_execution_time_list)==0:
            print(model)
            continue
        normalized_execution_time = normalized_execution_time/len(normalized_execution_time_list)
        normalized_max_memory_usage = normalized_max_memory_usage/len(normalized_execution_time_list)
        normalized_memory_usage = normalized_memory_usage/len(normalized_execution_time_list)
        total_execution_time = total_execution_time/len(normalized_execution_time_list)
        total_memory_usage = total_memory_usage/len(normalized_execution_time_list)
        total_max_memory_usage = total_max_memory_usage/len(normalized_execution_time_list)

        pass1 = len(completion_memory_usage)/1000*100

        total_500_net = total_500_net/len(normalized_execution_time_list)*100
        total_500_nmu = total_500_nmu/len(normalized_execution_time_list)*100
        total_500_tmu = total_500_tmu/len(normalized_execution_time_list)*100

        return f"{model}&{total_execution_time:.2f}&{normalized_execution_time:.2f}&{max_net:.2f}&{total_500_net:.1f}&{total_max_memory_usage:.2f}&{normalized_max_memory_usage:.2f}&{max_nmu:.2f}&{total_500_nmu:.1f}&{total_memory_usage:.2f}&{normalized_memory_usage:.2f}&{max_tmu:.2f}&{total_500_tmu:.1f}&{pass1:.1f}\\\\"


if __name__ == "__main__":
    parse = argparse.ArgumentParser()
    parse.add_argument("--task", type=str, default="EffiBench")
    parse.add_argument("--model", type=str, default="gpt-4")
    parse.add_argument("--file", type=str, default="")
    args = parse.parse_args()

    if not args.file:
        args.file = f"./{args.task}_{args.model}.json"
    
    report_results(args.task,args.model, args.file)