import gradio as gr import subprocess import os import cv2 from huggingface_hub import hf_hub_download import glob from datetime import datetime is_shared_ui = True if "fffiloni/X-Portrait" in os.environ['SPACE_ID'] else False # Ensure 'checkpoint' directory exists os.makedirs("checkpoint", exist_ok=True) hf_hub_download( repo_id="fffiloni/X-Portrait", filename="model_state-415001.th", local_dir="checkpoint" ) def trim_video(video_path, output_dir="trimmed_videos", max_duration=2): # Create output directory if it does not exist os.makedirs(output_dir, exist_ok=True) # Generate a timestamp for the output filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = os.path.join(output_dir, f"trimmed_video_{timestamp}.mp4") # Load the video with VideoFileClip(video_path) as video: # Check the duration of the video if video.duration > max_duration: # Trim the video to the first max_duration seconds trimmed_video = video.subclip(0, max_duration) # Write the trimmed video to a file trimmed_video.write_videofile(output_path, codec="libx264") return output_path else: # If the video is within the duration, return the original path return video_path def load_driving_video(video_path): if is_shared_ui : video_path = trim_video(video_path) print("Path to the (trimmed) driving video:", video_path) frames_data = extract_frames_with_labels(video_path) return video_path, frames_data, gr.update(open="True") else: frames_data = extract_frames_with_labels(video_path) return video_path, frames_data, gr.update(open="True") def extract_frames_with_labels(video_path, base_output_dir="frames"): # Generate a timestamped folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(base_output_dir, f"frames_{timestamp}") # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Open the video file video_capture = cv2.VideoCapture(video_path) if not video_capture.isOpened(): raise ValueError(f"Cannot open video file: {video_path}") frame_data = [] frame_index = 0 # Loop through the video frames while True: ret, frame = video_capture.read() if not ret: break # Exit the loop if there are no frames left to read # Zero-padded frame index for filename and label frame_label = f"{frame_index:04}" frame_filename = os.path.join(output_dir, f"frame_{frame_label}.jpg") # Save the frame as a .jpg file cv2.imwrite(frame_filename, frame) # Append the tuple (filename, label) to the list frame_data.append((frame_filename, frame_label)) # Increment frame index frame_index += 1 # Release the video capture object video_capture.release() return frame_data # Define a function to run your script with selected inputs def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps, progress=gr.Progress(track_tqdm=True)): # Create a unique output directory name based on current date and time output_dir_base = "outputs" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = os.path.join(output_dir_base, f"output_{timestamp}") os.makedirs(output_dir, exist_ok=True) model_config = "config/cldm_v15_appearance_pose_local_mm.yaml" resume_dir = "checkpoint/model_state-415001.th" # Construct the command command = [ "python3", "core/test_xportrait.py", "--model_config", model_config, "--output_dir", output_dir, "--resume_dir", resume_dir, "--seed", str(seed), "--uc_scale", str(uc_scale), "--source_image", source_image, "--driving_video", driving_video, "--best_frame", str(best_frame), "--out_frames", str(out_frames), "--num_mix", str(num_mix), "--ddim_steps", str(ddim_steps) ] # Run the command try: subprocess.run(command, check=True) # Find the generated video file in the output directory video_files = glob.glob(os.path.join(output_dir, "*.mp4")) print(video_files) if video_files: return f"Output video saved at: {video_files[0]}", video_files[0] else: return "No video file was found in the output directory.", None except subprocess.CalledProcessError as e: return f"An error occurred: {e}", None # Set up Gradio interface css=""" div#frames-gallery{ overflow: scroll!important; } """ example_frame_data = extract_frames_with_labels("./assets/driving_video.mp4") with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention") gr.Markdown("On this shared UI, drinving video input will be trimmed to 2 seconds max. Duplicate this space for more controls.") gr.HTML("""
""") with gr.Row(): with gr.Column(): with gr.Row(): source_image = gr.Image(label="Source Image", type="filepath") driving_video = gr.Video(label="Driving Video") with gr.Group(): with gr.Row(): best_frame = gr.Number(value=36, label="Best Frame", info="specify the frame index in the driving video where the head pose best matches the source image (note: precision of best_frame index might affect the final quality)") out_frames = gr.Number(value=-1, label="Out Frames", info="number of generation frames") with gr.Accordion("Driving video Frames", open=False) as frames_gallery_panel: driving_frames = gr.Gallery(show_label=True, columns=6, height=380, elem_id="frames-gallery") with gr.Row(): seed = gr.Number(value=999, label="Seed") uc_scale = gr.Number(value=5, label="UC Scale") with gr.Row(): num_mix = gr.Number(value=4, label="Number of Mix") ddim_steps = gr.Number(value=30, label="DDIM Steps") submit_btn = gr.Button("Submit") gr.Examples( examples=[ ["./assets/source_image.png", "./assets/driving_video.mp4", example_frame_data] ], inputs=[source_image, driving_video, driving_frames] ) with gr.Column(): video_output = gr.Video(label="Output Video") status = gr.Textbox(label="status") gr.HTML(""" """) driving_video.upload( fn = load_driving_video, inputs = [driving_video], outputs = [driving_video, driving_frames, frames_gallery_panel], queue = False ) submit_btn.click( fn = run_xportrait, inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps], outputs = [status, video_output] ) # Launch the Gradio app demo.launch()