updated requirements
Browse files- models/vallex.py +8 -3
models/vallex.py
CHANGED
@@ -22,7 +22,6 @@ import torch.nn.functional as F
|
|
22 |
# from icefall.utils import make_pad_mask
|
23 |
# from torchmetrics.classification import MulticlassAccuracy
|
24 |
|
25 |
-
|
26 |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
27 |
from modules.transformer import (
|
28 |
AdaptiveLayerNorm,
|
@@ -493,7 +492,10 @@ class VALLE(VALLF):
|
|
493 |
x = self.ar_text_embedding(text)
|
494 |
# Add language embedding
|
495 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
496 |
-
|
|
|
|
|
|
|
497 |
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
498 |
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
499 |
x = self.ar_text_prenet(x)
|
@@ -599,7 +601,10 @@ class VALLE(VALLF):
|
|
599 |
x = self.nar_text_embedding(text)
|
600 |
# Add language embedding
|
601 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
602 |
-
|
|
|
|
|
|
|
603 |
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
604 |
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
605 |
x = self.nar_text_prenet(x)
|
|
|
22 |
# from icefall.utils import make_pad_mask
|
23 |
# from torchmetrics.classification import MulticlassAccuracy
|
24 |
|
|
|
25 |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
26 |
from modules.transformer import (
|
27 |
AdaptiveLayerNorm,
|
|
|
492 |
x = self.ar_text_embedding(text)
|
493 |
# Add language embedding
|
494 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
495 |
+
if isinstance(text_language, str):
|
496 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
497 |
+
elif isinstance(text_language, List):
|
498 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
499 |
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
500 |
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
501 |
x = self.ar_text_prenet(x)
|
|
|
601 |
x = self.nar_text_embedding(text)
|
602 |
# Add language embedding
|
603 |
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
604 |
+
if isinstance(text_language, str):
|
605 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
606 |
+
elif isinstance(text_language, List):
|
607 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
608 |
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
609 |
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
610 |
x = self.nar_text_prenet(x)
|