liuhuadai commited on
Commit
cdc4845
·
verified ·
1 Parent(s): 9f03cb9

Optimize memory usage

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -208,32 +208,28 @@ seed_everything(seed, workers=True)
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
- model = create_model_from_config(model_config)
212
-
213
  ## speed by torch.compile
214
  if args.compile:
215
- model = torch.compile(model)
216
 
217
  if args.pretrained_ckpt_path:
218
- copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
219
 
220
  if args.remove_pretransform_weight_norm == "pre_load":
221
- remove_weight_norm_from_model(model.pretransform)
222
 
223
 
224
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
225
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
226
- model.pretransform.load_state_dict(load_vae_state)
227
 
228
  # Remove weight_norm from the pretransform if specified
229
  if args.remove_pretransform_weight_norm == "post_load":
230
- remove_weight_norm_from_model(model.pretransform)
231
- ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
232
- training_wrapper = create_training_wrapper_from_config(model_config, model)
233
- # 加载模型权重时根据设备选择map_location
234
- training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
235
 
236
- training_wrapper.to("cuda")
237
 
238
  def get_video_duration(video_path):
239
  video = VideoFileClip(video_path)
@@ -276,36 +272,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
276
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
277
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
278
  latent_seq_len = (int)(194/9*duration_sec)
279
- training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
280
 
281
  metadata = [preprocessed_data]
282
 
283
  batch_size = 1
284
  length = latent_seq_len
285
  with torch.amp.autocast(device):
286
- conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
287
 
288
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
289
- conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
290
- conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
291
 
292
  yield "⏳ Inferring…", None
293
 
294
- cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
295
- noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
296
  with torch.amp.autocast(device):
297
- model = training_wrapper.diffusion.model
298
- if training_wrapper.diffusion_objective == "v":
299
  fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
300
- elif training_wrapper.diffusion_objective == "rectified_flow":
301
  import time
302
  start_time = time.time()
303
  fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
304
  end_time = time.time()
305
  execution_time = end_time - start_time
306
  print(f"执行时间: {execution_time:.2f} 秒")
307
- if training_wrapper.diffusion.pretransform is not None:
308
- fakes = training_wrapper.diffusion.pretransform.decode(fakes)
309
 
310
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
311
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
 
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
+ diffusion_model = create_model_from_config(model_config)
212
+ ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
213
  ## speed by torch.compile
214
  if args.compile:
215
+ diffusion_model = torch.compile(diffusion_model)
216
 
217
  if args.pretrained_ckpt_path:
218
+ copy_state_dict(diffusion_model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
219
 
220
  if args.remove_pretransform_weight_norm == "pre_load":
221
+ remove_weight_norm_from_model(diffusion_model.pretransform)
222
 
223
 
224
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
225
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
226
+ diffusion_model.pretransform.load_state_dict(load_vae_state)
227
 
228
  # Remove weight_norm from the pretransform if specified
229
  if args.remove_pretransform_weight_norm == "post_load":
230
+ remove_weight_norm_from_model(diffusion_model.pretransform)
 
 
 
 
231
 
232
+ diffusion_model.to(device)
233
 
234
  def get_video_duration(video_path):
235
  video = VideoFileClip(video_path)
 
272
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
273
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
274
  latent_seq_len = (int)(194/9*duration_sec)
275
+ diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
276
 
277
  metadata = [preprocessed_data]
278
 
279
  batch_size = 1
280
  length = latent_seq_len
281
  with torch.amp.autocast(device):
282
+ conditioning = diffusion_model.conditioner(metadata, device)
283
 
284
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
285
+ conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
286
+ conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
287
 
288
  yield "⏳ Inferring…", None
289
 
290
+ cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
291
+ noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
292
  with torch.amp.autocast(device):
293
+ model = diffusion_model.model
294
+ if diffusion_model.diffusion_objective == "v":
295
  fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
296
+ elif diffusion_model.diffusion_objective == "rectified_flow":
297
  import time
298
  start_time = time.time()
299
  fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
300
  end_time = time.time()
301
  execution_time = end_time - start_time
302
  print(f"执行时间: {execution_time:.2f} 秒")
303
+ if diffusion_model.pretransform is not None:
304
+ fakes = diffusion_model.pretransform.decode(fakes)
305
 
306
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
307
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: