alan commited on
Commit
cb5b6f4
·
1 Parent(s): a2f3b4f

Update speaker

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -39,7 +39,8 @@ MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
39
  DEFAULT_TARGET_LANGUAGE = "French"
40
 
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
-
 
43
  processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
44
  model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
45
  # processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
@@ -82,7 +83,8 @@ def predict(
82
  if task_name in ["S2TT", "T2TT"]:
83
  tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
84
  else:
85
- output = model.generate(**input_data, return_intermediate_token_ids=True, tgt_lang=target_language_code, speaker_id=LANG_TO_SPKR_ID[target_language_code], num_beams=5, do_sample=True)
 
86
 
87
  waveform = output.waveform.cpu().squeeze().detach().numpy()
88
  tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
 
39
  DEFAULT_TARGET_LANGUAGE = "French"
40
 
41
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
+ # if torch.backends.mps.is_available():
43
+ # device = torch.device("mps")
44
  processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
45
  model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
46
  # processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
 
83
  if task_name in ["S2TT", "T2TT"]:
84
  tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
85
  else:
86
+ print(input_data.input_features.shape)
87
+ output = model.generate(**input_data, return_intermediate_token_ids=True, tgt_lang=target_language_code, speaker_id=LANG_TO_SPKR_ID[target_language_code][0], num_beams=5, do_sample=True)
88
 
89
  waveform = output.waveform.cpu().squeeze().detach().numpy()
90
  tokens_ids = output.sequences.cpu().squeeze().detach().tolist()