Spaces:
Runtime error
Runtime error
Update scripts/gradio/i2v_test_application.py
Browse files
scripts/gradio/i2v_test_application.py
CHANGED
@@ -30,7 +30,8 @@ def extract_frames(video_path):
|
|
30 |
# フレームをリストに追加
|
31 |
frame_list.append(frame)
|
32 |
frame_num += 1
|
33 |
-
|
|
|
34 |
# 動画ファイルを閉じる
|
35 |
cap.release()
|
36 |
|
@@ -80,12 +81,15 @@ class Image2Video():
|
|
80 |
|
81 |
@spaces.GPU(duration=100)
|
82 |
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
|
|
|
83 |
control_frames = extract_frames(frame_guides)
|
|
|
84 |
seed_everything(seed)
|
85 |
transform = transforms.Compose([
|
86 |
transforms.Resize(min(self.resolution)),
|
87 |
transforms.CenterCrop(self.resolution),
|
88 |
])
|
|
|
89 |
torch.cuda.empty_cache()
|
90 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
91 |
start = time.time()
|
@@ -103,7 +107,7 @@ class Image2Video():
|
|
103 |
# text cond
|
104 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
105 |
text_emb = model.get_learned_conditioning([prompt])
|
106 |
-
|
107 |
#control cond
|
108 |
if frame_guides is not None:
|
109 |
cn_videos = []
|
@@ -129,7 +133,7 @@ class Image2Video():
|
|
129 |
else:
|
130 |
cn_videos = None
|
131 |
|
132 |
-
|
133 |
|
134 |
# img cond
|
135 |
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
@@ -137,7 +141,7 @@ class Image2Video():
|
|
137 |
|
138 |
image_tensor_resized = transform(img_tensor) #3,h,w
|
139 |
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
140 |
-
|
141 |
# z = get_latent_z(model, videos) #bc,1,hw
|
142 |
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
143 |
|
@@ -156,7 +160,7 @@ class Image2Video():
|
|
156 |
img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
|
157 |
img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
|
158 |
|
159 |
-
|
160 |
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
|
161 |
img_emb = model.image_proj_model(cond_images)
|
162 |
|
@@ -164,7 +168,8 @@ class Image2Video():
|
|
164 |
|
165 |
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
|
166 |
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
|
167 |
-
|
|
|
168 |
## inference
|
169 |
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
|
170 |
|
|
|
30 |
# フレームをリストに追加
|
31 |
frame_list.append(frame)
|
32 |
frame_num += 1
|
33 |
+
|
34 |
+
print("load video length:", len(frame_list))
|
35 |
# 動画ファイルを閉じる
|
36 |
cap.release()
|
37 |
|
|
|
81 |
|
82 |
@spaces.GPU(duration=100)
|
83 |
def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
|
84 |
+
print("enter fn")
|
85 |
control_frames = extract_frames(frame_guides)
|
86 |
+
print("extract frames")
|
87 |
seed_everything(seed)
|
88 |
transform = transforms.Compose([
|
89 |
transforms.Resize(min(self.resolution)),
|
90 |
transforms.CenterCrop(self.resolution),
|
91 |
])
|
92 |
+
print("before empty cache")
|
93 |
torch.cuda.empty_cache()
|
94 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
95 |
start = time.time()
|
|
|
107 |
# text cond
|
108 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
109 |
text_emb = model.get_learned_conditioning([prompt])
|
110 |
+
print("before control")
|
111 |
#control cond
|
112 |
if frame_guides is not None:
|
113 |
cn_videos = []
|
|
|
133 |
else:
|
134 |
cn_videos = None
|
135 |
|
136 |
+
print("image cond")
|
137 |
|
138 |
# img cond
|
139 |
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
|
|
|
141 |
|
142 |
image_tensor_resized = transform(img_tensor) #3,h,w
|
143 |
videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
|
144 |
+
print("get latent z")
|
145 |
# z = get_latent_z(model, videos) #bc,1,hw
|
146 |
videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
|
147 |
|
|
|
160 |
img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
|
161 |
img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
|
162 |
|
163 |
+
print("image embedder")
|
164 |
cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
|
165 |
img_emb = model.image_proj_model(cond_images)
|
166 |
|
|
|
168 |
|
169 |
fs = torch.tensor([fs], dtype=torch.long, device=model.device)
|
170 |
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
|
171 |
+
|
172 |
+
print("before sample loop")
|
173 |
## inference
|
174 |
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
|
175 |
|