fffiloni commited on
Commit
24d5f06
·
verified ·
1 Parent(s): c85dadc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -0
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image, ImageFilter
5
+ import uuid
6
+ from scipy.interpolate import interp1d, PchipInterpolator
7
+ import torchvision
8
+ from utils import *
9
+
10
+ output_dir = "outputs"
11
+ ensure_dirname(output_dir)
12
+
13
+ def interpolate_trajectory(points, n_points):
14
+ x = [point[0] for point in points]
15
+ y = [point[1] for point in points]
16
+
17
+ t = np.linspace(0, 1, len(points))
18
+
19
+ # fx = interp1d(t, x, kind='cubic')
20
+ # fy = interp1d(t, y, kind='cubic')
21
+ fx = PchipInterpolator(t, x)
22
+ fy = PchipInterpolator(t, y)
23
+
24
+ new_t = np.linspace(0, 1, n_points)
25
+
26
+ new_x = fx(new_t)
27
+ new_y = fy(new_t)
28
+ new_points = list(zip(new_x, new_y))
29
+
30
+ return new_points
31
+
32
+ def visualize_drag_v2(background_image_path, splited_tracks, width, height):
33
+ trajectory_maps = []
34
+
35
+ background_image = Image.open(background_image_path).convert('RGBA')
36
+ background_image = background_image.resize((width, height))
37
+ w, h = background_image.size
38
+ transparent_background = np.array(background_image)
39
+ transparent_background[:, :, -1] = 128
40
+ transparent_background = Image.fromarray(transparent_background)
41
+
42
+ # Create a transparent layer with the same size as the background image
43
+ transparent_layer = np.zeros((h, w, 4))
44
+ for splited_track in splited_tracks:
45
+ if len(splited_track) > 1:
46
+ splited_track = interpolate_trajectory(splited_track, 16)
47
+ splited_track = splited_track[:16]
48
+ for i in range(len(splited_track)-1):
49
+ start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
50
+ end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
51
+ vx = end_point[0] - start_point[0]
52
+ vy = end_point[1] - start_point[1]
53
+ arrow_length = np.sqrt(vx**2 + vy**2)
54
+ if i == len(splited_track)-2:
55
+ cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
56
+ else:
57
+ cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
58
+ else:
59
+ cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, (255, 0, 0, 192), -1)
60
+
61
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
62
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
63
+ trajectory_maps.append(trajectory_map)
64
+ return trajectory_maps, transparent_layer
65
+
66
+ class Drag:
67
+ def __init__(self, device, model_path, cfg_path, height, width, model_length):
68
+ self.device = device
69
+ cf = import_filename(cfg_path)
70
+ Net, args = cf.Net, cf.args
71
+ drag_nuwa_net = Net(args)
72
+ state_dict = file2data(model_path, map_location='cpu')
73
+ adaptively_load_state_dict(drag_nuwa_net, state_dict)
74
+ drag_nuwa_net.eval()
75
+ drag_nuwa_net.to(device)
76
+ # drag_nuwa_net.half()
77
+ self.drag_nuwa_net = drag_nuwa_net
78
+ self.height = height
79
+ self.width = width
80
+ _, model_step, _ = split_filename(model_path)
81
+ self.ouput_prefix = f'{model_step}_{width}X{height}'
82
+ self.model_length = model_length
83
+
84
+ @torch.no_grad()
85
+ def forward_sample(self, input_drag, input_first_frame, motion_bucket_id, outputs=dict()):
86
+ device = self.device
87
+
88
+ b, l, h, w, c = input_drag.size()
89
+ drag = self.drag_nuwa_net.apply_gaussian_filter_on_drag(input_drag)
90
+ drag = torch.cat([torch.zeros_like(drag[:, 0]).unsqueeze(1), drag], dim=1) # pad the first frame with zero flow
91
+ drag = rearrange(drag, 'b l h w c -> b l c h w')
92
+
93
+ input_conditioner = dict()
94
+ input_conditioner['cond_frames_without_noise'] = input_first_frame
95
+ input_conditioner['cond_frames'] = (input_first_frame + 0.02 * torch.randn_like(input_first_frame))
96
+ input_conditioner['motion_bucket_id'] = torch.tensor([motion_bucket_id]).to(drag.device).repeat(b * (l+1))
97
+ input_conditioner['fps_id'] = torch.tensor([self.drag_nuwa_net.args.fps]).to(drag.device).repeat(b * (l+1))
98
+ input_conditioner['cond_aug'] = torch.tensor([0.02]).to(drag.device).repeat(b * (l+1))
99
+
100
+ input_conditioner_uc = {}
101
+ for key in input_conditioner.keys():
102
+ if key not in input_conditioner_uc and isinstance(input_conditioner[key], torch.Tensor):
103
+ input_conditioner_uc[key] = input_conditioner[key].clone()
104
+
105
+ c, uc = self.drag_nuwa_net.conditioner.get_unconditional_conditioning(
106
+ input_conditioner,
107
+ batch_uc=input_conditioner_uc,
108
+ force_uc_zero_embeddings=[
109
+ "cond_frames",
110
+ "cond_frames_without_noise",
111
+ ],
112
+ )
113
+
114
+ for k in ["crossattn", "concat"]:
115
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
116
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...")
117
+ c[k] = repeat(c[k], "b ... -> b t ...", t=self.drag_nuwa_net.num_frames)
118
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...")
119
+
120
+ H, W = input_conditioner['cond_frames_without_noise'].shape[2:]
121
+ shape = (self.drag_nuwa_net.num_frames, 4, H // 8, W // 8)
122
+ randn = torch.randn(shape).to(self.device)
123
+
124
+ additional_model_inputs = {}
125
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
126
+ 2, self.drag_nuwa_net.num_frames
127
+ ).to(self.device)
128
+ additional_model_inputs["num_video_frames"] = self.drag_nuwa_net.num_frames
129
+ additional_model_inputs["flow"] = drag.repeat(2, 1, 1, 1, 1) # c and uc
130
+
131
+ def denoiser(input, sigma, c):
132
+ return self.drag_nuwa_net.denoiser(self.drag_nuwa_net.model, input, sigma, c, **additional_model_inputs)
133
+
134
+ samples_z = self.drag_nuwa_net.sampler(denoiser, randn, cond=c, uc=uc)
135
+ samples = self.drag_nuwa_net.decode_first_stage(samples_z)
136
+
137
+ outputs['logits_imgs'] = rearrange(samples, '(b l) c h w -> b l c h w', b=b)
138
+ return outputs
139
+
140
+ def run(self, first_frame_path, tracking_points, inference_batch_size, motion_bucket_id):
141
+ original_width, original_height=576, 320
142
+
143
+ input_all_points = tracking_points.constructor_args['value']
144
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
145
+
146
+ input_drag = torch.zeros(self.model_length - 1, self.height, self.width, 2)
147
+ for splited_track in resized_all_points:
148
+ if len(splited_track) == 1: # stationary point
149
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
150
+ splited_track = tuple([splited_track[0], displacement_point])
151
+ # interpolate the track
152
+ splited_track = interpolate_trajectory(splited_track, self.model_length)
153
+ splited_track = splited_track[:self.model_length]
154
+ if len(splited_track) < self.model_length:
155
+ splited_track = splited_track + [splited_track[-1]] * (self.model_length -len(splited_track))
156
+ for i in range(self.model_length - 1):
157
+ start_point = splited_track[i]
158
+ end_point = splited_track[i+1]
159
+ input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
160
+ input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
161
+
162
+ dir, base, ext = split_filename(first_frame_path)
163
+ id = base.split('_')[-1]
164
+
165
+ image_pil = image2pil(first_frame_path)
166
+ image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGB')
167
+
168
+ visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
169
+
170
+ first_frames_transform = transforms.Compose([
171
+ lambda x: Image.fromarray(x),
172
+ transforms.ToTensor(),
173
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
174
+ ])
175
+
176
+ outputs = None
177
+ ouput_video_list = []
178
+ num_inference = 1
179
+ for i in tqdm(range(num_inference)):
180
+ if not outputs:
181
+ first_frames = image2arr(first_frame_path)
182
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to(self.device)
183
+ else:
184
+ first_frames = outputs['logits_imgs'][:, -1]
185
+
186
+ outputs = self.forward_sample(
187
+ repeat(input_drag[i*(self.model_length - 1):(i+1)*(self.model_length - 1)], 'l h w c -> b l h w c', b=inference_batch_size).to(self.device),
188
+ first_frames,
189
+ motion_bucket_id)
190
+ ouput_video_list.append(outputs['logits_imgs'])
191
+
192
+ for i in range(inference_batch_size):
193
+ ouput_tensor = [ouput_video_list[0][i]]
194
+ for j in range(num_inference - 1):
195
+ ouput_tensor.append(ouput_video_list[j+1][i][1:])
196
+ ouput_tensor = torch.cat(ouput_tensor, dim=0)
197
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
198
+ data2file([transforms.ToPILImage('RGB')(utils.make_grid(e.to(torch.float32).cpu(), normalize=True, range=(-1, 1))) for e in ouput_tensor], outputs_path,
199
+ printable=False, duration=1 / 6, override=True)
200
+
201
+ return visualized_drag[0], outputs_path
202
+
203
+ with gr.Blocks() as demo:
204
+ gr.Markdown("""<h1 align="center">DragNUWA 1.5</h1><br>""")
205
+ gr.HTML("""
206
+ <p style="margin:12px auto;display: flex;justify-content: center;">
207
+ <a href="https://huggingface.co/spaces/fffiloni/DragNUWA?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a>
208
+ </p>
209
+ """)
210
+ gr.Markdown("""Official Gradio Demo for <a href='https://arxiv.org/abs/2308.08089'><b>DragNUWA: Fine-grained Control in Video Generation by Integrating Text, Image, and Trajectory</b></a>.<br>
211
+ 🔥DragNUWA enables users to manipulate backgrounds or objects within images directly, and the model seamlessly translates these actions into **camera movements** or **object motions**, generating the corresponding video.<br>
212
+ 🔥DragNUWA 1.5 enables Stable Video Diffusion to animate an image according to specific path.<br>""")
213
+
214
+ gr.Image(label="DragNUWA", value="assets/DragNUWA1.5/Figure1.gif")
215
+
216
+ gr.Markdown("""## Usage: <br>
217
+ 1. Upload an image via the "Upload Image" button.<br>
218
+ 2. Draw some drags.<br>
219
+ 2.1. Click "Add Drag" when you want to add a control path.<br>
220
+ 2.2. You can click several points which forms a path.<br>
221
+ 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
222
+ 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
223
+ 3. Animate the image according the path with a click on "Run" button. <br>""")
224
+
225
+ DragNUWA_net = Drag("cuda:0", 'models/drag_nuwa_svd.pth', 'DragNUWA_net.py', 320, 576, 14)
226
+ first_frame_path = gr.State()
227
+ tracking_points = gr.State([])
228
+
229
+ def reset_states(first_frame_path, tracking_points):
230
+ first_frame_path = gr.State()
231
+ tracking_points = gr.State([])
232
+ return first_frame_path, tracking_points
233
+
234
+ def preprocess_image(image):
235
+ image_pil = image2pil(image.name)
236
+ raw_w, raw_h = image_pil.size
237
+ resize_ratio = max(576/raw_w, 320/raw_h)
238
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
239
+ image_pil = transforms.CenterCrop((320, 576))(image_pil.convert('RGB'))
240
+
241
+ first_frame_path = os.path.join(output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
242
+ image_pil.save(first_frame_path)
243
+
244
+ return first_frame_path, first_frame_path, gr.State([])
245
+
246
+ def add_drag(tracking_points):
247
+ tracking_points.constructor_args['value'].append([])
248
+ return tracking_points
249
+
250
+ def delete_last_drag(tracking_points, first_frame_path):
251
+ tracking_points.constructor_args['value'].pop()
252
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
253
+ w, h = transparent_background.size
254
+ transparent_layer = np.zeros((h, w, 4))
255
+ for track in tracking_points.constructor_args['value']:
256
+ if len(track) > 1:
257
+ for i in range(len(track)-1):
258
+ start_point = track[i]
259
+ end_point = track[i+1]
260
+ vx = end_point[0] - start_point[0]
261
+ vy = end_point[1] - start_point[1]
262
+ arrow_length = np.sqrt(vx**2 + vy**2)
263
+ if i == len(track)-2:
264
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
265
+ else:
266
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
267
+ else:
268
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
269
+
270
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
271
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
272
+ return tracking_points, trajectory_map
273
+
274
+ def delete_last_step(tracking_points, first_frame_path):
275
+ tracking_points.constructor_args['value'][-1].pop()
276
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
277
+ w, h = transparent_background.size
278
+ transparent_layer = np.zeros((h, w, 4))
279
+ for track in tracking_points.constructor_args['value']:
280
+ if len(track) > 1:
281
+ for i in range(len(track)-1):
282
+ start_point = track[i]
283
+ end_point = track[i+1]
284
+ vx = end_point[0] - start_point[0]
285
+ vy = end_point[1] - start_point[1]
286
+ arrow_length = np.sqrt(vx**2 + vy**2)
287
+ if i == len(track)-2:
288
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
289
+ else:
290
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
291
+ else:
292
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
293
+
294
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
295
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
296
+ return tracking_points, trajectory_map
297
+
298
+ def add_tracking_points(tracking_points, first_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
299
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
300
+ tracking_points.constructor_args['value'][-1].append(evt.index)
301
+
302
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
303
+ w, h = transparent_background.size
304
+ transparent_layer = np.zeros((h, w, 4))
305
+ for track in tracking_points.constructor_args['value']:
306
+ if len(track) > 1:
307
+ for i in range(len(track)-1):
308
+ start_point = track[i]
309
+ end_point = track[i+1]
310
+ vx = end_point[0] - start_point[0]
311
+ vy = end_point[1] - start_point[1]
312
+ arrow_length = np.sqrt(vx**2 + vy**2)
313
+ if i == len(track)-2:
314
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
315
+ else:
316
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
317
+ else:
318
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
319
+
320
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
321
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
322
+ return tracking_points, trajectory_map
323
+
324
+ with gr.Row():
325
+ with gr.Column(scale=1):
326
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
327
+ add_drag_button = gr.Button(value="Add Drag")
328
+ reset_button = gr.Button(value="Reset")
329
+ run_button = gr.Button(value="Run")
330
+ delete_last_drag_button = gr.Button(value="Delete last drag")
331
+ delete_last_step_button = gr.Button(value="Delete last step")
332
+
333
+ with gr.Column(scale=7):
334
+ with gr.Row():
335
+ with gr.Column(scale=6):
336
+ input_image = gr.Image(label=None,
337
+ interactive=True,
338
+ height=320,
339
+ width=576,)
340
+ with gr.Column(scale=6):
341
+ output_image = gr.Image(label=None,
342
+ height=320,
343
+ width=576,)
344
+
345
+ with gr.Row():
346
+ with gr.Column(scale=1):
347
+ inference_batch_size = gr.Slider(label='Inference Batch Size',
348
+ minimum=1,
349
+ maximum=1,
350
+ step=1,
351
+ value=1)
352
+
353
+ motion_bucket_id = gr.Slider(label='Motion Bucket',
354
+ minimum=1,
355
+ maximum=100,
356
+ step=1,
357
+ value=4)
358
+
359
+ with gr.Column(scale=5):
360
+ output_video = gr.Image(label="Output Video",
361
+ height=320,
362
+ width=576,)
363
+
364
+ with gr.Row():
365
+ gr.Markdown("""
366
+ ## Citation
367
+ ```bibtex
368
+ @article{yin2023dragnuwa,
369
+ title={Dragnuwa: Fine-grained control in video generation by integrating text, image, and trajectory},
370
+ author={Yin, Shengming and Wu, Chenfei and Liang, Jian and Shi, Jie and Li, Houqiang and Ming, Gong and Duan, Nan},
371
+ journal={arXiv preprint arXiv:2308.08089},
372
+ year={2023}
373
+ }
374
+ ```
375
+ """)
376
+
377
+
378
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
379
+
380
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
381
+
382
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path], [tracking_points, input_image])
383
+
384
+ delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path], [tracking_points, input_image])
385
+
386
+ reset_button.click(reset_states, [first_frame_path, tracking_points], [first_frame_path, tracking_points])
387
+
388
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path], [tracking_points, input_image])
389
+
390
+ run_button.click(DragNUWA_net.run, [first_frame_path, tracking_points, inference_batch_size, motion_bucket_id], [output_image, output_video])
391
+
392
+ demo.launch(server_name="0.0.0.0", debug=True)