"""
input: json file with video, audio, motion paths
output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps

preprocess:
1. assume you have a video for one speaker in folder, listed in
    -- video_a.mp4
    -- video_b.mp4
    run process_video.py to extract frames and audio
"""

import os
import json
import smplx
import torch
import igraph
import numpy as np
import subprocess
import utils.rotation_conversions as rc
from moviepy.editor import VideoClip, AudioFileClip
from tqdm import tqdm
import imageio
import tempfile
import argparse
import time

SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))


def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device="cuda"):
    bs, n, _ = motion_tensor.shape
    motion_tensor = motion_tensor.float().to(device)
    motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)

    output = smplx_model(
        betas=torch.zeros(bs * n, 300, device=device),
        transl=torch.zeros(bs * n, 3, device=device),
        expression=torch.zeros(bs * n, 100, device=device),
        jaw_pose=torch.zeros(bs * n, 3, device=device),
        global_orient=torch.zeros(bs * n, 3, device=device),
        body_pose=motion_tensor_reshaped[:, 3 : 21 * 3 + 3],
        left_hand_pose=motion_tensor_reshaped[:, 25 * 3 : 40 * 3],
        right_hand_pose=motion_tensor_reshaped[:, 40 * 3 : 55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(bs * n, 3, device=device),
        reye_pose=torch.zeros(bs * n, 3, device=device),
    )

    joints = output["joints"].reshape(bs, n, 127, 3)[:, :, :55, :]
    dt = 1 / pose_fps
    init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
    middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
    final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
    vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)

    position = joints
    rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)

    init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
    middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
    final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
    angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)

    rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)

    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": motion_tensor,
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
    }


def get_motion_reps(motion, smplx_model, pose_fps=30):
    gt_motion_tensor = motion["poses"]
    n = gt_motion_tensor.shape[0]
    bs = 1
    gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0)
    gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165)
    output = smplx_model(
        betas=torch.zeros(bs * n, 300).to(device),
        transl=torch.zeros(bs * n, 3).to(device),
        expression=torch.zeros(bs * n, 100).to(device),
        jaw_pose=torch.zeros(bs * n, 3).to(device),
        global_orient=torch.zeros(bs * n, 3).to(device),
        body_pose=gt_motion_tensor_reshaped[:, 3 : 21 * 3 + 3],
        left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3 : 40 * 3],
        right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3 : 55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(bs * n, 3).to(device),
        reye_pose=torch.zeros(bs * n, 3).to(device),
    )
    joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
    dt = 1 / pose_fps
    init_vel = (joints[1:2] - joints[0:1]) / dt
    middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
    final_vel = (joints[-1:] - joints[-2:-1]) / dt
    vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
    position = joints
    rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0]
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()

    init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt
    middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt)
    final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt
    angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3)

    rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15)
    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": motion["poses"],
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
        "trans": motion["trans"],
    }


def create_graph(json_path, smplx_model):
    fps = 30
    data_meta = json.load(open(json_path, "r"))
    graph = igraph.Graph(directed=True)
    global_i = 0
    for data_item in data_meta:
        video_path = os.path.join(data_item["video_path"], data_item["video_id"] + ".mp4")
        # audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] +  ".wav")
        motion_path = os.path.join(data_item["motion_path"], data_item["video_id"] + ".npz")
        video_id = data_item.get("video_id", "")
        motion = np.load(motion_path, allow_pickle=True)
        motion_reps = get_motion_reps(motion, smplx_model)
        position = motion_reps["position"]
        velocity = motion_reps["velocity"]
        trans = motion_reps["trans"]
        axis_angle = motion_reps["axis_angle"]
        # audio, sr = librosa.load(audio_path, sr=None)
        # audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        all_frames = []
        reader = imageio.get_reader(video_path)
        all_frames = []
        for frame in reader:
            all_frames.append(frame)
        video_frames = np.array(all_frames)
        min_frames = min(len(video_frames), position.shape[0])
        position = position[:min_frames]
        velocity = velocity[:min_frames]
        video_frames = video_frames[:min_frames]
        # print(min_frames)
        for i in tqdm(range(min_frames)):
            if i == 0:
                previous = -1
                next_node = global_i + 1
            elif i == min_frames - 1:
                previous = global_i - 1
                next_node = -1
            else:
                previous = global_i - 1
                next_node = global_i + 1
            graph.add_vertex(
                idx=global_i,
                name=video_id,
                motion=motion_reps,
                position=position[i],
                velocity=velocity[i],
                axis_angle=axis_angle[i],
                trans=trans[i],
                # audio=audio[],
                video=video_frames[i],
                previous=previous,
                next=next_node,
                frame=i,
                fps=fps,
            )
            global_i += 1
    return graph


