martylabs commited on
Commit
9c5161e
·
verified ·
1 Parent(s): 191efb6

Update generate_multitalk.py

Browse files
Files changed (1) hide show
  1. generate_multitalk.py +39 -40
generate_multitalk.py CHANGED
@@ -23,7 +23,6 @@ from wan.utils.utils import cache_image, cache_video, str2bool
23
  from wan.utils.multitalk_utils import save_video_ffmpeg
24
 
25
  from transformers import Wav2Vec2FeatureExtractor
26
- from transformers import Wav2Vec2ForCTC
27
  from src.audio_analysis.wav2vec2 import Wav2Vec2Model
28
 
29
  import librosa
@@ -215,7 +214,7 @@ def _parse_args():
215
  return args
216
 
217
  def custom_init(device, wav2vec):
218
- audio_encoder = Wav2Vec2Model.from_pretrained(args.wav2vec_dir, attn_implementation="eager").to(device)
219
  audio_encoder.freeze_feature_extractor()
220
  wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
221
  return wav2vec_feature_extractor, audio_encoder
@@ -373,50 +372,50 @@ def generate(args):
373
 
374
  assert args.task == "multitalk-14B", 'You should choose multitalk in args.task.'
375
 
376
- # Initialize a placeholder for all processes
377
- input_data = None
378
-
379
- # Let only the main process prepare the data
380
  if rank == 0:
381
  with open(args.input_json, 'r', encoding='utf-8') as f:
382
  input_data = json.load(f)
383
 
384
- wav2vec_feature_extractor, audio_encoder= custom_init('cpu', args.wav2vec_dir)
385
- args.audio_save_dir = os.path.join(args.audio_save_dir, input_data['cond_image'].split('/')[-1].split('.')[0])
386
- os.makedirs(args.audio_save_dir,exist_ok=True)
387
-
388
- if len(input_data['cond_audio'])==2:
389
- new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(input_data['cond_audio']['person1'], input_data['cond_audio']['person2'], input_data['audio_type'])
390
- audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder)
391
- audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder)
392
- emb1_path = os.path.join(args.audio_save_dir, '1.pt')
393
- emb2_path = os.path.join(args.audio_save_dir, '2.pt')
394
- sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
395
- sf.write(sum_audio, sum_human_speechs, 16000)
396
- torch.save(audio_embedding_1, emb1_path)
397
- torch.save(audio_embedding_2, emb2_path)
398
- input_data['cond_audio']['person1'] = emb1_path
399
- input_data['cond_audio']['person2'] = emb2_path
400
- input_data['video_audio'] = sum_audio
401
- elif len(input_data['cond_audio'])==1:
402
- human_speech = audio_prepare_single(input_data['cond_audio']['person1'])
403
- audio_embedding = get_embedding(human_speech, wav2vec_feature_extractor, audio_encoder)
404
- emb_path = os.path.join(args.audio_save_dir, '1.pt')
405
- sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
406
- sf.write(sum_audio, human_speech, 16000)
407
- torch.save(audio_embedding, emb_path)
408
- input_data['cond_audio']['person1'] = emb_path
409
- input_data['video_audio'] = sum_audio
410
-
411
- # Broadcast the data from rank 0 to all other processes
 
 
 
 
412
  if dist.is_initialized():
413
- objects_to_broadcast = [input_data] if rank == 0 else [None]
414
- dist.broadcast_object_list(objects_to_broadcast, src=0)
415
- input_data = objects_to_broadcast[0]
416
-
417
- # Wait for all file I/O to be complete before proceeding
418
  dist.barrier()
419
 
 
 
 
 
420
  logging.info("Creating MultiTalk pipeline.")
421
  wan_i2v = wan.MultiTalkPipeline(
422
  config=cfg,
@@ -461,7 +460,7 @@ def generate(args):
461
  args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}"
462
 
463
  logging.info(f"Saving generated video to {args.save_file}.mp4")
464
- save_video_ffmpeg(video, args.save_file, [input_data['video_audio']])
465
 
466
  logging.info("Finished.")
467
 
 
23
  from wan.utils.multitalk_utils import save_video_ffmpeg
24
 
25
  from transformers import Wav2Vec2FeatureExtractor
 
26
  from src.audio_analysis.wav2vec2 import Wav2Vec2Model
27
 
28
  import librosa
 
214
  return args
215
 
216
  def custom_init(device, wav2vec):
217
+ audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, attn_implementation="eager").to(device)
218
  audio_encoder.freeze_feature_extractor()
219
  wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
220
  return wav2vec_feature_extractor, audio_encoder
 
372
 
373
  assert args.task == "multitalk-14B", 'You should choose multitalk in args.task.'
374
 
375
+ # Let only the main process (rank 0) prepare the audio embeddings and overwrite the input JSON file.
 
 
 
376
  if rank == 0:
377
  with open(args.input_json, 'r', encoding='utf-8') as f:
378
  input_data = json.load(f)
379
 
380
+ wav2vec_feature_extractor, audio_encoder = custom_init('cpu', args.wav2vec_dir)
381
+ args.audio_save_dir = os.path.join(args.audio_save_dir, input_data['cond_image'].split('/')[-1].split('.')[0])
382
+ os.makedirs(args.audio_save_dir, exist_ok=True)
383
+
384
+ if len(input_data['cond_audio']) == 2:
385
+ new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(input_data['cond_audio']['person1'], input_data['cond_audio']['person2'], input_data['audio_type'])
386
+ audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder)
387
+ audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder)
388
+ emb1_path = os.path.join(args.audio_save_dir, '1.pt')
389
+ emb2_path = os.path.join(args.audio_save_dir, '2.pt')
390
+ sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
391
+ sf.write(sum_audio, sum_human_speechs, 16000)
392
+ torch.save(audio_embedding_1, emb1_path)
393
+ torch.save(audio_embedding_2, emb2_path)
394
+ input_data['cond_audio']['person1'] = emb1_path
395
+ input_data['cond_audio']['person2'] = emb2_path
396
+ input_data['video_audio'] = sum_audio
397
+ elif len(input_data['cond_audio']) == 1:
398
+ human_speech = audio_prepare_single(input_data['cond_audio']['person1'])
399
+ audio_embedding = get_embedding(human_speech, wav2vec_feature_extractor, audio_encoder)
400
+ emb_path = os.path.join(args.audio_save_dir, '1.pt')
401
+ sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
402
+ sf.write(sum_audio, human_speech, 16000)
403
+ torch.save(audio_embedding, emb_path)
404
+ input_data['cond_audio']['person1'] = emb_path
405
+ input_data['video_audio'] = sum_audio
406
+
407
+ # Overwrite the temporary JSON file with the updated paths
408
+ with open(args.input_json, 'w', encoding='utf-8') as f:
409
+ json.dump(input_data, f, indent=4)
410
+
411
+ # Barrier to ensure rank 0 has finished writing all files (embeddings AND the json)
412
  if dist.is_initialized():
 
 
 
 
 
413
  dist.barrier()
414
 
415
+ # Now, ALL processes read the (now updated) JSON file, ensuring data consistency.
416
+ with open(args.input_json, 'r', encoding='utf-8') as f:
417
+ input_data = json.load(f)
418
+
419
  logging.info("Creating MultiTalk pipeline.")
420
  wan_i2v = wan.MultiTalkPipeline(
421
  config=cfg,
 
460
  args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}"
461
 
462
  logging.info(f"Saving generated video to {args.save_file}.mp4")
463
+ save__video_ffmpeg(video, args.save_file, [input_data['video_audio']])
464
 
465
  logging.info("Finished.")
466