Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +19 -1
gradio_app.py
CHANGED
@@ -142,6 +142,16 @@ save_fps = 8
|
|
142 |
print("resolution:", resolution)
|
143 |
print("init done.")
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
@spaces.GPU(duration=200)
|
147 |
def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
|
@@ -172,6 +182,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
172 |
noise_shape = [batch_size, channels, frames, h, w]
|
173 |
|
174 |
# text cond
|
|
|
175 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
176 |
text_emb = model.get_learned_conditioning([prompt])
|
177 |
print("before control")
|
@@ -185,6 +196,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
185 |
|
186 |
#cn_tensor = (cn_tensor / 255. - 0.5) * 2
|
187 |
cn_tensor = ( cn_tensor/255.0 )
|
|
|
188 |
cn_tensor_resized = transform(cn_tensor) #3,h,w
|
189 |
|
190 |
cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
@@ -211,8 +223,10 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
211 |
|
212 |
# img cond
|
213 |
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
|
|
214 |
img_tensor = (img_tensor / 255. - 0.5) * 2
|
215 |
-
|
|
|
216 |
image_tensor_resized = transform(img_tensor) #3,h,w
|
217 |
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
218 |
print("get latent z")
|
@@ -222,6 +236,7 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
222 |
if image2 is not None:
|
223 |
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
|
224 |
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
|
|
|
225 |
image_tensor_resized2 = transform(img_tensor2) #3,h,w
|
226 |
videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
|
227 |
videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
@@ -263,6 +278,9 @@ def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, i
|
|
263 |
|
264 |
global result_dir
|
265 |
global save_fps
|
|
|
|
|
|
|
266 |
save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
|
267 |
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
|
268 |
model = model.cpu()
|
|
|
142 |
print("resolution:", resolution)
|
143 |
print("init done.")
|
144 |
|
145 |
+
def transpose_if_needed(tensor):
|
146 |
+
h = tensor.shape[-2]
|
147 |
+
w = tensor.shape[-1]
|
148 |
+
if h > w:
|
149 |
+
tensor = tensor.permute(0, 2, 1)
|
150 |
+
return tensor
|
151 |
+
|
152 |
+
def untranspose(tensor):
|
153 |
+
ndim = tensor.ndim
|
154 |
+
return tensor.transpose(ndim-1, ndim-2)
|
155 |
|
156 |
@spaces.GPU(duration=200)
|
157 |
def get_image(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
|
|
|
182 |
noise_shape = [batch_size, channels, frames, h, w]
|
183 |
|
184 |
# text cond
|
185 |
+
transposed = False
|
186 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
187 |
text_emb = model.get_learned_conditioning([prompt])
|
188 |
print("before control")
|
|
|
196 |
|
197 |
#cn_tensor = (cn_tensor / 255. - 0.5) * 2
|
198 |
cn_tensor = ( cn_tensor/255.0 )
|
199 |
+
cn_tensor = transpose_if_needed(cn_tensor)
|
200 |
cn_tensor_resized = transform(cn_tensor) #3,h,w
|
201 |
|
202 |
cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
|
|
223 |
|
224 |
# img cond
|
225 |
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
226 |
+
input_h, input_w = img_tensor.shape[1:]
|
227 |
img_tensor = (img_tensor / 255. - 0.5) * 2
|
228 |
+
img_tensor = transpose_if_needed(img_tensor)
|
229 |
+
|
230 |
image_tensor_resized = transform(img_tensor) #3,h,w
|
231 |
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
232 |
print("get latent z")
|
|
|
236 |
if image2 is not None:
|
237 |
img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device)
|
238 |
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
|
239 |
+
img_tensor2 = transpose_if_needed(img_tensor2)
|
240 |
image_tensor_resized2 = transform(img_tensor2) #3,h,w
|
241 |
videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
|
242 |
videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
|
|
278 |
|
279 |
global result_dir
|
280 |
global save_fps
|
281 |
+
if input_h > input_w:
|
282 |
+
batch_samples = untranspose(batch_samples)
|
283 |
+
|
284 |
save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
|
285 |
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
|
286 |
model = model.cpu()
|