import cv2, math
import json, os, torch
import numpy as np
from sklearn.preprocessing import Normalizer
from align import align_filter


def merge_intervals_with_breaks(time_intervals, errors, max_break=1.5):
    print(f"时间区间: {time_intervals}")
    print(f"错误: {errors}")

    if not time_intervals:
        return []

    # Sort intervals based on starting times (not necessary here as input is sorted but good practice)
    sorted_intervals = sorted(zip(time_intervals, errors), key=lambda x: x[0][0])
    
    merged_intervals = []
    current_interval, current_error = sorted_intervals[0]

    for (start, end), error in sorted_intervals[1:]:
        # Check if the current interval error is the same and the break between intervals is <= 1.5 seconds
        if error == current_error and start - current_interval[1] <= max_break:
            # Merge intervals
            current_interval = (round(current_interval[0]), round(max(current_interval[1], end)))
        else:
            # Save the completed interval
            merged_intervals.append(((round(current_interval[0]), round(current_interval[1])), current_error))
            # merged_intervals.append((current_interval, current_error))
            # Start a new interval
            current_interval, current_error = (round(start), round(end)), error

    # Add the last interval
    merged_intervals.append((current_interval, current_error))
    
    return merged_intervals
def findcos_single(k1, k2):
    u1 = np.array(k1).reshape(-1, 1)
    u2 = np.array(k2).reshape(-1, 1)
    source_representation, test_representation = u1, u2
    a = np.matmul(np.transpose(source_representation), test_representation)
    b = np.sum(np.multiply(source_representation, source_representation))
    c = np.sum(np.multiply(test_representation, test_representation))
    # return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
    cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
    return 100 * (1 - (1 - cosine_similarity) / 2), 0


def findCosineSimilarity_1(keypoints1, keypoints2):
    # transformer = Normalizer().fit(keypoints1)  
    # keypoints1 = transformer.transform(keypoints1)
    user1 = np.concatenate((keypoints1[5:13], keypoints1[91:133]), axis=0).reshape(-1, 1)

    # transformer = Normalizer().fit(keypoints2)  
    # keypoints2 = transformer.transform(keypoints2)
    user2 = np.concatenate((keypoints2[5:13], keypoints2[91:133]), axis=0).reshape(-1, 1)

    ####ZIYU
    source_representation, test_representation = user1, user2
    a = np.matmul(np.transpose(source_representation), test_representation)
    b = np.sum(np.multiply(source_representation, source_representation))
    c = np.sum(np.multiply(test_representation, test_representation))
    # return 1 - (a / (np.sqrt(b) * np.sqrt(c)))
    cosine_similarity = a / (np.sqrt(b) * np.sqrt(c))
    return 100 * (1 - (1 - cosine_similarity) / 2), 0

def load_json(path):
    with open(path, 'r') as file:
        return json.load(file)

