X-Portrait / app.py
fffiloni's picture
Update app.py
eab49d0 verified
raw
history blame
8 kB
import gradio as gr
import subprocess
import os
import cv2
from huggingface_hub import hf_hub_download
import glob
from datetime import datetime
is_shared_ui = True if "fffiloni/X-Portrait" in os.environ['SPACE_ID'] else False
# Ensure 'checkpoint' directory exists
os.makedirs("checkpoint", exist_ok=True)
hf_hub_download(
repo_id="fffiloni/X-Portrait",
filename="model_state-415001.th",
local_dir="checkpoint"
)
def trim_video(video_path, output_dir="trimmed_videos", max_duration=2):
# Create output directory if it does not exist
os.makedirs(output_dir, exist_ok=True)
# Generate a timestamp for the output filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"trimmed_video_{timestamp}.mp4")
# Load the video
with VideoFileClip(video_path) as video:
# Check the duration of the video
if video.duration > max_duration:
# Trim the video to the first max_duration seconds
trimmed_video = video.subclip(0, max_duration)
# Write the trimmed video to a file
trimmed_video.write_videofile(output_path, codec="libx264")
return output_path
else:
# If the video is within the duration, return the original path
return video_path
def extract_frames_with_labels(video_path, base_output_dir="frames"):
if is_shared_ui :
video_path = trim_video(video_path)
print("Path to the (trimmed) driving video:", video_path)
# Generate a timestamped folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(base_output_dir, f"frames_{timestamp}")
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Open the video file
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise ValueError(f"Cannot open video file: {video_path}")
frame_data = []
frame_index = 0
# Loop through the video frames
while True:
ret, frame = video_capture.read()
if not ret:
break # Exit the loop if there are no frames left to read
# Zero-padded frame index for filename and label
frame_label = f"{frame_index:04}"
frame_filename = os.path.join(output_dir, f"frame_{frame_label}.jpg")
# Save the frame as a .jpg file
cv2.imwrite(frame_filename, frame)
# Append the tuple (filename, label) to the list
frame_data.append((frame_filename, frame_label))
# Increment frame index
frame_index += 1
# Release the video capture object
video_capture.release()
return frame_data, gr.update(open=True)
# Define a function to run your script with selected inputs
def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps, progress=gr.Progress(track_tqdm=True)):
# Create a unique output directory name based on current date and time
output_dir_base = "outputs"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
os.makedirs(output_dir, exist_ok=True)
model_config = "config/cldm_v15_appearance_pose_local_mm.yaml"
resume_dir = "checkpoint/model_state-415001.th"
# Construct the command
command = [
"python3", "core/test_xportrait.py",
"--model_config", model_config,
"--output_dir", output_dir,
"--resume_dir", resume_dir,
"--seed", str(seed),
"--uc_scale", str(uc_scale),
"--source_image", source_image,
"--driving_video", driving_video,
"--best_frame", str(best_frame),
"--out_frames", str(out_frames),
"--num_mix", str(num_mix),
"--ddim_steps", str(ddim_steps)
]
# Run the command
try:
subprocess.run(command, check=True)
# Find the generated video file in the output directory
video_files = glob.glob(os.path.join(output_dir, "*.mp4"))
print(video_files)
if video_files:
return f"Output video saved at: {video_files[0]}", video_files[0]
else:
return "No video file was found in the output directory.", None
except subprocess.CalledProcessError as e:
return f"An error occurred: {e}", None
# Set up Gradio interface
css="""
div#frames-gallery{
overflow: scroll!important;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# X-Portrait: Expressive Portrait Animation with Hierarchical Motion Attention")
gr.Markdown("On this shared UI, drinving video input will be trimmed to 2 seconds max. Duplicate this space for more controls.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://github.com/bytedance/X-Portrait'>
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href='https://byteaigc.github.io/x-portrait/'>
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
with gr.Row():
source_image = gr.Image(label="Source Image", type="filepath")
driving_video = gr.Video(label="Driving Video")
with gr.Group():
with gr.Row():
best_frame = gr.Number(value=36, label="Best Frame", info="specify the frame index in the driving video where the head pose best matches the source image (note: precision of best_frame index might affect the final quality)")
out_frames = gr.Number(value=-1, label="Out Frames", info="number of generation frames")
with gr.Accordion("Driving video Frames", open=False) as frames_gallery_panel:
driving_frames = gr.Gallery(show_label=True, columns=6, height=380, elem_id="frames-gallery")
with gr.Row():
seed = gr.Number(value=999, label="Seed")
uc_scale = gr.Number(value=5, label="UC Scale")
with gr.Row():
num_mix = gr.Number(value=4, label="Number of Mix")
ddim_steps = gr.Number(value=30, label="DDIM Steps")
submit_btn = gr.Button("Submit")
with gr.Column():
video_output = gr.Video(label="Output Video")
status = gr.Textbox(label="status")
gr.Examples(
examples=[
["./assets/source_image.png", "./assets/driving_video.mp4"]
],
inputs=[source_image, driving_video]
)
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://huggingface.co/spaces/fffiloni/X-Portrait?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-xl.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-xl-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
driving_video.change(
fn = extract_frames_with_labels,
inputs = [driving_video],
outputs = [driving_frames, frames_gallery_panel],
queue = False
)
submit_btn.click(
fn = run_xportrait,
inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps],
outputs = [status, video_output]
)
# Launch the Gradio app
demo.launch()