Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +11 -8
gradio_app.py
CHANGED
@@ -212,14 +212,17 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
212 |
# z = get_latent_z(model, videos) #bc,1,hw
|
213 |
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
z, hs = get_latent_z_with_hidden_states(model, videos)
|
224 |
|
225 |
img_tensor_repeat = torch.zeros_like(z)
|
|
|
212 |
# z = get_latent_z(model, videos) #bc,1,hw
|
213 |
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
214 |
|
215 |
+
if image2 is not None:
|
216 |
+
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
|
217 |
+
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
|
218 |
+
image_tensor_resized2 = transform(img_tensor2) #3,h,w
|
219 |
+
videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
|
220 |
+
videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
221 |
+
|
222 |
+
videos = torch.cat([videos, videos2], dim=2)
|
223 |
+
else:
|
224 |
+
videos = torch.cat([videos, videos], dim=2)
|
225 |
+
|
226 |
z, hs = get_latent_z_with_hidden_states(model, videos)
|
227 |
|
228 |
img_tensor_repeat = torch.zeros_like(z)
|