def eval(test, standard, tmpdir):
    test_p = tmpdir + "/user.mp4"
    standard_p = tmpdir + "/standard.mp4"
    os.system('python inferencer_demo.py ' + test_p + ' --pred-out-dir ' + tmpdir) # produce user.json

    scores = []

    align_filter(tmpdir + '/standard', tmpdir + '/user', tmpdir) # 帧对齐 produce aligned vedios

    data_00 = load_json(tmpdir + '/standard.json') 
    data_01 = load_json(tmpdir + '/user.json')
    cap_00 = cv2.VideoCapture(standard_p)
    cap_01 = cv2.VideoCapture(test_p)
    # Define keypoint connections for both videos (example indices, you'll need to customize)
    connections1 = [(9,11), (7,9), (6,7), (6,8), (8,10), (7,13), (6,12), (12,13)]
    connections2 = [(130,133), (126,129), (122,125), (118,121), (114,117), (93,96), (97,100), (101,104), (105,108), (109,112)]

    # Determine the minimum length of JSON data to use
    min_length = min(len(data_00), len(data_01))

    frame_width = int(cap_00.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap_00.get(cv2.CAP_PROP_FRAME_HEIGHT))

    out = cv2.VideoWriter(tmpdir + '/output.mp4', cv2.VideoWriter_fourcc(*'H264'), 5, (frame_width*2, frame_height*2))

    cap_00.set(cv2.CAP_PROP_POS_FRAMES, 0) # 初始化视频从头开始读取
    cap_01.set(cv2.CAP_PROP_POS_FRAMES, 0)
    comments = -1
    error_dict = {}
    cnt = 0

    line_width = 1 if frame_width // 300 == 0 else frame_width // 300 
    # 开始逐帧处理两个视频
    while True:
        ret_00, frame_00 = cap_00.read() # 逐帧读取标准视频和用户视频的当前帧
        ret_01, frame_01 = cap_01.read()
        if not ret_00 and ret_01:
            comments = 0  #.append("请尝试加快手势的完成速度,并确保每个动作都清晰可见。")
            break  # Stop if either video runs out of frames
        elif ret_00 and not ret_01:
            comments = 1 #.append("请尝试放慢手势的完成速度,确保每个动作都清晰可见。")
            break  # Stop if either video runs out of frames
        elif not ret_00 and not ret_01:
            comments = 2
            break
        combined_frame_ori = np.hstack((frame_00, frame_01))

        # 获取视频当前的帧号
        frame_id_00 = int(cap_00.get(cv2.CAP_PROP_POS_FRAMES))
        frame_id_01 = int(cap_01.get(cv2.CAP_PROP_POS_FRAMES))

        # 处理标准视频中的关键点,并绘制关键点连接
        if frame_id_00 < min_length:
            keypoints_00 = data_00[frame_id_00]["instances"][0]["keypoints"]

            for (start, end) in connections1:
                start = start - 1
                end = end - 1
                if start < len(keypoints_00) and end < len(keypoints_00):
                    start_point = (int(keypoints_00[start][0]), int(keypoints_00[start][1]))
                    end_point = (int(keypoints_00[end][0]), int(keypoints_00[end][1]))
                    cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width)  # (BGR) Blue line
            for (start, end) in connections2:
                start = start - 1
                end = end - 1
                for i in range(start, end):
                    if i < len(keypoints_00) and i + 1 < len(keypoints_00):
                        start_point = (int(keypoints_00[i][0]), int(keypoints_00[i][1]))
                        end_point = (int(keypoints_00[i + 1][0]), int(keypoints_00[i + 1][1]))
                        cv2.line(frame_00, start_point, end_point, (255, 0, 0), line_width)  # Blue line
            
            # 将部分关键点保存并绘制圆点,标记关键位置
            keypoints_00_ori = keypoints_00
            keypoints_00 = keypoints_00[5:13] + keypoints_00[91:133]

            for point in keypoints_00:
                cv2.circle(frame_00, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)

            
        # 处理用户视频中的关键点,并进行误差分析
        if frame_id_01 < min_length:
            error = []
            bigerror = []
            keypoints_01 = data_01[frame_id_01]["instances"][0]["keypoints"]

            for (start, end) in connections1:
                start = start - 1
                end = end - 1
                if start < len(keypoints_01) and end < len(keypoints_01):
                    start_point = (int(keypoints_01[start][0]), int(keypoints_01[start][1]))
                    end_point = (int(keypoints_01[end][0]), int(keypoints_01[end][1]))
                    cur_score = findcos_single([[int(keypoints_01[start][0]), int(keypoints_01[start][1])], [int(keypoints_01[end][0]), int(keypoints_01[end][1])]], [[int(keypoints_00_ori[start][0]), int(keypoints_00_ori[start][1])], [int(keypoints_00_ori[end][0]), int(keypoints_00_ori[end][1])]])

                     # 如果当前相似度小于 99.3,认为有误差,并记录下来
                    if float(cur_score[0]) < 98.8 and start != 5:
                        error.append(start)
                        cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2)  # Red line
                        # 如果相似度低于 98,记录为大误差
                        if float(cur_score[0]) < 97.8:
                            bigerror.append(start)
                    else:
                        cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width)  # Blue line

            for (start, end) in connections2:
                start = start - 1
                end = end - 1
                for i in range(start, end):
                    if i < len(keypoints_01) and i + 1 < len(keypoints_01):
                        start_point = (int(keypoints_01[i][0]), int(keypoints_01[i][1]))
                        end_point = (int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1]))

                        cur_score = findcos_single([[int(keypoints_01[i][0]), int(keypoints_01[i][1])], [int(keypoints_01[i + 1][0]), int(keypoints_01[i + 1][1])]], [[int(keypoints_00_ori[i][0]), int(keypoints_00_ori[i][1])], [int(keypoints_00_ori[i + 1][0]), int(keypoints_00_ori[i + 1][1])]])

                        if float(cur_score[0]) < 98.8:
                            error.append(start)
                            cv2.line(frame_01, start_point, end_point, (0, 0, 255), 2)  # Red line
                            if float(cur_score[0]) < 97.8:
                                bigerror.append(start)
                        else:
                            cv2.line(frame_01, start_point, end_point, (255, 0, 0), line_width)  # Blue line
            
            # 将用户视频的关键点绘制为圆点
            keypoints_01 = keypoints_01[5:13] + keypoints_01[91:133]

            for point in keypoints_01:
                cv2.circle(frame_01, (int(point[0]), int(point[1])), 1, (0, 210, 0), -1)
            
        # Concatenate the images horizontally to display side by side
        combined_frame = np.hstack((frame_00, frame_01))

        if frame_id_00 < min_length and frame_id_01 < min_length:
            min_cos, min_idx = findCosineSimilarity_1(data_00[frame_id_00]["instances"][0]["keypoints"], data_01[frame_id_01]["instances"][0]["keypoints"])

         # 如果存在误差,将误差部分对应的人体部位加入内容列表
        if error != []:
            # print(error)
            content = []
            for i in error:
                if i in [5,7]: content.append('Left Arm')
                if i in [6,8]: content.append('Right Arm')
                if i > 90 and i < 112: content.append('Left Hand')
                if i >= 112: content.append('Right Hand')
            part = ""

            # 在视频帧上显示检测到的误差部位
            cv2.putText(combined_frame, "Please check: ", (int(frame_width*1.75), int(frame_height*0.2)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
            start_x = int(frame_width*1.75) + 10   #435 # 起始的 x 坐标
            start_y = int(frame_height*0.2) + 50 # 45
            line_height = 50 # 每一行文字的高度

            # 将每一个部位的内容绘制到帧上
            for i, item in enumerate(list(set(content))):    
                text = "- " + item
                y_position = start_y + i * line_height
                cv2.putText(combined_frame, text, (start_x, y_position), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)

        # big
        if bigerror != []:
            bigcontent = []
            for i in bigerror:
                if i in [5,7]: bigcontent.append('Left Arm')
                if i in [6,8]: bigcontent.append('Right Arm')
                if i > 90 and i < 112: bigcontent.append('Left Hand')
                if i >= 112: bigcontent.append('Right Hand')
            
            # 记录当前帧的严重误差部位,存入 error_dict 中
            error_dict[cnt] = list(set(bigcontent))

        cnt += 1
        combined_frame = np.vstack((combined_frame_ori, combined_frame))
        out.write(combined_frame)
        scores.append(float(min_cos)) # 记录每一帧的相似度得分

    fps = 5  # Frames per second
    frame_numbers = list(error_dict.keys())  # List of frame numbers 获取含有严重误差的帧号列表
    time_intervals = [(frame / fps, (frame + 1) / fps) for frame in frame_numbers] # 将帧号转换为时间区间(秒)
    errors = [error_dict[frame] for frame in frame_numbers] # 每一帧对应的严重误差部位
    final_merged_intervals = merge_intervals_with_breaks(time_intervals, errors) # 合并相邻或相近的时间区间,并记录对应的误差部位
    out.release()

    # 返回三个结果:
    # 1. scores 的平均值,作为整体手势相似度的评分
    # 2. final_merged_intervals,合并后的误差时间区间及其对应的误差信息
    # 3. comments,用于给用户的速度建议(加快或放慢手势)
    return sum(scores) / len(scores), final_merged_intervals, comments

def install():
    # if torch.cuda.is_available():
    #     cu_version = torch.version.cuda
    #     cu_version = f"cu{cu_version.replace('.', '')}"  # Format it as 'cuXX' (e.g., 'cu113')
    # else:
    #     cu_version = "cpu"  # Fallback to CPU if no CUDA is available

    # torch_version = torch.__version__.split('+')[0]  # Get PyTorch version without build info

    # pip_command = f'pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html'


    # os.system(pip_command)
    import subprocess
    subprocess.run(["pip", "uninstall", "-y", "numpy"], check=True)
    subprocess.run(["pip", "install", "numpy<2"], check=True)
    
    os.system('mim install mmengine')
    # os.system('mim install "mmcv"')
    # os.system('mim install "mmdet"')
    # os.system('mim install "mmpose"')
    # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html"')
    # os.system('pip3 install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html')
    
    os.system('git clone https://github.com/open-mmlab/mmpose.git')
    os.chdir('mmpose')
    os.system('pip install -r requirements.txt')
    os.system('pip install -v -e .')
    os.chdir('../')

    os.system('git clone https://github.com/open-mmlab/mmdetection.git')
    os.chdir('mmdetection')
    os.system('pip install -v -e .')
    os.chdir('../')
    # os.system('mim install "mmpose>=1.1.0"')