def create_edges(graph):
    adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4]
    # print()
    for i, node in enumerate(graph.vs):
        current_position = node["position"]
        current_velocity = node["velocity"]
        current_trans = node["trans"]
        # print(current_position.shape, current_velocity.shape)
        avg_position = np.zeros(current_position.shape[0])
        avg_velocity = np.zeros(current_position.shape[0])
        avg_trans = 0
        count = 0
        for node_offset in adaptive_length:
            idx = i + node_offset
            if idx < 0 or idx >= len(graph.vs):
                continue
            if node_offset < 0:
                if graph.vs[idx]["next"] == -1:
                    continue
            else:
                if graph.vs[idx]["previous"] == -1:
                    continue
            # add check
            other_node = graph.vs[idx]
            other_position = other_node["position"]
            other_velocity = other_node["velocity"]
            other_trans = other_node["trans"]
            # print(other_position.shape, other_velocity.shape)
            avg_position += np.linalg.norm(current_position - other_position, axis=1)
            avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1)
            avg_trans += np.linalg.norm(current_trans - other_trans, axis=0)
            count += 1

        if count == 0:
            continue
        threshold_position = avg_position / count
        threshold_velocity = avg_velocity / count
        threshold_trans = avg_trans / count
        # print(threshold_position, threshold_velocity, threshold_trans)
        for j, other_node in enumerate(graph.vs):
            if i == j:
                continue
            if j == node["previous"] or j == node["next"]:
                graph.add_edge(i, j, is_continue=1)
                continue
            other_position = other_node["position"]
            other_velocity = other_node["velocity"]
            other_trans = other_node["trans"]
            position_similarity = np.linalg.norm(current_position - other_position, axis=1)
            velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1)
            trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0)
            if trans_similarity < threshold_trans:
                if np.sum(position_similarity < threshold_position) >= 45 and np.sum(velocity_similarity < threshold_velocity) >= 45:
                    graph.add_edge(i, j, is_continue=0)

    print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
    in_degrees = graph.indegree()
    out_degrees = graph.outdegree()
    avg_in_degree = sum(in_degrees) / len(in_degrees)
    avg_out_degree = sum(out_degrees) / len(out_degrees)
    print(f"Average In-degree: {avg_in_degree}")
    print(f"Average Out-degree: {avg_out_degree}")
    print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
    print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
    # igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10)
    return graph


def random_walk(graph, walk_length, start_node=None):
    if start_node is None:
        start_node = np.random.choice(graph.vs)
    walk = [start_node]
    is_continue = [1]
    for _ in range(walk_length):
        current_node = walk[-1]
        neighbor_indices = graph.neighbors(current_node.index, mode="OUT")
        if not neighbor_indices:
            break
        next_idx = np.random.choice(neighbor_indices)
        edge_id = graph.get_eid(current_node.index, next_idx)
        is_cont = graph.es[edge_id]["is_continue"]
        walk.append(graph.vs[next_idx])
        is_continue.append(is_cont)
    return walk, is_continue


