Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,11 @@ hf_hub_download(
|
|
14 |
local_dir="checkpoint"
|
15 |
)
|
16 |
|
17 |
-
def extract_frames_with_labels(video_path,
|
|
|
|
|
|
|
|
|
18 |
# Ensure output directory exists
|
19 |
os.makedirs(output_dir, exist_ok=True)
|
20 |
|
@@ -51,27 +55,19 @@ def extract_frames_with_labels(video_path, output_dir="frames"):
|
|
51 |
return frame_data
|
52 |
|
53 |
# Define a function to run your script with selected inputs
|
54 |
-
def run_xportrait(
|
55 |
-
model_config,
|
56 |
-
output_dir_base,
|
57 |
-
resume_dir,
|
58 |
-
seed,
|
59 |
-
uc_scale,
|
60 |
-
source_image,
|
61 |
-
driving_video,
|
62 |
-
best_frame,
|
63 |
-
out_frames,
|
64 |
-
num_mix,
|
65 |
-
ddim_steps
|
66 |
-
):
|
67 |
# Check if the model weights are in place
|
68 |
if not os.path.exists(resume_dir):
|
69 |
return "Model weights not found in checkpoint directory. Please download them first."
|
70 |
|
71 |
# Create a unique output directory name based on current date and time
|
|
|
72 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
73 |
output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
|
74 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
75 |
|
76 |
# Construct the command
|
77 |
command = [
|
@@ -94,7 +90,7 @@ def run_xportrait(
|
|
94 |
subprocess.run(command, check=True)
|
95 |
|
96 |
# Find the generated video file in the output directory
|
97 |
-
video_files = glob.glob(os.path.join(output_dir, "*.mp4"))
|
98 |
print(video_files)
|
99 |
if video_files:
|
100 |
return f"Output video saved at: {video_files[0]}", video_files[0]
|
@@ -104,25 +100,41 @@ def run_xportrait(
|
|
104 |
return f"An error occurred: {e}", None
|
105 |
|
106 |
# Set up Gradio interface
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# Launch the Gradio app
|
128 |
-
|
|
|
14 |
local_dir="checkpoint"
|
15 |
)
|
16 |
|
17 |
+
def extract_frames_with_labels(video_path, base_output_dir="frames"):
|
18 |
+
# Generate a timestamped folder name
|
19 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
20 |
+
output_dir = os.path.join(base_output_dir, f"frames_{timestamp}")
|
21 |
+
|
22 |
# Ensure output directory exists
|
23 |
os.makedirs(output_dir, exist_ok=True)
|
24 |
|
|
|
55 |
return frame_data
|
56 |
|
57 |
# Define a function to run your script with selected inputs
|
58 |
+
def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# Check if the model weights are in place
|
60 |
if not os.path.exists(resume_dir):
|
61 |
return "Model weights not found in checkpoint directory. Please download them first."
|
62 |
|
63 |
# Create a unique output directory name based on current date and time
|
64 |
+
output_dir_base = "outputs"
|
65 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
66 |
output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
|
67 |
os.makedirs(output_dir, exist_ok=True)
|
68 |
+
|
69 |
+
model_config = "config/cldm_v15_appearance_pose_local_mm.yaml"
|
70 |
+
resume_dir = "checkpoint/model_state-415001.th"
|
71 |
|
72 |
# Construct the command
|
73 |
command = [
|
|
|
90 |
subprocess.run(command, check=True)
|
91 |
|
92 |
# Find the generated video file in the output directory
|
93 |
+
video_files = glob.glob(os.path.join(output_dir, "*.mp4"))
|
94 |
print(video_files)
|
95 |
if video_files:
|
96 |
return f"Output video saved at: {video_files[0]}", video_files[0]
|
|
|
100 |
return f"An error occurred: {e}", None
|
101 |
|
102 |
# Set up Gradio interface
|
103 |
+
with gr.Blocks() as demo:
|
104 |
+
with gr.Column(elem_id="col-container"):
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
source_image = gr.Image(label="Source Image", type="filepath")
|
108 |
+
driving_video = gr.Video(label="Driving Video")
|
109 |
+
with gr.Row():
|
110 |
+
seed = gr.Number(value=999, label="Seed")
|
111 |
+
uc_scale = gr.Number(value=5, label="UC Scale")
|
112 |
+
with gr.Group():
|
113 |
+
with gr.Row():
|
114 |
+
best_frame = gr.Number(value=36, label="Best Frame")
|
115 |
+
out_frames = gr.Number(value=-1, label="Out Frames")
|
116 |
+
with gr.Accordion("Driving video Frames"):
|
117 |
+
driving_frames = gr.Gallery(show_label=True)
|
118 |
+
with gr.Row():
|
119 |
+
num_mix = gr.Number(value=4, label="Number of Mix")
|
120 |
+
ddim_steps = gr.Number(value=30, label="DDIM Steps")
|
121 |
+
submit_btn = gr.Button("Submit")
|
122 |
+
with gr.Column():
|
123 |
+
video_output = gr.Video(label="Output Video")
|
124 |
+
status = gr.Textbox(label="status")
|
125 |
+
|
126 |
+
driving_video.upload(
|
127 |
+
fn = extract_frames_with_labels,
|
128 |
+
inputs = [driving_video],
|
129 |
+
ouputs = [driving_frames],
|
130 |
+
queue = False
|
131 |
+
)
|
132 |
+
|
133 |
+
submit_btn.click(
|
134 |
+
fn = run_xportrait,
|
135 |
+
inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps],
|
136 |
+
outputs = [status, video_output]
|
137 |
+
)
|
138 |
|
139 |
# Launch the Gradio app
|
140 |
+
demo.launch()
|