|
import torch |
|
|
|
from Modules.EmbeddingModel.GST import GSTStyleEncoder |
|
from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder |
|
|
|
|
|
class StyleEmbedding(torch.nn.Module): |
|
""" |
|
The style embedding should provide information of the speaker and their speaking style |
|
|
|
The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop. |
|
The train loop does however supply supervision in the form of a barlow twins objective. |
|
|
|
See the git history for some other approaches for style embedding, like the SWIN transformer |
|
and a simple LSTM baseline. GST turned out to be the best. |
|
""" |
|
|
|
def __init__(self, embedding_dim=16, style_tts_encoder=False): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.use_gst = not style_tts_encoder |
|
if style_tts_encoder: |
|
self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim) |
|
else: |
|
self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim) |
|
|
|
def forward(self, |
|
batch_of_feature_sequences, |
|
batch_of_feature_sequence_lengths): |
|
""" |
|
Args: |
|
batch_of_feature_sequences: b is the batch axis, 128 features per timestep |
|
and l time-steps, which may include padding |
|
for most elements in the batch (b, l, 128) |
|
batch_of_feature_sequence_lengths: indicate for every element in the batch, |
|
what the true length is, since they are |
|
all padded to the length of the longest |
|
element in the batch (b, 1) |
|
Returns: |
|
batch of n dimensional embeddings (b,n) |
|
""" |
|
|
|
minimum_sequence_length = 512 |
|
specs = list() |
|
for index, spec_length in enumerate(batch_of_feature_sequence_lengths): |
|
spec = batch_of_feature_sequences[index][:spec_length] |
|
|
|
spec = spec.repeat((2, 1)) |
|
current_spec_length = len(spec) |
|
while current_spec_length < minimum_sequence_length: |
|
|
|
spec = spec.repeat((2, 1)) |
|
current_spec_length = len(spec) |
|
specs.append(spec[:minimum_sequence_length]) |
|
|
|
spec_batch = torch.stack(specs, dim=0) |
|
return self.style_encoder(speech=spec_batch) |
|
|
|
|
|
if __name__ == '__main__': |
|
style_emb = StyleEmbedding(style_tts_encoder=False) |
|
print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") |
|
|
|
seq_length = 398 |
|
print(style_emb(torch.randn(5, seq_length, 512), |
|
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) |
|
|
|
style_emb = StyleEmbedding(style_tts_encoder=True) |
|
print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") |
|
|
|
seq_length = 398 |
|
print(style_emb(torch.randn(5, seq_length, 512), |
|
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) |
|
|