Spaces:
Runtime error
Runtime error
alan
commited on
Commit
·
cb5b6f4
1
Parent(s):
a2f3b4f
Update speaker
Browse files
app.py
CHANGED
@@ -39,7 +39,8 @@ MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
|
|
39 |
DEFAULT_TARGET_LANGUAGE = "French"
|
40 |
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
-
|
|
|
43 |
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
|
44 |
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
|
45 |
# processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
|
@@ -82,7 +83,8 @@ def predict(
|
|
82 |
if task_name in ["S2TT", "T2TT"]:
|
83 |
tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
|
84 |
else:
|
85 |
-
|
|
|
86 |
|
87 |
waveform = output.waveform.cpu().squeeze().detach().numpy()
|
88 |
tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
|
|
|
39 |
DEFAULT_TARGET_LANGUAGE = "French"
|
40 |
|
41 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
42 |
+
# if torch.backends.mps.is_available():
|
43 |
+
# device = torch.device("mps")
|
44 |
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
|
45 |
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to(device)
|
46 |
# processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
|
|
|
83 |
if task_name in ["S2TT", "T2TT"]:
|
84 |
tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
|
85 |
else:
|
86 |
+
print(input_data.input_features.shape)
|
87 |
+
output = model.generate(**input_data, return_intermediate_token_ids=True, tgt_lang=target_language_code, speaker_id=LANG_TO_SPKR_ID[target_language_code][0], num_beams=5, do_sample=True)
|
88 |
|
89 |
waveform = output.waveform.cpu().squeeze().detach().numpy()
|
90 |
tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
|