fbnnb commited on
Commit
9b61a85
·
verified ·
1 Parent(s): 8a8a66c

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +19 -1
gradio_app.py CHANGED
@@ -142,6 +142,16 @@ save_fps = 8
142
  print("resolution:", resolution)
143
  print("init done.")
144
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  @spaces.GPU(duration=200)
147
  def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
@@ -172,6 +182,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
172
  noise_shape = [batch_size, channels, frames, h, w]
173
 
174
  # text cond
 
175
  with torch.no_grad(), torch.cuda.amp.autocast():
176
  text_emb = model.get_learned_conditioning([prompt])
177
  print("before control")
@@ -185,6 +196,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
185
 
186
  #cn_tensor = (cn_tensor / 255. - 0.5) * 2
187
  cn_tensor = ( cn_tensor/255.0 )
 
188
  cn_tensor_resized = transform(cn_tensor) #3,h,w
189
 
190
  cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
@@ -211,8 +223,10 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
211
 
212
  # img cond
213
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
 
214
  img_tensor = (img_tensor / 255. - 0.5) * 2
215
-
 
216
  image_tensor_resized = transform(img_tensor) #3,h,w
217
  videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
218
  print("get latent z")
@@ -222,6 +236,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
222
  if image2 is not None:
223
  img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
224
  img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
 
225
  image_tensor_resized2 = transform(img_tensor2) #3,h,w
226
  videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
227
  videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
@@ -263,6 +278,9 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
263
 
264
  global result_dir
265
  global save_fps
 
 
 
266
  save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
267
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
268
  model = model.cpu()
 
142
  print("resolution:", resolution)
143
  print("init done.")
144
 
145
+ def transpose_if_needed(tensor):
146
+ h = tensor.shape[-2]
147
+ w = tensor.shape[-1]
148
+ if h > w:
149
+ tensor = tensor.permute(0, 2, 1)
150
+ return tensor
151
+
152
+ def untranspose(tensor):
153
+ ndim = tensor.ndim
154
+ return tensor.transpose(ndim-1, ndim-2)
155
 
156
  @spaces.GPU(duration=200)
157
  def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
 
182
  noise_shape = [batch_size, channels, frames, h, w]
183
 
184
  # text cond
185
+ transposed = False
186
  with torch.no_grad(), torch.cuda.amp.autocast():
187
  text_emb = model.get_learned_conditioning([prompt])
188
  print("before control")
 
196
 
197
  #cn_tensor = (cn_tensor / 255. - 0.5) * 2
198
  cn_tensor = ( cn_tensor/255.0 )
199
+ cn_tensor = transpose_if_needed(cn_tensor)
200
  cn_tensor_resized = transform(cn_tensor) #3,h,w
201
 
202
  cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
 
223
 
224
  # img cond
225
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
226
+ input_h, input_w = img_tensor.shape[1:]
227
  img_tensor = (img_tensor / 255. - 0.5) * 2
228
+ img_tensor = transpose_if_needed(img_tensor)
229
+
230
  image_tensor_resized = transform(img_tensor) #3,h,w
231
  videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
232
  print("get latent z")
 
236
  if image2 is not None:
237
  img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
238
  img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
239
+ img_tensor2 = transpose_if_needed(img_tensor2)
240
  image_tensor_resized2 = transform(img_tensor2) #3,h,w
241
  videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
242
  videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
 
278
 
279
  global result_dir
280
  global save_fps
281
+ if input_h > input_w:
282
+ batch_samples = untranspose(batch_samples)
283
+
284
  save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
285
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
286
  model = model.cpu()