def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
    all_frames = [node["video"] for node in path]
    average_dis_continue = 1 - sum(is_continue) / len(is_continue)
    if verbose_continue:
        print("average_dis_continue:", average_dis_continue)

    fps = graph.vs[0]["fps"]
    duration = len(all_frames) / fps

    def make_frame(t):
        idx = min(int(t * fps), len(all_frames) - 1)
        return all_frames[idx]

    video_only_path = f"/tmp/video_only_{time.time()}.mp4"  # Temporary file
    video_clip = VideoClip(make_frame, duration=duration)
    video_clip.write_videofile(video_only_path, codec="libx264", fps=fps, audio=False)

    # Optionally, ensure audio and video durations match
    if audio_path is not None:
        audio_clip = AudioFileClip(audio_path)
        video_duration = video_clip.duration
        audio_duration = audio_clip.duration

        if audio_duration > video_duration:
            # Trim the audio
            trimmed_audio_path = "trimmed_audio.aac"
            audio_clip = audio_clip.subclip(0, video_duration)
            audio_clip.write_audiofile(trimmed_audio_path)
            audio_input = trimmed_audio_path
        else:
            audio_input = audio_path

        # Use FFmpeg to combine video and audio
        ffmpeg_command = [
            "ffmpeg",
            "-y",
            "-i",
            video_only_path,
            "-i",
            audio_input,
            "-c:v",
            "copy",
            "-c:a",
            "aac",
            "-strict",
            "experimental",
            save_path,
        ]
        subprocess.check_call(ffmpeg_command)

        # Clean up temporary files if necessary
        os.remove(video_only_path)
        if audio_input != audio_path:
            os.remove(audio_input)

    if return_motion:
        all_motion = [node["axis_angle"] for node in path]
        all_motion = np.stack(all_motion, 0)
        return all_motion


def generate_transition_video(frame_start_path, frame_end_path, output_video_path):
    import subprocess
    import os

    # Define the path to your model and inference script
    model_path = os.path.join(SCRIPT_PATH, "frame-interpolation-pytorch/film_net_fp32.pt")
    inference_script = os.path.join(SCRIPT_PATH, "frame-interpolation-pytorch/inference.py")

    # Build the command to run the inference script
    command = [
        "python",
        inference_script,
        model_path,
        frame_start_path,
        frame_end_path,
        "--save_path",
        output_video_path,
        "--gpu",
        "--frames",
        "3",
        "--fps",
        "30",
    ]

    # Run the command
    try:
        subprocess.run(command, check=True)
        print(f"Generated transition video saved at {output_video_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error occurred while generating transition video: {e}")


