liuhuadai commited on
Commit
f6a5ed7
·
verified ·
1 Parent(s): 273dd2b

Optimize memory usage

Browse files
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -208,32 +208,19 @@ seed_everything(seed, workers=True)
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
- model = create_model_from_config(model_config)
 
 
 
212
 
213
  ## speed by torch.compile
214
  if args.compile:
215
- model = torch.compile(model)
216
-
217
- if args.pretrained_ckpt_path:
218
- copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
219
-
220
- if args.remove_pretransform_weight_norm == "pre_load":
221
- remove_weight_norm_from_model(model.pretransform)
222
 
223
 
224
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
225
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
226
- model.pretransform.load_state_dict(load_vae_state)
227
-
228
- # Remove weight_norm from the pretransform if specified
229
- if args.remove_pretransform_weight_norm == "post_load":
230
- remove_weight_norm_from_model(model.pretransform)
231
- ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
232
- training_wrapper = create_training_wrapper_from_config(model_config, model)
233
- # 加载模型权重时根据设备选择map_location
234
- training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
235
-
236
- training_wrapper.to("cuda")
237
 
238
  def get_video_duration(video_path):
239
  video = VideoFileClip(video_path)
@@ -276,36 +263,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
276
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
277
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
278
  latent_seq_len = (int)(194/9*duration_sec)
279
- training_wrapper.diffusion.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
280
 
281
  metadata = [preprocessed_data]
282
 
283
  batch_size = 1
284
  length = latent_seq_len
285
  with torch.amp.autocast(device):
286
- conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
287
 
288
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
289
- conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
290
- conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
291
 
292
  yield "⏳ Inferring…", None
293
 
294
- cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
295
- noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
296
  with torch.amp.autocast(device):
297
- model = training_wrapper.diffusion.model
298
- if training_wrapper.diffusion_objective == "v":
299
- fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
300
- elif training_wrapper.diffusion_objective == "rectified_flow":
301
  import time
302
  start_time = time.time()
303
- fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
304
  end_time = time.time()
305
  execution_time = end_time - start_time
306
- print(f"执行时间: {execution_time:.2f} 秒")
307
- if training_wrapper.diffusion.pretransform is not None:
308
- fakes = training_wrapper.diffusion.pretransform.decode(fakes)
 
309
 
310
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
311
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
 
208
  with open("ThinkSound/configs/model_configs/thinksound.json") as f:
209
  model_config = json.load(f)
210
 
211
+ diffusion_model = create_model_from_config(model_config)
212
+ ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
213
+ diffusion_model.load_state_dict(torch.load(ckpt_path))
214
+ diffusion_model.to(device)
215
 
216
  ## speed by torch.compile
217
  if args.compile:
218
+ diffusion_model = torch.compile(diffusion_model)
 
 
 
 
 
 
219
 
220
 
221
  load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
222
  # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
223
+ diffusion_model.pretransform.load_state_dict(load_vae_state)
 
 
 
 
 
 
 
 
 
 
224
 
225
  def get_video_duration(video_path):
226
  video = VideoFileClip(video_path)
 
263
  sync_seq_len = preprocessed_data['sync_features'].shape[0]
264
  clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
265
  latent_seq_len = (int)(194/9*duration_sec)
266
+ diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
267
 
268
  metadata = [preprocessed_data]
269
 
270
  batch_size = 1
271
  length = latent_seq_len
272
  with torch.amp.autocast(device):
273
+ conditioning = diffusion_model.conditioner(metadata, device)
274
 
275
  video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
276
+ conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
277
+ conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
278
 
279
  yield "⏳ Inferring…", None
280
 
281
+ cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
282
+ noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
283
  with torch.amp.autocast(device):
284
+ if diffusion_model.diffusion_objective == "v":
285
+ fakes = sample(diffusion_model.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
286
+ elif diffusion_model.diffusion_objective == "rectified_flow":
 
287
  import time
288
  start_time = time.time()
289
+ fakes = sample_discrete_euler(diffusion_model.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
290
  end_time = time.time()
291
  execution_time = end_time - start_time
292
+ print(f"execution_time: {execution_time:.2f} 秒")
293
+
294
+ if diffusion_model.pretransform is not None:
295
+ fakes = diffusion_model.pretransform.decode(fakes)
296
 
297
  audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
298
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: