liuhuadai commited on
Commit
273dd2b
·
verified ·
1 Parent(s): 7c36308

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -208,34 +208,37 @@ seed_everything(seed, workers=True)
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
- diffusion_model = create_model_from_config(model_config)
212
- ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
213
  ## speed by torch.compile
214
  if args.compile:
215
- diffusion_model = torch.compile(diffusion_model)
216
 
217
  if args.pretrained_ckpt_path:
218
- copy_state_dict(diffusion_model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
219
 
220
  if args.remove_pretransform_weight_norm == "pre_load":
221
- remove_weight_norm_from_model(diffusion_model.pretransform)
222
 
223
 
224
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
225
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
226
- diffusion_model.pretransform.load_state_dict(load_vae_state)
227
 
228
  # Remove weight_norm from the pretransform if specified
229
  if args.remove_pretransform_weight_norm == "post_load":
230
- remove_weight_norm_from_model(diffusion_model.pretransform)
 
 
 
 
231
 
232
- diffusion_model.to(device)
233
 
234
  def get_video_duration(video_path):
235
  video = VideoFileClip(video_path)
236
  return video.duration
237
 
238
- a
239
  @spaces.GPU(duration=60)
240
  @torch.inference_mode()
241
  @torch.no_grad()
@@ -273,36 +276,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
273
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
274
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
275
  latent_seq_len = (int)(194/9*duration_sec)
276
- diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
277
 
278
  metadata = [preprocessed_data]
279
 
280
  batch_size = 1
281
  length = latent_seq_len
282
  with torch.amp.autocast(device):
283
- conditioning = diffusion_model.conditioner(metadata, device)
284
 
285
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
286
- conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
287
- conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
288
 
289
  yield "⏳ Inferring…", None
290
 
291
- cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
292
- noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
293
  with torch.amp.autocast(device):
294
- model = diffusion_model.model
295
- if diffusion_model.diffusion_objective == "v":
296
  fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
297
- elif diffusion_model.diffusion_objective == "rectified_flow":
298
  import time
299
  start_time = time.time()
300
  fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
301
  end_time = time.time()
302
  execution_time = end_time - start_time
303
  print(f"执行时间: {execution_time:.2f} 秒")
304
- if diffusion_model.pretransform is not None:
305
- fakes = diffusion_model.pretransform.decode(fakes)
306
 
307
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
308
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
 
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
+ model = create_model_from_config(model_config)
212
+
213
  ## speed by torch.compile
214
  if args.compile:
215
+ model = torch.compile(model)
216
 
217
  if args.pretrained_ckpt_path:
218
+ copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
219
 
220
  if args.remove_pretransform_weight_norm == "pre_load":
221
+ remove_weight_norm_from_model(model.pretransform)
222
 
223
 
224
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
225
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
226
+ model.pretransform.load_state_dict(load_vae_state)
227
 
228
  # Remove weight_norm from the pretransform if specified
229
  if args.remove_pretransform_weight_norm == "post_load":
230
+ remove_weight_norm_from_model(model.pretransform)
231
+ ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
232
+ training_wrapper = create_training_wrapper_from_config(model_config, model)
233
+ # 加载模型权重时根据设备选择map_location
234
+ training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
235
 
236
+ training_wrapper.to("cuda")
237
 
238
  def get_video_duration(video_path):
239
  video = VideoFileClip(video_path)
240
  return video.duration
241
 
 
242
  @spaces.GPU(duration=60)
243
  @torch.inference_mode()
244
  @torch.no_grad()
 
276
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
277
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
278
  latent_seq_len = (int)(194/9*duration_sec)
279
+ training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
280
 
281
  metadata = [preprocessed_data]
282
 
283
  batch_size = 1
284
  length = latent_seq_len
285
  with torch.amp.autocast(device):
286
+ conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
287
 
288
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
289
+ conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
290
+ conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
291
 
292
  yield "⏳ Inferring…", None
293
 
294
+ cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
295
+ noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
296
  with torch.amp.autocast(device):
297
+ model = training_wrapper.diffusion.model
298
+ if training_wrapper.diffusion_objective == "v":
299
  fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
300
+ elif training_wrapper.diffusion_objective == "rectified_flow":
301
  import time
302
  start_time = time.time()
303
  fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
304
  end_time = time.time()
305
  execution_time = end_time - start_time
306
  print(f"执行时间: {execution_time:.2f} 秒")
307
+ if training_wrapper.diffusion.pretransform is not None:
308
+ fakes = training_wrapper.diffusion.pretransform.decode(fakes)
309
 
310
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
311
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: