ganeshblank commited on
Commit
45b0642
1 Parent(s): bc6e3e3
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The entrance of the gradio
5
+ """
6
+
7
+ import tyro
8
+ import gradio as gr
9
+ import os.path as osp
10
+ from src.utils.helper import load_description
11
+ from src.gradio_pipeline import GradioPipeline
12
+ from src.config.crop_config import CropConfig
13
+ from src.config.argument_config import ArgumentConfig
14
+ from src.config.inference_config import InferenceConfig
15
+
16
+
17
+ def partial_fields(target_class, kwargs):
18
+ return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
19
+
20
+
21
+ # set tyro theme
22
+ tyro.extras.set_accent_color("bright_cyan")
23
+ args = tyro.cli(ArgumentConfig)
24
+
25
+ # specify configs for inference
26
+ inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
27
+ crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
28
+
29
+ gradio_pipeline = GradioPipeline(
30
+ inference_cfg=inference_cfg,
31
+ crop_cfg=crop_cfg,
32
+ args=args
33
+ )
34
+
35
+
36
+ def gpu_wrapped_execute_video(*args, **kwargs):
37
+ return gradio_pipeline.execute_video(*args, **kwargs)
38
+
39
+ def gpu_wrapped_execute_s_video(*args, **kwargs):
40
+ return gradio_pipeline.execute_s_video(*args, **kwargs)
41
+
42
+ def gpu_wrapped_execute_image(*args, **kwargs):
43
+ return gradio_pipeline.execute_image(*args, **kwargs)
44
+
45
+
46
+ # assets
47
+ title_md = "assets/gradio_title.md"
48
+ example_portrait_dir = "assets/examples/source"
49
+ example_video_dir = "assets/examples/driving"
50
+ data_examples = [
51
+ [osp.join(example_portrait_dir, "s9.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
52
+ [osp.join(example_portrait_dir, "s6.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
53
+ [osp.join(example_portrait_dir, "s10.jpg"), osp.join(example_video_dir, "d0.mp4"), True, True, True, False],
54
+ [osp.join(example_portrait_dir, "s5.jpg"), osp.join(example_video_dir, "d18.mp4"), True, True, True, False],
55
+ [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d19.mp4"), True, True, True, False],
56
+ [osp.join(example_portrait_dir, "s2.jpg"), osp.join(example_video_dir, "d13.mp4"), True, True, True, True],
57
+ ]
58
+ #################### interface logic ####################
59
+
60
+ # Define components first
61
+ eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
62
+ lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
63
+ retargeting_input_image = gr.Image(type="filepath")
64
+ output_image = gr.Image(type="numpy")
65
+ output_image_paste_back = gr.Image(type="numpy")
66
+ output_video = gr.Video()
67
+ output_video_concat = gr.Video()
68
+
69
+ output_video1 = gr.Video()
70
+ output_video_concat1 = gr.Video()
71
+
72
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
73
+ gr.HTML(load_description(title_md))
74
+ gr.Markdown(load_description("assets/gradio_description_upload.md"))
75
+ with gr.Row():
76
+ with gr.Accordion(open=True, label="Source Portrait"):
77
+ image_input = gr.Image(type="filepath")
78
+ gr.Examples(
79
+ examples=[
80
+ [osp.join(example_portrait_dir, "s9.jpg")],
81
+ [osp.join(example_portrait_dir, "s6.jpg")],
82
+ [osp.join(example_portrait_dir, "s10.jpg")],
83
+ [osp.join(example_portrait_dir, "s5.jpg")],
84
+ [osp.join(example_portrait_dir, "s7.jpg")],
85
+ [osp.join(example_portrait_dir, "s12.jpg")],
86
+ ],
87
+ inputs=[image_input],
88
+ cache_examples=False,
89
+ )
90
+ with gr.Accordion(open=True, label="Driving Video"):
91
+ video_input = gr.Video()
92
+ gr.Examples(
93
+ examples=[
94
+ [osp.join(example_video_dir, "d0.mp4")],
95
+ [osp.join(example_video_dir, "d18.mp4")],
96
+ [osp.join(example_video_dir, "d19.mp4")],
97
+ [osp.join(example_video_dir, "d14.mp4")],
98
+ [osp.join(example_video_dir, "d6.mp4")],
99
+ ],
100
+ inputs=[video_input],
101
+ cache_examples=False,
102
+ )
103
+ with gr.Row():
104
+ with gr.Accordion(open=False, label="Animation Instructions and Options"):
105
+ gr.Markdown(load_description("assets/gradio_description_animation.md"))
106
+ with gr.Row():
107
+ flag_relative_input = gr.Checkbox(value=True, label="relative motion")
108
+ flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
109
+ flag_remap_input = gr.Checkbox(value=True, label="paste-back")
110
+ flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
111
+ with gr.Row():
112
+ with gr.Column():
113
+ process_button_animation = gr.Button("🚀 Animate", variant="primary")
114
+ with gr.Column():
115
+ process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
116
+ with gr.Row():
117
+ with gr.Column():
118
+ with gr.Accordion(open=True, label="The animated video in the original image space"):
119
+ output_video.render()
120
+ with gr.Column():
121
+ with gr.Accordion(open=True, label="The animated video"):
122
+ output_video_concat.render()
123
+ with gr.Row():
124
+ # Examples
125
+ gr.Markdown("## You could also choose the examples below by one click ⬇️")
126
+ with gr.Row():
127
+ gr.Examples(
128
+ examples=data_examples,
129
+ fn=gpu_wrapped_execute_video,
130
+ inputs=[
131
+ image_input,
132
+ video_input,
133
+ flag_relative_input,
134
+ flag_do_crop_input,
135
+ flag_remap_input,
136
+ flag_crop_driving_video_input
137
+ ],
138
+ outputs=[output_image, output_image_paste_back],
139
+ examples_per_page=len(data_examples),
140
+ cache_examples=False,
141
+ )
142
+
143
+ with gr.Row():
144
+ # Examples
145
+ gr.Markdown("## For video to video")
146
+ # for video portrait
147
+ with gr.Row():
148
+ with gr.Accordion(open=True, label="Video Portrait"):
149
+ source_video_input = gr.Video()
150
+ gr.Examples(
151
+ examples=[
152
+ [osp.join(example_video_dir, "d0.mp4")],
153
+ [osp.join(example_video_dir, "d18.mp4")],
154
+ [osp.join(example_video_dir, "d19.mp4")],
155
+ [osp.join(example_video_dir, "d14.mp4")],
156
+ [osp.join(example_video_dir, "d6.mp4")],
157
+ ],
158
+ inputs=[source_video_input],
159
+ cache_examples=False,
160
+ )
161
+ with gr.Accordion(open=True, label="Driving Video"):
162
+ video_input = gr.Video()
163
+ gr.Examples(
164
+ examples=[
165
+ [osp.join(example_video_dir, "d0.mp4")],
166
+ [osp.join(example_video_dir, "d18.mp4")],
167
+ [osp.join(example_video_dir, "d19.mp4")],
168
+ [osp.join(example_video_dir, "d14.mp4")],
169
+ [osp.join(example_video_dir, "d6.mp4")],
170
+ ],
171
+ inputs=[video_input],
172
+ cache_examples=False,
173
+ )
174
+ with gr.Row():
175
+ with gr.Accordion(open=False, label="source Animation Instructions and Options"):
176
+ gr.Markdown(load_description("assets/gradio_description_animation.md"))
177
+ with gr.Row():
178
+ flag_relative_input = gr.Checkbox(value=True, label="relative motion")
179
+ flag_do_crop_input = gr.Checkbox(value=True, label="do crop (source)")
180
+ flag_remap_input = gr.Checkbox(value=True, label="paste-back")
181
+ flag_crop_driving_video_input = gr.Checkbox(value=False, label="do crop (driving video)")
182
+ with gr.Row():
183
+ with gr.Column():
184
+ process_button_source_animation = gr.Button("🚀 Animate video", variant="primary")
185
+ with gr.Column():
186
+ process_button_reset = gr.ClearButton([source_video_input, video_input, output_video1, output_video_concat1], value="🧹 Clear")
187
+ with gr.Row():
188
+ with gr.Column():
189
+ with gr.Accordion(open=True, label="The animated video in the original image space"):
190
+ output_video1.render()
191
+ with gr.Column():
192
+ with gr.Accordion(open=True, label="The animated video"):
193
+ output_video_concat1.render()
194
+
195
+ gr.Markdown(load_description("assets/gradio_description_retargeting.md"), visible=True)
196
+ with gr.Row(visible=True):
197
+ eye_retargeting_slider.render()
198
+ lip_retargeting_slider.render()
199
+ with gr.Row(visible=True):
200
+ process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
201
+ process_button_reset_retargeting = gr.ClearButton(
202
+ [
203
+ eye_retargeting_slider,
204
+ lip_retargeting_slider,
205
+ retargeting_input_image,
206
+ output_image,
207
+ output_image_paste_back
208
+ ],
209
+ value="🧹 Clear"
210
+ )
211
+ with gr.Row(visible=True):
212
+ with gr.Column():
213
+ with gr.Accordion(open=True, label="Retargeting Input"):
214
+ retargeting_input_image.render()
215
+ gr.Examples(
216
+ examples=[
217
+ [osp.join(example_portrait_dir, "s9.jpg")],
218
+ [osp.join(example_portrait_dir, "s6.jpg")],
219
+ [osp.join(example_portrait_dir, "s10.jpg")],
220
+ [osp.join(example_portrait_dir, "s5.jpg")],
221
+ [osp.join(example_portrait_dir, "s7.jpg")],
222
+ [osp.join(example_portrait_dir, "s12.jpg")],
223
+ ],
224
+ inputs=[retargeting_input_image],
225
+ cache_examples=False,
226
+ )
227
+ with gr.Column():
228
+ with gr.Accordion(open=True, label="Retargeting Result"):
229
+ output_image.render()
230
+ with gr.Column():
231
+ with gr.Accordion(open=True, label="Paste-back Result"):
232
+ output_image_paste_back.render()
233
+ # binding functions for buttons
234
+ process_button_retargeting.click(
235
+ # fn=gradio_pipeline.execute_image,
236
+ fn=gpu_wrapped_execute_image,
237
+ inputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image, flag_do_crop_input],
238
+ outputs=[output_image, output_image_paste_back],
239
+ show_progress=True
240
+ )
241
+ process_button_animation.click(
242
+ fn=gpu_wrapped_execute_video,
243
+ inputs=[
244
+ image_input,
245
+ video_input,
246
+ flag_relative_input,
247
+ flag_do_crop_input,
248
+ flag_remap_input,
249
+ flag_crop_driving_video_input
250
+ ],
251
+ outputs=[output_video, output_video_concat],
252
+ show_progress=True
253
+ )
254
+ process_button_source_animation.click(
255
+ fn=gpu_wrapped_execute_s_video,
256
+ inputs=[
257
+ source_video_input,
258
+ video_input,
259
+ flag_relative_input,
260
+ flag_do_crop_input,
261
+ flag_remap_input,
262
+ flag_crop_driving_video_input
263
+ ],
264
+ outputs=[output_video1, output_video_concat1],
265
+ show_progress=True
266
+ )
267
+
268
+
269
+ demo.launch(
270
+ server_port=args.server_port,
271
+ share=args.share,
272
+ server_name=args.server_name
273
+ )