Update gradio_app.py
Browse files- gradio_app.py +120 -133
gradio_app.py
CHANGED
@@ -126,141 +126,123 @@ def untranspose(tensor):
|
|
126 |
return tensor.transpose(ndim-1, ndim-2)
|
127 |
|
128 |
@spaces.GPU(duration=200)
|
129 |
-
def
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
seed_everything(seed)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
if steps > 60:
|
144 |
-
steps = 60
|
145 |
-
|
146 |
-
global model
|
147 |
-
# model = model_list[gpu_id]
|
148 |
-
model = model.cuda()
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
#
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
# #cn_tensor = (cn_tensor / 255. - 0.5) * 2
|
170 |
-
# cn_tensor = ( cn_tensor/255.0 )
|
171 |
-
# cn_tensor = transpose_if_needed(cn_tensor)
|
172 |
-
# cn_tensor_resized = transform(cn_tensor) #3,h,w
|
173 |
-
|
174 |
-
# cn_video = cn_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
175 |
-
# cn_videos.append(cn_video)
|
176 |
-
|
177 |
-
# cn_videos = torch.cat(cn_videos, dim=2)
|
178 |
-
# if cn_videos.shape[2] > frames:
|
179 |
-
# idxs = []
|
180 |
-
# for i in range(frames):
|
181 |
-
# index = int((i + 0.5) * cn_videos.shape[2] / frames)
|
182 |
-
# idxs.append(min(index, cn_videos.shape[2] - 1))
|
183 |
-
# cn_videos = cn_videos[:, :, idxs, :, :]
|
184 |
-
# print("cn_videos.shape after slicing", cn_videos.shape)
|
185 |
-
# model_list = []
|
186 |
-
# for model in model_list:
|
187 |
-
# model.control_scale = control_scale
|
188 |
-
# model_list.append(model)
|
189 |
-
|
190 |
-
# else:
|
191 |
-
cn_videos = None
|
192 |
-
|
193 |
-
print("image cond")
|
194 |
-
|
195 |
-
# img cond
|
196 |
-
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
197 |
-
input_h, input_w = img_tensor.shape[1:]
|
198 |
-
img_tensor = (img_tensor / 255. - 0.5) * 2
|
199 |
-
img_tensor = transpose_if_needed(img_tensor)
|
200 |
-
|
201 |
-
image_tensor_resized = transform(img_tensor) #3,h,w
|
202 |
-
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
203 |
-
print("get latent z")
|
204 |
-
# z = get_latent_z(model, videos) #bc,1,hw
|
205 |
-
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
206 |
-
|
207 |
-
if sketch is not None:
|
208 |
-
img_tensor2 = torch.from_numpy(sketch).permute(2, 0, 1).float().to(model.device)
|
209 |
-
img_tensor2 = (img_tensor2 / 255. - 0.5) * 2
|
210 |
-
img_tensor2 = transpose_if_needed(img_tensor2)
|
211 |
-
image_tensor_resized2 = transform(img_tensor2) #3,h,w
|
212 |
-
videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw
|
213 |
-
videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
214 |
-
|
215 |
-
videos = torch.cat([videos, videos2], dim=2)
|
216 |
else:
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
print("
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
global result_dir
|
251 |
-
global save_fps
|
252 |
-
if input_h > input_w:
|
253 |
-
batch_samples = untranspose(batch_samples)
|
254 |
-
|
255 |
-
save_videos(batch_samples, result_dir, filenames=[prompt_str], fps=save_fps)
|
256 |
-
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
|
257 |
-
model = model.cpu()
|
258 |
-
saved_result_dir = os.path.join(result_dir, f"{prompt_str}.mp4")
|
259 |
-
print("result saved to:", saved_result_dir)
|
260 |
-
return saved_result_dir
|
261 |
|
|
|
|
|
262 |
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
|
266 |
|
@@ -314,17 +296,22 @@ def dynamicrafter_demo(result_dir='./tmp/', res=1024):
|
|
314 |
i2v_end_btn = gr.Button("Generate")
|
315 |
with gr.Column():
|
316 |
with gr.Row():
|
317 |
-
|
318 |
with gr.Row():
|
319 |
i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
321 |
gr.Examples(examples=i2v_examples_interp_1024,
|
322 |
-
inputs=[i2v_input_image,
|
323 |
outputs=[i2v_output_video],
|
324 |
fn = get_image,
|
325 |
cache_examples=False,
|
326 |
)
|
327 |
-
i2v_end_btn.click(inputs=[i2v_input_image,
|
328 |
outputs=[i2v_output_video],
|
329 |
fn = get_image
|
330 |
)
|
|
|
126 |
return tensor.transpose(ndim-1, ndim-2)
|
127 |
|
128 |
@spaces.GPU(duration=200)
|
129 |
+
def image_guided_synthesis(model, prompts, image1, image2, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
|
130 |
+
unconditional_guidance_scale=1.0, cfg_img=None, fs=None, seed=123, text_input=False, multiple_cond_cfg=False, \
|
131 |
+
loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
|
132 |
+
|
133 |
seed_everything(seed)
|
134 |
+
# image1 = Image.open(file_list[2*idx]).convert('RGB')
|
135 |
+
image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
|
136 |
+
# image2 = Image.open(file_list[2*idx+1]).convert('RGB')
|
137 |
+
image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
|
138 |
+
frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=15)
|
139 |
+
frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=1)
|
140 |
+
videos = torch.cat([frame_tensor1, frame_tensor2], dim=1)
|
141 |
+
# frame_tensor = torch.cat([frame_tensor1, frame_tensor1], dim=1)
|
142 |
+
# _, filename = os.path.split(file_list[idx*2])
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
|
145 |
+
batch_size = noise_shape[0]
|
146 |
+
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
147 |
+
|
148 |
+
if not text_input:
|
149 |
+
prompts = [""]*batch_size
|
150 |
+
|
151 |
+
img = videos[:,:,0] #bchw
|
152 |
+
img_emb = model.embedder(img) ## blc
|
153 |
+
img_emb = model.image_proj_model(img_emb)
|
154 |
+
|
155 |
+
cond_emb = model.get_learned_conditioning(prompts)
|
156 |
+
cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
|
157 |
+
if model.model.conditioning_key == 'hybrid':
|
158 |
+
z, hs = get_latent_z_with_hidden_states(model, videos) # b c t h w
|
159 |
+
if loop or interp:
|
160 |
+
img_cat_cond = torch.zeros_like(z)
|
161 |
+
img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
|
162 |
+
img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
else:
|
164 |
+
img_cat_cond = z[:,:,:1,:,:]
|
165 |
+
img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
|
166 |
+
cond["c_concat"] = [img_cat_cond] # b c 1 h w
|
167 |
+
|
168 |
+
if unconditional_guidance_scale != 1.0:
|
169 |
+
if model.uncond_type == "empty_seq":
|
170 |
+
prompts = batch_size * [""]
|
171 |
+
uc_emb = model.get_learned_conditioning(prompts)
|
172 |
+
elif model.uncond_type == "zero_embed":
|
173 |
+
uc_emb = torch.zeros_like(cond_emb)
|
174 |
+
uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
|
175 |
+
uc_img_emb = model.image_proj_model(uc_img_emb)
|
176 |
+
uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
|
177 |
+
if model.model.conditioning_key == 'hybrid':
|
178 |
+
uc["c_concat"] = [img_cat_cond]
|
179 |
+
else:
|
180 |
+
uc = None
|
181 |
+
#
|
182 |
+
# for i, h in enumerate(hs):
|
183 |
+
# print("h:", h.shape)
|
184 |
+
# hs[i] = hs[i][:,:,0,:,:].unsqueeze(2)
|
185 |
+
additional_decode_kwargs = {'ref_context': hs}
|
186 |
+
# additional_decode_kwargs = {'ref_context': None}
|
187 |
+
|
188 |
+
## we need one more unconditioning image=yes, text=""
|
189 |
+
if multiple_cond_cfg and cfg_img != 1.0:
|
190 |
+
uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
|
191 |
+
if model.model.conditioning_key == 'hybrid':
|
192 |
+
uc_2["c_concat"] = [img_cat_cond]
|
193 |
+
kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
|
194 |
+
else:
|
195 |
+
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
+
z0 = None
|
198 |
+
cond_mask = None
|
199 |
|
200 |
+
batch_variants = []
|
201 |
+
for _ in range(n_samples):
|
202 |
+
|
203 |
+
if z0 is not None:
|
204 |
+
cond_z0 = z0.clone()
|
205 |
+
kwargs.update({"clean_cond": True})
|
206 |
+
else:
|
207 |
+
cond_z0 = None
|
208 |
+
if ddim_sampler is not None:
|
209 |
+
|
210 |
+
samples, _ = ddim_sampler.sample(S=ddim_steps,
|
211 |
+
conditioning=cond,
|
212 |
+
batch_size=batch_size,
|
213 |
+
shape=noise_shape[1:],
|
214 |
+
verbose=False,
|
215 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
216 |
+
unconditional_conditioning=uc,
|
217 |
+
eta=ddim_eta,
|
218 |
+
cfg_img=cfg_img,
|
219 |
+
mask=cond_mask,
|
220 |
+
x0=cond_z0,
|
221 |
+
fs=fs,
|
222 |
+
timestep_spacing=timestep_spacing,
|
223 |
+
guidance_rescale=guidance_rescale,
|
224 |
+
**kwargs
|
225 |
+
)
|
226 |
+
|
227 |
+
## reconstruct from latent to pixel space
|
228 |
+
batch_images = model.decode_first_stage(samples, **additional_decode_kwargs)
|
229 |
+
|
230 |
+
index = list(range(samples.shape[2]))
|
231 |
+
del index[1]
|
232 |
+
del index[-2]
|
233 |
+
samples = samples[:,:,index,:,:]
|
234 |
+
## reconstruct from latent to pixel space
|
235 |
+
batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs)
|
236 |
+
batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2]
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
batch_variants.append(batch_images)
|
241 |
+
## variants, batch, c, t, h, w
|
242 |
+
batch_variants = torch.stack(batch_variants)
|
243 |
+
return batch_variants.permute(1, 0, 2, 3, 4, 5)
|
244 |
+
|
245 |
+
|
246 |
|
247 |
|
248 |
|
|
|
296 |
i2v_end_btn = gr.Button("Generate")
|
297 |
with gr.Column():
|
298 |
with gr.Row():
|
299 |
+
i2v_input_image2 = gr.Image(label="Input Image 2",elem_id="input_img2")
|
300 |
with gr.Row():
|
301 |
i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
|
302 |
|
303 |
+
|
304 |
+
# s(model, prompts, image1, image2, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
|
305 |
+
# unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, \
|
306 |
+
# loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
|
307 |
+
|
308 |
gr.Examples(examples=i2v_examples_interp_1024,
|
309 |
+
inputs=inputs=[i2v_input_image, i2v_input_text, i2v_input_image, i2v_input_image2, [72, 108], 1, i2v_steps, i2v_eta, 1.0, None, i2v_motion, i2v_seed],
|
310 |
outputs=[i2v_output_video],
|
311 |
fn = get_image,
|
312 |
cache_examples=False,
|
313 |
)
|
314 |
+
i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_input_image, i2v_input_image2, [72, 108], 1, i2v_steps, i2v_eta, 1.0, None, i2v_motion, i2v_seed],
|
315 |
outputs=[i2v_output_video],
|
316 |
fn = get_image
|
317 |
)
|