fbnnb commited on
Commit
f3e748e
·
verified ·
1 Parent(s): caea331

Update scripts/gradio/i2v_test_application.py

Browse files
scripts/gradio/i2v_test_application.py CHANGED
@@ -30,7 +30,8 @@ def extract_frames(video_path):
30
  # フレームをリストに追加
31
  frame_list.append(frame)
32
  frame_num += 1
33
-
 
34
  # 動画ファイルを閉じる
35
  cap.release()
36
 
@@ -80,12 +81,15 @@ class Image2Video():
80
 
81
  @spaces.GPU(duration=100)
82
  def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
 
83
  control_frames = extract_frames(frame_guides)
 
84
  seed_everything(seed)
85
  transform = transforms.Compose([
86
  transforms.Resize(min(self.resolution)),
87
  transforms.CenterCrop(self.resolution),
88
  ])
 
89
  torch.cuda.empty_cache()
90
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
91
  start = time.time()
@@ -103,7 +107,7 @@ class Image2Video():
103
  # text cond
104
  with torch.no_grad(), torch.cuda.amp.autocast():
105
  text_emb = model.get_learned_conditioning([prompt])
106
-
107
  #control cond
108
  if frame_guides is not None:
109
  cn_videos = []
@@ -129,7 +133,7 @@ class Image2Video():
129
  else:
130
  cn_videos = None
131
 
132
-
133
 
134
  # img cond
135
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
@@ -137,7 +141,7 @@ class Image2Video():
137
 
138
  image_tensor_resized = transform(img_tensor) #3,h,w
139
  videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
140
-
141
  # z = get_latent_z(model, videos) #bc,1,hw
142
  videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
143
 
@@ -156,7 +160,7 @@ class Image2Video():
156
  img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
157
  img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
158
 
159
-
160
  cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
161
  img_emb = model.image_proj_model(cond_images)
162
 
@@ -164,7 +168,8 @@ class Image2Video():
164
 
165
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
166
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
167
-
 
168
  ## inference
169
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
170
 
 
30
  # フレームをリストに追加
31
  frame_list.append(frame)
32
  frame_num += 1
33
+
34
+ print("load video length:", len(frame_list))
35
  # 動画ファイルを閉じる
36
  cap.release()
37
 
 
81
 
82
  @spaces.GPU(duration=100)
83
  def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None, frame_guides=None,control_scale=0.6):
84
+ print("enter fn")
85
  control_frames = extract_frames(frame_guides)
86
+ print("extract frames")
87
  seed_everything(seed)
88
  transform = transforms.Compose([
89
  transforms.Resize(min(self.resolution)),
90
  transforms.CenterCrop(self.resolution),
91
  ])
92
+ print("before empty cache")
93
  torch.cuda.empty_cache()
94
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
95
  start = time.time()
 
107
  # text cond
108
  with torch.no_grad(), torch.cuda.amp.autocast():
109
  text_emb = model.get_learned_conditioning([prompt])
110
+ print("before control")
111
  #control cond
112
  if frame_guides is not None:
113
  cn_videos = []
 
133
  else:
134
  cn_videos = None
135
 
136
+ print("image cond")
137
 
138
  # img cond
139
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
 
141
 
142
  image_tensor_resized = transform(img_tensor) #3,h,w
143
  videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw
144
+ print("get latent z")
145
  # z = get_latent_z(model, videos) #bc,1,hw
146
  videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2)
147
 
 
160
  img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:]
161
  img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:]
162
 
163
+ print("image embedder")
164
  cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
165
  img_emb = model.image_proj_model(cond_images)
166
 
 
168
 
169
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
170
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat], "control_cond": cn_videos}
171
+
172
+ print("before sample loop")
173
  ## inference
174
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs)
175