fbnnb commited on
Commit
349445d
Β·
verified Β·
1 Parent(s): 418cea4

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +120 -133
gradio_app.py CHANGED
@@ -126,141 +126,123 @@ def untranspose(tensor):
126
  return tensor.transpose(ndim-1, ndim-2)
127
 
128
  @spaces.GPU(duration=200)
129
- def get_image(image, sketch, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, control_scale=0.6):
130
- print("enter fn")
131
- # control_frames = extract_frames(frame_guides)
132
- print("extract frames")
133
  seed_everything(seed)
134
- transform = transforms.Compose([
135
- transforms.Resize(min(resolution)),
136
- transforms.CenterCrop(resolution),
137
- ])
138
- print("before empty cache")
139
- torch.cuda.empty_cache()
140
- print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
141
- start = time.time()
142
- gpu_id=0
143
- if steps > 60:
144
- steps = 60
145
-
146
- global model
147
- # model = model_list[gpu_id]
148
- model = model.cuda()
149
 
150
- batch_size=1
151
- channels = model.model.diffusion_model.out_channels
152
- frames = model.temporal_length
153
- h, w = resolution[0] // 8, resolution[1] // 8
154
- noise_shape = [batch_size, channels, frames, h, w]
155
-
156
- # text cond
157
- transposed = False
158
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
159
- text_emb = model.get_learned_conditioning([prompt])
160
- print("before control")
161
- #control cond
162
- # if frame_guides is not None:
163
- # cn_videos = []
164
- # for frame in control_frames:
165
- # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
166
- # frame = cv2.bitwise_not(frame)
167
- # cn_tensor = torch.from_numpy(frame).unsqueeze(2).permute(2, 0, 1).float().to(model.device)
168
-
169
- # #cn_tensor = (cn_tensor / 255. - 0.5) * 2
170
- # cn_tensor = ( cn_tensor/255.0 )
171
- # cn_tensor = transpose_if_needed(cn_tensor)
172
- # cn_tensor_resized = transform(cn_tensor) #3,h,w
173
-
174
- # cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
175
- # cn_videos.append(cn_video)
176
-
177
- # cn_videos = torch.cat(cn_videos, dim=2)
178
- # if cn_videos.shape[2] > frames:
179
- # idxs = []
180
- # for i in range(frames):
181
- # index = int((i + 0.5) * cn_videos.shape[2] / frames)
182
- # idxs.append(min(index, cn_videos.shape[2] - 1))
183
- # cn_videos = cn_videos[:, :, idxs, :, :]
184
- # print("cn_videos.shape after slicing", cn_videos.shape)
185
- # model_list = []
186
- # for model in model_list:
187
- # model.control_scale = control_scale
188
- # model_list.append(model)
189
-
190
- # else:
191
- cn_videos = None
192
-
193
- print("image cond")
194
-
195
- # img cond
196
- img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
197
- input_h, input_w = img_tensor.shape[1:]
198
- img_tensor = (img_tensor / 255. - 0.5) * 2
199
- img_tensor = transpose_if_needed(img_tensor)
200
-
201
- image_tensor_resized = transform(img_tensor) #3,h,w
202
- videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
203
- print("get latent z")
204
- # z = get_latent_z(model, videos) #bc,1,hw
205
- videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
206
-
207
- if sketch is not None:
208
- img_tensor2 = torch.from_numpy(sketch).permute(2, 0, 1).float().to(model.device)
209
- img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
210
- img_tensor2 = transpose_if_needed(img_tensor2)
211
- image_tensor_resized2 = transform(img_tensor2) #3,h,w
212
- videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
213
- videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
214
-
215
- videos = torch.cat([videos, videos2], dim=2)
216
  else:
217
- videos = torch.cat([videos, videos], dim=2)
218
-
219
- z, hs = get_latent_z_with_hidden_states(model, videos)
220
-
221
- img_tensor_repeat = torch.zeros_like(z)
222
-
223
- img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
224
- img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
225
-
226
- print("image embedder")
227
- cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
228
- img_emb = model.image_proj_model(cond_images)
229
-
230
- imtext_cond = torch.cat([text_emb, img_emb], dim=1)
231
-
232
- fs = torch.tensor([fs], dtype=torch.long, device=model.device)
233
- # print("cn videos:",cn_videos.shape, "img emb:", img_emb.shape)
234
- cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
235
-
236
- print("before sample loop")
237
- ## inference
238
- batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
239
-
240
- ## remove the last frame
241
- # if image2 is None:
242
- batch_samples = batch_samples[:,:,:,:-1,...]
243
- ## b,samples,c,t,h,w
244
- prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
245
- prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
246
- prompt_str=prompt_str[:40]
247
- if len(prompt_str) == 0:
248
- prompt_str = 'empty_prompt'
249
-
250
- global result_dir
251
- global save_fps
252
- if input_h > input_w:
253
- batch_samples = untranspose(batch_samples)
254
-
255
- save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
256
- print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
257
- model = model.cpu()
258
- saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
259
- print("result saved to:", saved_result_dir)
260
- return saved_result_dir
261
 
 
 
262
 
263
- # @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
 
@@ -314,17 +296,22 @@ def dynamicrafter_demo(result_dir='./tmp/', res=1024):
314
  i2v_end_btn = gr.Button("Generate")
315
  with gr.Column():
316
  with gr.Row():
317
- i2v_input_sketch = gr.Image(label="Input End SKetch",elem_id="input_img2")
318
  with gr.Row():
319
  i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
320
 
 
 
 
 
 
321
  gr.Examples(examples=i2v_examples_interp_1024,
322
- inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
323
  outputs=[i2v_output_video],
324
  fn = get_image,
325
  cache_examples=False,
326
  )
