fffiloni commited on
Commit
8af9162
1 Parent(s): 4264aae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -35
app.py CHANGED
@@ -14,7 +14,11 @@ hf_hub_download(
14
  local_dir="checkpoint"
15
  )
16
 
17
- def extract_frames_with_labels(video_path, output_dir="frames"):
 
 
 
 
18
  # Ensure output directory exists
19
  os.makedirs(output_dir, exist_ok=True)
20
 
@@ -51,27 +55,19 @@ def extract_frames_with_labels(video_path, output_dir="frames"):
51
  return frame_data
52
 
53
  # Define a function to run your script with selected inputs
54
- def run_xportrait(
55
- model_config,
56
- output_dir_base,
57
- resume_dir,
58
- seed,
59
- uc_scale,
60
- source_image,
61
- driving_video,
62
- best_frame,
63
- out_frames,
64
- num_mix,
65
- ddim_steps
66
- ):
67
  # Check if the model weights are in place
68
  if not os.path.exists(resume_dir):
69
  return "Model weights not found in checkpoint directory. Please download them first."
70
 
71
  # Create a unique output directory name based on current date and time
 
72
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
73
  output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
74
  os.makedirs(output_dir, exist_ok=True)
 
 
 
75
 
76
  # Construct the command
77
  command = [
@@ -94,7 +90,7 @@ def run_xportrait(
94
  subprocess.run(command, check=True)
95
 
96
  # Find the generated video file in the output directory
97
- video_files = glob.glob(os.path.join(output_dir, "*.mp4")) + glob.glob(os.path.join(output_dir, "*.avi"))
98
  print(video_files)
99
  if video_files:
100
  return f"Output video saved at: {video_files[0]}", video_files[0]
@@ -104,25 +100,41 @@ def run_xportrait(
104
  return f"An error occurred: {e}", None
105
 
106
  # Set up Gradio interface
107
- app = gr.Interface(
108
- fn=run_xportrait,
109
- inputs=[
110
- gr.Textbox(value="config/cldm_v15_appearance_pose_local_mm.yaml", label="Model Config Path"),
111
- gr.Textbox(value="outputs", label="Output Directory"),
112
- gr.Textbox(value="checkpoint/model_state-415001.th", label="Resume Directory"),
113
- gr.Number(value=999, label="Seed"),
114
- gr.Number(value=5, label="UC Scale"),
115
- gr.Image(label="Source Image", type="filepath"),
116
- gr.Video(label="Driving Video"),
117
- gr.Number(value=36, label="Best Frame"),
118
- gr.Number(value=-1, label="Out Frames"),
119
- gr.Number(value=4, label="Number of Mix"),
120
- gr.Number(value=30, label="DDIM Steps")
121
- ],
122
- outputs=["text", "video"],
123
- title="XPortrait Model Runner",
124
- description="Run XPortrait with customizable parameters."
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Launch the Gradio app
128
- app.launch()
 
14
  local_dir="checkpoint"
15
  )
16
 
17
+ def extract_frames_with_labels(video_path, base_output_dir="frames"):
18
+ # Generate a timestamped folder name
19
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
20
+ output_dir = os.path.join(base_output_dir, f"frames_{timestamp}")
21
+
22
  # Ensure output directory exists
23
  os.makedirs(output_dir, exist_ok=True)
24
 
 
55
  return frame_data
56
 
57
  # Define a function to run your script with selected inputs
58
+ def run_xportrait(source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps):
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Check if the model weights are in place
60
  if not os.path.exists(resume_dir):
61
  return "Model weights not found in checkpoint directory. Please download them first."
62
 
63
  # Create a unique output directory name based on current date and time
64
+ output_dir_base = "outputs"
65
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
66
  output_dir = os.path.join(output_dir_base, f"output_{timestamp}")
67
  os.makedirs(output_dir, exist_ok=True)
68
+
69
+ model_config = "config/cldm_v15_appearance_pose_local_mm.yaml"
70
+ resume_dir = "checkpoint/model_state-415001.th"
71
 
72
  # Construct the command
73
  command = [
 
90
  subprocess.run(command, check=True)
91
 
92
  # Find the generated video file in the output directory
93
+ video_files = glob.glob(os.path.join(output_dir, "*.mp4"))
94
  print(video_files)
95
  if video_files:
96
  return f"Output video saved at: {video_files[0]}", video_files[0]
 
100
  return f"An error occurred: {e}", None
101
 
102
  # Set up Gradio interface
103
+ with gr.Blocks() as demo:
104
+ with gr.Column(elem_id="col-container"):
105
+ with gr.Row():
106
+ with gr.Column():
107
+ source_image = gr.Image(label="Source Image", type="filepath")
108
+ driving_video = gr.Video(label="Driving Video")
109
+ with gr.Row():
110
+ seed = gr.Number(value=999, label="Seed")
111
+ uc_scale = gr.Number(value=5, label="UC Scale")
112
+ with gr.Group():
113
+ with gr.Row():
114
+ best_frame = gr.Number(value=36, label="Best Frame")
115
+ out_frames = gr.Number(value=-1, label="Out Frames")
116
+ with gr.Accordion("Driving video Frames"):
117
+ driving_frames = gr.Gallery(show_label=True)
118
+ with gr.Row():
119
+ num_mix = gr.Number(value=4, label="Number of Mix")
120
+ ddim_steps = gr.Number(value=30, label="DDIM Steps")
121
+ submit_btn = gr.Button("Submit")
122
+ with gr.Column():
123
+ video_output = gr.Video(label="Output Video")
124
+ status = gr.Textbox(label="status")
125
+
126
+ driving_video.upload(
127
+ fn = extract_frames_with_labels,
128
+ inputs = [driving_video],
129
+ ouputs = [driving_frames],
130
+ queue = False
131
+ )
132
+
133
+ submit_btn.click(
134
+ fn = run_xportrait,
135
+ inputs = [source_image, driving_video, seed, uc_scale, best_frame, out_frames, num_mix, ddim_steps],
136
+ outputs = [status, video_output]
137
+ )
138
 
139
  # Launch the Gradio app
140
+ demo.launch()