def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
    """
    this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method
    """
    all_frames = [node["video"] for node in path]
    average_dis_continue = 1 - sum(is_continue) / len(is_continue)
    if verbose_continue:
        print("average_dis_continue:", average_dis_continue)
    duration = len(all_frames) / graph.vs[0]["fps"]

    # First loop: Confirm where blending is needed
    discontinuity_indices = []
    for i, cont in enumerate(is_continue):
        if cont == 0:
            discontinuity_indices.append(i)

    # Identify blending positions without overlapping
    blend_positions = []
    processed_frames = set()
    for i in discontinuity_indices:
        # Define the frames for blending: i-2 to i+2
        start_idx = i - 2
        end_idx = i + 2
        # Check index boundaries
        if start_idx < 0 or end_idx >= len(all_frames):
            continue  # Skip if indices are out of bounds
        # Check for overlapping frames
        overlap = any(idx in processed_frames for idx in range(i - 1, i + 2))
        if overlap:
            continue  # Skip if frames have been processed
        # Mark frames as processed
        processed_frames.update(range(i - 1, i + 2))
        blend_positions.append(i)

    # Second loop: Perform blending
    temp_dir = tempfile.mkdtemp(prefix="blending_frames_")
    for i in tqdm(blend_positions):
        start_frame_idx = i - 2
        end_frame_idx = i + 2
        frame_start = all_frames[start_frame_idx]
        frame_end = all_frames[end_frame_idx]
        frame_start_path = os.path.join(temp_dir, f"frame_{start_frame_idx}.png")
        frame_end_path = os.path.join(temp_dir, f"frame_{end_frame_idx}.png")
        # Save the start and end frames as images
        imageio.imwrite(frame_start_path, frame_start)
        imageio.imwrite(frame_end_path, frame_end)

        # Call FiLM API to generate video
        generated_video_path = os.path.join(temp_dir, f"generated_{start_frame_idx}_{end_frame_idx}.mp4")
        generate_transition_video(frame_start_path, frame_end_path, generated_video_path)

        # Read the generated video frames
        reader = imageio.get_reader(generated_video_path)
        generated_frames = [frame for frame in reader]
        reader.close()

        # Replace the middle three frames (i-1, i, i+1) in all_frames
        total_generated_frames = len(generated_frames)
        if total_generated_frames < 5:
            print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.")
            continue
        middle_start = 1  # Start index for middle 3 frames
        middle_frames = generated_frames[middle_start : middle_start + 3]
        for idx, frame_idx in enumerate(range(i - 1, i + 2)):
            all_frames[frame_idx] = middle_frames[idx]

    # Create the video clip
    def make_frame(t):
        idx = min(int(t * graph.vs[0]["fps"]), len(all_frames) - 1)
        return all_frames[idx]

    video_clip = VideoClip(make_frame, duration=duration)
    if audio_path is not None:
        audio_clip = AudioFileClip(audio_path)
        video_clip = video_clip.set_audio(audio_clip)
    video_clip.write_videofile(save_path, codec="libx264", fps=graph.vs[0]["fps"], audio_codec="aac")

    if return_motion:
        all_motion = [node["axis_angle"] for node in path]
        all_motion = np.stack(all_motion, 0)
        return all_motion


def graph_pruning(graph):
    ascc = graph.clusters(mode="STRONG")
    lascc = ascc.giant()
    print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}")
    print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}")
    in_degrees = lascc.indegree()
    out_degrees = lascc.outdegree()
    avg_in_degree = sum(in_degrees) / len(in_degrees)
    avg_out_degree = sum(out_degrees) / len(out_degrees)
    print(f"Average In-degree: {avg_in_degree}")
    print(f"Average Out-degree: {avg_out_degree}")
    print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
    print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
    return lascc


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--json_save_path", type=str, default="")
    parser.add_argument("--graph_save_path", type=str, default="")
    args = parser.parse_args()
    json_path = args.json_save_path
    print("json_path", json_path)
    graph_path = args.graph_save_path

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    smplx_model = (
        smplx.create(
            os.path.join(SCRIPT_PATH, "emage/smplx_models/"),
            model_type="smplx",
            gender="NEUTRAL_2020",
            use_face_contour=False,
            num_betas=300,
            num_expression_coeffs=100,
            ext="npz",
            use_pca=False,
        )
        .to(device)
        .eval()
    )

    # single_test
    # graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json')
    graph = create_graph(json_path, smplx_model)
    graph = create_edges(graph)
    # pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl"
    # graph = igraph.Graph.Read_Pickle(fname=pool_path)
    # graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl")

    walk, is_continue = random_walk(graph, 100)
    motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
    # print(motion.shape)
    save_graph = graph.write_pickle(fname=graph_path)
    graph = graph_pruning(graph)

    # show-oliver
    # json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/"
    # pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/"
    # for json_file in tqdm(os.listdir(json_path)):
    #     graph = create_graph(os.path.join(json_path, json_file))
    #     graph = create_edges(graph)
    #     if not len(graph.vs) >= 1500:
    #         print(f"skip: {len(graph.vs)}", json_file)
    #     graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl"))
    #     print(f"Graph saved at {json_file.split('.')[0]}.pkl")