327
- i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_sketch, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, control_scale],
328
  outputs=[i2v_output_video],
329
  fn = get_image
330
  )
 
126
  return tensor.transpose(ndim-1, ndim-2)
127
 
128
  @spaces.GPU(duration=200)
129
+ def image_guided_synthesis(model, prompts, image1, image2, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
130
+ unconditional_guidance_scale=1.0, cfg_img=None, fs=None, seed=123, text_input=False, multiple_cond_cfg=False, \
131
+ loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
132
+
133
  seed_everything(seed)
134
+ # image1 = Image.open(file_list[2*idx]).convert('RGB')
135
+ image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
136
+ # image2 = Image.open(file_list[2*idx+1]).convert('RGB')
137
+ image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
138
+ frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=15)
139
+ frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=1)
140
+ videos = torch.cat([frame_tensor1, frame_tensor2], dim=1)
141
+ # frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
142
+ # _, filename = os.path.split(file_list[idx*2])
 
 
 
 
 
 
143
 
144
+ ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
145
+ batch_size = noise_shape[0]
146
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
147
+
148
+ if not text_input:
149
+ prompts = [""]*batch_size
150
+
151
+ img = videos[:,:,0] #bchw
152
+ img_emb = model.embedder(img) ## blc
153
+ img_emb = model.image_proj_model(img_emb)
154
+
155
+ cond_emb = model.get_learned_conditioning(prompts)
156
+ cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
157
+ if model.model.conditioning_key == 'hybrid':
158
+ z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
159
+ if loop or interp:
160
+ img_cat_cond = torch.zeros_like(z)
161
+ img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
162
+ img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  else:
164
+ img_cat_cond = z[:,:,:1,:,:]
165
+ img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
166
+ cond["c_concat"] = [img_cat_cond] # b c 1 h w
167
+
168
+ if unconditional_guidance_scale != 1.0:
169
+ if model.uncond_type == "empty_seq":
170
+ prompts = batch_size * [""]
171
+ uc_emb = model.get_learned_conditioning(prompts)
172
+ elif model.uncond_type == "zero_embed":
173
+ uc_emb = torch.zeros_like(cond_emb)
174
+ uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
175
+ uc_img_emb = model.image_proj_model(uc_img_emb)
176
+ uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
177
+ if model.model.conditioning_key == 'hybrid':
178
+ uc["c_concat"] = [img_cat_cond]
179
+ else:
180
+ uc = None
181
+ #
182
+ # for i, h in enumerate(hs):
183
+ # print("h:", h.shape)
184
+ # hs[i] = hs[i][:,:,0,:,:].unsqueeze(2)
185
+ additional_decode_kwargs = {'ref_context': hs}
186
+ # additional_decode_kwargs = {'ref_context': None}
187
+
188
+ ## we need one more unconditioning image=yes, text=""
189
+ if multiple_cond_cfg and cfg_img != 1.0:
190
+ uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
191
+ if model.model.conditioning_key == 'hybrid':
192
+ uc_2["c_concat"] = [img_cat_cond]
193
+ kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
194
+ else:
195
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ z0 = None
198
+ cond_mask = None
199
 
200
+ batch_variants = []
201
+ for _ in range(n_samples):
202
+
203
+ if z0 is not None:
204
+ cond_z0 = z0.clone()
205
+ kwargs.update({"clean_cond": True})
206
+ else:
207
+ cond_z0 = None
208
+ if ddim_sampler is not None:
209
+
210
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
211
+ conditioning=cond,
212
+ batch_size=batch_size,
213
+ shape=noise_shape[1:],
214
+ verbose=False,
215
+ unconditional_guidance_scale=unconditional_guidance_scale,
216
+ unconditional_conditioning=uc,
217
+ eta=ddim_eta,
218
+ cfg_img=cfg_img,
219
+ mask=cond_mask,
220
+ x0=cond_z0,
221
+ fs=fs,
222
+ timestep_spacing=timestep_spacing,
223
+ guidance_rescale=guidance_rescale,
224
+ **kwargs
225
+ )
226
+
227
+ ## reconstruct from latent to pixel space
228
+ batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
229
+
230
+ index = list(range(samples.shape[2]))
231
+ del index[1]
232
+ del index[-2]
233
+ samples = samples[:,:,index,:,:]
234
+ ## reconstruct from latent to pixel space
235
+ batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
236
+ batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
237
+
238
+
239
+
240
+ batch_variants.append(batch_images)
241
+ ## variants, batch, c, t, h, w
242
+ batch_variants = torch.stack(batch_variants)
243
+ return batch_variants.permute(1, 0, 2, 3, 4, 5)
244
+
245
+
246
 
247
 
248
 
 
296
  i2v_end_btn = gr.Button("Generate")
297
  with gr.Column():
298
  with gr.Row():
299
+ i2v_input_image2 = gr.Image(label="Input Image 2",elem_id="input_img2")
300
  with gr.Row():
301
  i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
302
 
303
+
304
+ # s(model, prompts, image1, image2, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
305
+ # unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, \
306
+ # loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
307
+
308
  gr.Examples(examples=i2v_examples_interp_1024,
309
+ inputs=inputs=[i2v_input_image, i2v_input_text, i2v_input_image, i2v_input_image2, [72, 108], 1, i2v_steps, i2v_eta, 1.0, None, i2v_motion, i2v_seed],
310
  outputs=[i2v_output_video],
311
  fn = get_image,
312
  cache_examples=False,
313
  )
314
+ i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_input_image, i2v_input_image2, [72, 108], 1, i2v_steps, i2v_eta, 1.0, None, i2v_motion, i2v_seed],
315
  outputs=[i2v_output_video],
316
  fn = get_image
317
  )