import gradio as gr
import subprocess
import os
import cv2
from huggingface_hub import hf_hub_download
import glob
from datetime import datetime

# Ensure 'checkpoint' directory exists
os.makedirs("checkpoint", exist_ok=True)

hf_hub_download(
    repo_id="fffiloni/X-Portrait",
    filename="model_state-415001.th",
    local_dir="checkpoint"
)

def extract_frames_with_labels(video_path, base_output_dir="frames"):
    # Generate a timestamped folder name
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(base_output_dir, f"frames_{timestamp}")
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Open the video file
    video_capture = cv2.VideoCapture(video_path)
    if not video_capture.isOpened():
        raise ValueError(f"Cannot open video file: {video_path}")
    
    frame_data = []
    frame_index = 0
    
    # Loop through the video frames
    while True:
        ret, frame = video_capture.read()
        if not ret:
            break  # Exit the loop if there are no frames left to read

        # Zero-padded frame index for filename and label
        frame_label = f"{frame_index:04}"
        frame_filename = os.path.join(output_dir, f"frame_{frame_label}.jpg")
        
        # Save the frame as a .jpg file
        cv2.imwrite(frame_filename, frame)
        
        # Append the tuple (filename, label) to the list
        frame_data.append((frame_filename, frame_label))
        
        # Increment frame index
        frame_index += 1
    
    # Release the video capture object
    video_capture.release()
    
    return frame_data

# Define a function to run your script with selected inputs
def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps):

    # Create a unique output directory name based on current date and time
    output_dir_base = "outputs"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    model_config = "config/cldm_v15_appearance_pose_local_mm.yaml"
    resume_dir = "checkpoint/model_state-415001.th"
    
    # Construct the command
    command = [
        "python3", "core/test_xportrait.py",
        "--model_config", model_config,
        "--output_dir", output_dir,
        "--resume_dir", resume_dir,
        "--seed", str(seed),
        "--uc_scale", str(uc_scale),
        "--source_image", source_image,
        "--driving_video", driving_video,
        "--best_frame", str(best_frame),
        "--out_frames", str(out_frames),
        "--num_mix", str(num_mix),
        "--ddim_steps", str(ddim_steps)
    ]
    
    # Run the command
    try:
        subprocess.run(command, check=True)
        
        # Find the generated video file in the output directory
        video_files = glob.glob(os.path.join(output_dir, "*.mp4"))
        print(video_files)
        if video_files:
            return f"Output video saved at: {video_files[0]}", video_files[0]
        else:
            return "No video file was found in the output directory.", None
    except subprocess.CalledProcessError as e:
        return f"An error occurred: {e}", None

# Set up Gradio interface
with gr.Blocks() as demo:
    with gr.Column(elem_id="col-container"):
        with gr.Row():
            with gr.Column():
                source_image = gr.Image(label="Source Image", type="filepath")
                driving_video = gr.Video(label="Driving Video")
                with gr.Row():
                    seed = gr.Number(value=999, label="Seed")
                    uc_scale = gr.Number(value=5, label="UC Scale")
                with gr.Group():
                    with gr.Row():
                        best_frame = gr.Number(value=36, label="Best Frame")
                        out_frames = gr.Number(value=-1, label="Out Frames")
                    with gr.Accordion("Driving video Frames"):
                        driving_frames = gr.Gallery(show_label=True, columns=6, height=512)
                with gr.Row():
                    num_mix = gr.Number(value=4, label="Number of Mix")
                    ddim_steps = gr.Number(value=30, label="DDIM Steps")
                submit_btn = gr.Button("Submit")
            with gr.Column():
                video_output = gr.Video(label="Output Video")
                status = gr.Textbox(label="status")

    driving_video.upload(
        fn = extract_frames_with_labels,
        inputs = [driving_video],
        outputs = [driving_frames],
        queue = False
    )
    
    submit_btn.click(
        fn = run_xportrait,
        inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps],
        outputs = [status, video_output]
    )

# Launch the Gradio app
demo.launch()