Spaces:
Paused
Paused
myhanhhyugen
commited on
Commit
•
dc9eaa3
1
Parent(s):
0629725
initial commits
Browse files- TTSInferencing.py +267 -0
- hyperparams.yaml +187 -0
- model.ckpt +3 -0
- module_classes.py +214 -0
TTSInferencing.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import random
|
7 |
+
import speechbrain
|
8 |
+
from speechbrain.inference.interfaces import Pretrained
|
9 |
+
from speechbrain.inference.text import GraphemeToPhoneme
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class TTSInferencing(Pretrained):
|
14 |
+
"""
|
15 |
+
A ready-to-use wrapper for TTS (text -> mel_spec).
|
16 |
+
Arguments
|
17 |
+
---------
|
18 |
+
hparams
|
19 |
+
Hyperparameters (from HyperPyYAML)
|
20 |
+
"""
|
21 |
+
|
22 |
+
HPARAMS_NEEDED = ["modules", "input_encoder"]
|
23 |
+
|
24 |
+
MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
|
25 |
+
"decoder_prenet", "pos_emb_dec",
|
26 |
+
"Seq2SeqTransformer", "mel_lin",
|
27 |
+
"stop_lin", "decoder_postnet"]
|
28 |
+
|
29 |
+
|
30 |
+
def __init__(self, *args, **kwargs):
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
lexicon = self.hparams.lexicon
|
33 |
+
lexicon = ["@@"] + lexicon
|
34 |
+
self.input_encoder = self.hparams.input_encoder
|
35 |
+
self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
|
36 |
+
self.input_encoder.add_unk()
|
37 |
+
|
38 |
+
self.modules = self.hparams.modules
|
39 |
+
|
40 |
+
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def generate_padded_phonemes(self, texts):
|
46 |
+
"""Computes mel-spectrogram for a list of texts
|
47 |
+
|
48 |
+
Arguments
|
49 |
+
---------
|
50 |
+
texts: List[str]
|
51 |
+
texts to be converted to spectrogram
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
tensors of output spectrograms
|
56 |
+
"""
|
57 |
+
|
58 |
+
# Preprocessing required at the inference time for the input text
|
59 |
+
# "label" below contains input text
|
60 |
+
# "phoneme_labels" contain the phoneme sequences corresponding to input text labels
|
61 |
+
|
62 |
+
phoneme_labels = list()
|
63 |
+
|
64 |
+
for label in texts:
|
65 |
+
|
66 |
+
phoneme_label = list()
|
67 |
+
|
68 |
+
label = self.custom_clean(label).upper()
|
69 |
+
|
70 |
+
words = label.split()
|
71 |
+
words = [word.strip() for word in words]
|
72 |
+
words_phonemes = self.g2p(words)
|
73 |
+
|
74 |
+
for i in range(len(words_phonemes)):
|
75 |
+
words_phonemes_seq = words_phonemes[i]
|
76 |
+
for phoneme in words_phonemes_seq:
|
77 |
+
if not phoneme.isspace():
|
78 |
+
phoneme_label.append(phoneme)
|
79 |
+
phoneme_labels.append(phoneme_label)
|
80 |
+
|
81 |
+
|
82 |
+
# encode the phonemes with input text encoder
|
83 |
+
encoded_phonemes = list()
|
84 |
+
for i in range(len(phoneme_labels)):
|
85 |
+
phoneme_label = phoneme_labels[i]
|
86 |
+
encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
|
87 |
+
encoded_phonemes.append(encoded_phoneme)
|
88 |
+
|
89 |
+
|
90 |
+
# Right zero-pad all one-hot text sequences to max input length
|
91 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
92 |
+
torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
|
93 |
+
)
|
94 |
+
|
95 |
+
max_input_len = input_lengths[0]
|
96 |
+
|
97 |
+
phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
|
98 |
+
phoneme_padded.zero_()
|
99 |
+
|
100 |
+
for seq_idx, seq in enumerate(encoded_phonemes):
|
101 |
+
phoneme_padded[seq_idx, : len(seq)] = seq
|
102 |
+
|
103 |
+
|
104 |
+
return phoneme_padded.to(self.device, non_blocking=True).float()
|
105 |
+
|
106 |
+
|
107 |
+
def encode_batch(self, texts):
|
108 |
+
"""Computes mel-spectrogram for a list of texts
|
109 |
+
|
110 |
+
Texts must be sorted in decreasing order on their lengths
|
111 |
+
|
112 |
+
Arguments
|
113 |
+
---------
|
114 |
+
texts: List[str]
|
115 |
+
texts to be encoded into spectrogram
|
116 |
+
|
117 |
+
Returns
|
118 |
+
-------
|
119 |
+
tensors of output spectrograms
|
120 |
+
"""
|
121 |
+
|
122 |
+
# generate phonemes and padd the input texts
|
123 |
+
encoded_phoneme_padded = self.generate_padded_phonemes(texts)
|
124 |
+
phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
|
125 |
+
# Positional Embeddings
|
126 |
+
phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded)
|
127 |
+
# Summing up embeddings
|
128 |
+
enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb
|
129 |
+
enc_phoneme_emb = enc_phoneme_emb.to(self.device)
|
130 |
+
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
|
134 |
+
# generate sequential predictions via transformer decoder
|
135 |
+
start_token = torch.full((80, 1), fill_value= 0)
|
136 |
+
start_token[1] = 2
|
137 |
+
decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
|
138 |
+
decoder_input = decoder_input.to(self.device, non_blocking=True).float()
|
139 |
+
|
140 |
+
num_itr = 0
|
141 |
+
stop_condition = [False] * decoder_input.size(0)
|
142 |
+
max_iter = 100
|
143 |
+
|
144 |
+
# while not all(stop_condition) and num_itr < max_iter:
|
145 |
+
while num_itr < max_iter:
|
146 |
+
|
147 |
+
# Decoder Prenet
|
148 |
+
mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)
|
149 |
+
|
150 |
+
# Positional Embeddings
|
151 |
+
mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
|
152 |
+
# Summing up Embeddings
|
153 |
+
dec_mel_spec = mel_prenet_emb + mel_pos_emb
|
154 |
+
|
155 |
+
# Getting the target mask to avoid looking ahead
|
156 |
+
tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)
|
157 |
+
|
158 |
+
# Getting the source mask
|
159 |
+
src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)
|
160 |
+
|
161 |
+
# Padding masks for source and targets
|
162 |
+
src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
|
163 |
+
tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)
|
164 |
+
|
165 |
+
|
166 |
+
# Running the Seq2Seq Transformer
|
167 |
+
decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
|
168 |
+
src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)
|
169 |
+
|
170 |
+
# Mel Linears
|
171 |
+
mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
|
172 |
+
mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
|
173 |
+
mel_pred = mel_linears + mel_postnet # mel tensor output
|
174 |
+
|
175 |
+
stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1)
|
176 |
+
|
177 |
+
stop_condition_list = self.check_stop_condition(stop_token_pred)
|
178 |
+
|
179 |
+
|
180 |
+
# update the values of main stop conditions
|
181 |
+
stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
|
182 |
+
stop_condition = stop_condition_update
|
183 |
+
|
184 |
+
|
185 |
+
# Prepare input for the transformer input for next iteration
|
186 |
+
current_output = mel_pred[:, :, -1:]
|
187 |
+
|
188 |
+
decoder_input=torch.cat([decoder_input,current_output],dim=2)
|
189 |
+
num_itr = num_itr+1
|
190 |
+
|
191 |
+
mel_outputs = decoder_input[:, :, 1:]
|
192 |
+
|
193 |
+
return mel_outputs
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
def encode_text(self, text):
|
198 |
+
"""Runs inference for a single text str"""
|
199 |
+
return self.encode_batch([text])
|
200 |
+
|
201 |
+
|
202 |
+
def forward(self, text_list):
|
203 |
+
"Encodes the input texts."
|
204 |
+
return self.encode_batch(text_list)
|
205 |
+
|
206 |
+
|
207 |
+
def check_stop_condition(self, stop_token_pred):
|
208 |
+
"""
|
209 |
+
check if stop token / EOS reached or not for mel_specs in the batch
|
210 |
+
"""
|
211 |
+
|
212 |
+
# Applying sigmoid to perform binary classification
|
213 |
+
sigmoid_output = torch.sigmoid(stop_token_pred)
|
214 |
+
# Checking if the probability is greater than 0.5
|
215 |
+
stop_results = sigmoid_output > 0.8
|
216 |
+
stop_output = [all(result) for result in stop_results]
|
217 |
+
|
218 |
+
return stop_output
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
def custom_clean(self, text):
|
223 |
+
"""
|
224 |
+
Uses custom criteria to clean text.
|
225 |
+
|
226 |
+
Arguments
|
227 |
+
---------
|
228 |
+
text : str
|
229 |
+
Input text to be cleaned
|
230 |
+
model_name : str
|
231 |
+
whether to treat punctuations
|
232 |
+
|
233 |
+
Returns
|
234 |
+
-------
|
235 |
+
text : str
|
236 |
+
Cleaned text
|
237 |
+
"""
|
238 |
+
|
239 |
+
_abbreviations = [
|
240 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
241 |
+
for x in [
|
242 |
+
("mrs", "missus"),
|
243 |
+
("mr", "mister"),
|
244 |
+
("dr", "doctor"),
|
245 |
+
("st", "saint"),
|
246 |
+
("co", "company"),
|
247 |
+
("jr", "junior"),
|
248 |
+
("maj", "major"),
|
249 |
+
("gen", "general"),
|
250 |
+
("drs", "doctors"),
|
251 |
+
("rev", "reverend"),
|
252 |
+
("lt", "lieutenant"),
|
253 |
+
("hon", "honorable"),
|
254 |
+
("sgt", "sergeant"),
|
255 |
+
("capt", "captain"),
|
256 |
+
("esq", "esquire"),
|
257 |
+
("ltd", "limited"),
|
258 |
+
("col", "colonel"),
|
259 |
+
("ft", "fort"),
|
260 |
+
]
|
261 |
+
]
|
262 |
+
|
263 |
+
text = re.sub(" +", " ", text)
|
264 |
+
|
265 |
+
for regex, replacement in _abbreviations:
|
266 |
+
text = re.sub(regex, replacement, text)
|
267 |
+
return text
|
hyperparams.yaml
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
############################################################################
|
3 |
+
# Model: TTS with attention-based mechanism
|
4 |
+
# Tokens: g2p + possitional embeddings
|
5 |
+
# losses: MSE & BCE
|
6 |
+
# Training: LJSpeech
|
7 |
+
# ############################################################################
|
8 |
+
|
9 |
+
###################################
|
10 |
+
# Experiment Parameters and setup #
|
11 |
+
###################################
|
12 |
+
seed: 1234
|
13 |
+
__set_seed: !apply:torch.manual_seed [!ref <seed>]
|
14 |
+
|
15 |
+
# Folder set up
|
16 |
+
# output_folder: !ref .\\results\\tts\\<seed>
|
17 |
+
# save_folder: !ref <output_folder>\\save
|
18 |
+
|
19 |
+
output_folder: !ref ./results/<seed>
|
20 |
+
save_folder: !ref <output_folder>/save
|
21 |
+
|
22 |
+
|
23 |
+
################################
|
24 |
+
# Model Parameters and model #
|
25 |
+
################################
|
26 |
+
# Input parameters
|
27 |
+
lexicon:
|
28 |
+
- AA
|
29 |
+
- AE
|
30 |
+
- AH
|
31 |
+
- AO
|
32 |
+
- AW
|
33 |
+
- AY
|
34 |
+
- B
|
35 |
+
- CH
|
36 |
+
- D
|
37 |
+
- DH
|
38 |
+
- EH
|
39 |
+
- ER
|
40 |
+
- EY
|
41 |
+
- F
|
42 |
+
- G
|
43 |
+
- HH
|
44 |
+
- IH
|
45 |
+
- IY
|
46 |
+
- JH
|
47 |
+
- K
|
48 |
+
- L
|
49 |
+
- M
|
50 |
+
- N
|
51 |
+
- NG
|
52 |
+
- OW
|
53 |
+
- OY
|
54 |
+
- P
|
55 |
+
- R
|
56 |
+
- S
|
57 |
+
- SH
|
58 |
+
- T
|
59 |
+
- TH
|
60 |
+
- UH
|
61 |
+
- UW
|
62 |
+
- V
|
63 |
+
- W
|
64 |
+
- Y
|
65 |
+
- Z
|
66 |
+
- ZH
|
67 |
+
|
68 |
+
input_encoder: !new:speechbrain.dataio.encoder.TextEncoder
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
################################
|
73 |
+
# Model Parameters and model #
|
74 |
+
# Transformer Parameters
|
75 |
+
################################
|
76 |
+
d_model: 512
|
77 |
+
nhead: 8
|
78 |
+
num_encoder_layers: 3
|
79 |
+
num_decoder_layers: 3
|
80 |
+
dim_feedforward: 512
|
81 |
+
dropout: 0.1
|
82 |
+
|
83 |
+
|
84 |
+
# Decoder parameters
|
85 |
+
# The number of frames in the target per encoder step
|
86 |
+
n_frames_per_step: 1
|
87 |
+
decoder_rnn_dim: 1024
|
88 |
+
prenet_dim: 256
|
89 |
+
max_decoder_steps: 1000
|
90 |
+
gate_threshold: 0.5
|
91 |
+
p_decoder_dropout: 0.1
|
92 |
+
decoder_no_early_stopping: False
|
93 |
+
|
94 |
+
blank_index: 0 # This special tokes is for padding
|
95 |
+
|
96 |
+
|
97 |
+
# Masks
|
98 |
+
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
|
99 |
+
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
|
100 |
+
|
101 |
+
|
102 |
+
################################
|
103 |
+
# CNN 3-layers Prenet #
|
104 |
+
################################
|
105 |
+
# Encoder Prenet
|
106 |
+
encoder_prenet: !new:module_classes.CNNPrenet
|
107 |
+
|
108 |
+
# Decoder Prenet
|
109 |
+
decoder_prenet: !new:module_classes.CNNDecoderPrenet
|
110 |
+
|
111 |
+
################################
|
112 |
+
# Positional Encodings #
|
113 |
+
################################
|
114 |
+
|
115 |
+
#encoder
|
116 |
+
pos_emb_enc: !new:module_classes.ScaledPositionalEncoding
|
117 |
+
input_size: !ref <d_model>
|
118 |
+
max_len: 5000
|
119 |
+
|
120 |
+
#decoder
|
121 |
+
pos_emb_dec: !new:module_classes.ScaledPositionalEncoding
|
122 |
+
input_size: !ref <d_model>
|
123 |
+
max_len: 5000
|
124 |
+
|
125 |
+
|
126 |
+
################################
|
127 |
+
# S2S Transfomer #
|
128 |
+
################################
|
129 |
+
|
130 |
+
Seq2SeqTransformer: !new:torch.nn.Transformer
|
131 |
+
d_model: !ref <d_model>
|
132 |
+
nhead: !ref <nhead>
|
133 |
+
num_encoder_layers: !ref <num_encoder_layers>
|
134 |
+
num_decoder_layers: !ref <num_decoder_layers>
|
135 |
+
dim_feedforward: !ref <dim_feedforward>
|
136 |
+
dropout: !ref <dropout>
|
137 |
+
batch_first: True
|
138 |
+
|
139 |
+
|
140 |
+
################################
|
141 |
+
# CNN 5-layers PostNet #
|
142 |
+
################################
|
143 |
+
|
144 |
+
decoder_postnet: !new:speechbrain.lobes.models.Tacotron2.Postnet
|
145 |
+
|
146 |
+
|
147 |
+
# Linear transformation on the top of the decoder.
|
148 |
+
stop_lin: !new:speechbrain.nnet.linear.Linear
|
149 |
+
input_size: !ref <d_model>
|
150 |
+
n_neurons: 1
|
151 |
+
|
152 |
+
|
153 |
+
# Linear transformation on the top of the decoder.
|
154 |
+
mel_lin: !new:speechbrain.nnet.linear.Linear
|
155 |
+
input_size: !ref <d_model>
|
156 |
+
n_neurons: 80
|
157 |
+
|
158 |
+
modules:
|
159 |
+
encoder_prenet: !ref <encoder_prenet>
|
160 |
+
pos_emb_enc: !ref <pos_emb_enc>
|
161 |
+
decoder_prenet: !ref <decoder_prenet>
|
162 |
+
pos_emb_dec: !ref <pos_emb_dec>
|
163 |
+
Seq2SeqTransformer: !ref <Seq2SeqTransformer>
|
164 |
+
mel_lin: !ref <mel_lin>
|
165 |
+
stop_lin: !ref <stop_lin>
|
166 |
+
decoder_postnet: !ref <decoder_postnet>
|
167 |
+
|
168 |
+
|
169 |
+
model: !new:torch.nn.ModuleList
|
170 |
+
- [!ref <encoder_prenet>,!ref <pos_emb_enc>,
|
171 |
+
!ref <decoder_prenet>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>,
|
172 |
+
!ref <mel_lin>, !ref <stop_lin>, !ref <decoder_postnet>]
|
173 |
+
|
174 |
+
|
175 |
+
pretrained_model_path: ./model.ckpt
|
176 |
+
|
177 |
+
# The pretrainer allows a mapping between pretrained files and instances that
|
178 |
+
# are declared in the yaml. E.g here, we will download the file model.ckpt
|
179 |
+
# and it will be loaded into "model" which is pointing to the <model> defined
|
180 |
+
# before.
|
181 |
+
|
182 |
+
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
|
183 |
+
collect_in: !ref <save_folder>
|
184 |
+
loadables:
|
185 |
+
model: !ref <model>
|
186 |
+
paths:
|
187 |
+
model: !ref <pretrained_model_path>
|
model.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e5421fe987116817841652862ce070a421d7f5d7c8bbef68c83bec876b1eafb
|
3 |
+
size 95804314
|
module_classes.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
|
7 |
+
class CNNPrenet(torch.nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(CNNPrenet, self).__init__()
|
10 |
+
|
11 |
+
# Define the layers using Sequential container
|
12 |
+
self.conv_layers = nn.Sequential(
|
13 |
+
nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding=1),
|
14 |
+
nn.BatchNorm1d(512),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.Dropout(0.1),
|
17 |
+
|
18 |
+
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
19 |
+
nn.BatchNorm1d(512),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Dropout(0.1),
|
22 |
+
|
23 |
+
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
24 |
+
nn.BatchNorm1d(512),
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Dropout(0.1)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
|
31 |
+
# Add a new dimension for the channel
|
32 |
+
x = x.unsqueeze(1)
|
33 |
+
|
34 |
+
# Pass input through convolutional layers
|
35 |
+
x = self.conv_layers(x)
|
36 |
+
|
37 |
+
# Remove the channel dimension
|
38 |
+
x = x.squeeze(1)
|
39 |
+
|
40 |
+
# Scale the output to the range [-1, 1]
|
41 |
+
x = torch.tanh(x)
|
42 |
+
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
class CNNDecoderPrenet(nn.Module):
|
48 |
+
def __init__(self, input_dim=80, hidden_dim=256, output_dim=256, final_dim=512, dropout_rate=0.5):
|
49 |
+
super(CNNDecoderPrenet, self).__init__()
|
50 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
51 |
+
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
52 |
+
self.linear_projection = nn.Linear(output_dim, final_dim) # Added linear projection
|
53 |
+
self.dropout = nn.Dropout(dropout_rate)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
|
57 |
+
# Transpose the input tensor to have the feature dimension as the last dimension
|
58 |
+
x = x.transpose(1, 2)
|
59 |
+
# Apply the linear layers
|
60 |
+
x = F.relu(self.layer1(x))
|
61 |
+
x = self.dropout(x)
|
62 |
+
x = F.relu(self.layer2(x))
|
63 |
+
x = self.dropout(x)
|
64 |
+
# Apply the linear projection
|
65 |
+
x = self.linear_projection(x)
|
66 |
+
x = x.transpose(1, 2)
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
class CNNPostNet(torch.nn.Module):
|
74 |
+
"""
|
75 |
+
Conv Postnet
|
76 |
+
Arguments
|
77 |
+
---------
|
78 |
+
n_mel_channels: int
|
79 |
+
input feature dimension for convolution layers
|
80 |
+
postnet_embedding_dim: int
|
81 |
+
output feature dimension for convolution layers
|
82 |
+
postnet_kernel_size: int
|
83 |
+
postnet convolution kernal size
|
84 |
+
postnet_n_convolutions: int
|
85 |
+
number of convolution layers
|
86 |
+
postnet_dropout: float
|
87 |
+
dropout probability fot postnet
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
n_mel_channels=80,
|
93 |
+
postnet_embedding_dim=512,
|
94 |
+
postnet_kernel_size=5,
|
95 |
+
postnet_n_convolutions=5,
|
96 |
+
postnet_dropout=0.1,
|
97 |
+
):
|
98 |
+
super(CNNPostNet, self).__init__()
|
99 |
+
|
100 |
+
self.conv_pre = nn.Conv1d(
|
101 |
+
in_channels=n_mel_channels,
|
102 |
+
out_channels=postnet_embedding_dim,
|
103 |
+
kernel_size=postnet_kernel_size,
|
104 |
+
padding="same",
|
105 |
+
)
|
106 |
+
|
107 |
+
self.convs_intermedite = nn.ModuleList()
|
108 |
+
for i in range(1, postnet_n_convolutions - 1):
|
109 |
+
self.convs_intermedite.append(
|
110 |
+
nn.Conv1d(
|
111 |
+
in_channels=postnet_embedding_dim,
|
112 |
+
out_channels=postnet_embedding_dim,
|
113 |
+
kernel_size=postnet_kernel_size,
|
114 |
+
padding="same",
|
115 |
+
),
|
116 |
+
)
|
117 |
+
|
118 |
+
self.conv_post = nn.Conv1d(
|
119 |
+
in_channels=postnet_embedding_dim,
|
120 |
+
out_channels=n_mel_channels,
|
121 |
+
kernel_size=postnet_kernel_size,
|
122 |
+
padding="same",
|
123 |
+
)
|
124 |
+
|
125 |
+
self.tanh = nn.Tanh()
|
126 |
+
self.ln1 = nn.LayerNorm(postnet_embedding_dim)
|
127 |
+
self.ln2 = nn.LayerNorm(postnet_embedding_dim)
|
128 |
+
self.ln3 = nn.LayerNorm(n_mel_channels)
|
129 |
+
self.dropout1 = nn.Dropout(postnet_dropout)
|
130 |
+
self.dropout2 = nn.Dropout(postnet_dropout)
|
131 |
+
self.dropout3 = nn.Dropout(postnet_dropout)
|
132 |
+
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
"""Computes the forward pass
|
136 |
+
Arguments
|
137 |
+
---------
|
138 |
+
x: torch.Tensor
|
139 |
+
a (batch, time_steps, features) input tensor
|
140 |
+
Returns
|
141 |
+
-------
|
142 |
+
output: torch.Tensor (the spectrogram predicted)
|
143 |
+
"""
|
144 |
+
x = self.conv_pre(x)
|
145 |
+
x = self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
146 |
+
x = self.tanh(x)
|
147 |
+
x = self.dropout1(x)
|
148 |
+
|
149 |
+
for i in range(len(self.convs_intermedite)):
|
150 |
+
x = self.convs_intermedite[i](x)
|
151 |
+
x = self.ln2(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
152 |
+
x = self.tanh(x)
|
153 |
+
x = self.dropout2(x)
|
154 |
+
|
155 |
+
x = self.conv_post(x)
|
156 |
+
x = self.ln3(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
|
157 |
+
x = self.dropout3(x)
|
158 |
+
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
class ScaledPositionalEncoding(nn.Module):
|
163 |
+
"""
|
164 |
+
This class implements the absolute sinusoidal positional encoding function
|
165 |
+
with an adaptive weight parameter alpha.
|
166 |
+
|
167 |
+
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
168 |
+
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
169 |
+
|
170 |
+
Arguments
|
171 |
+
---------
|
172 |
+
input_size: int
|
173 |
+
Embedding dimension.
|
174 |
+
max_len : int, optional
|
175 |
+
Max length of the input sequences (default 2500).
|
176 |
+
Example
|
177 |
+
-------
|
178 |
+
>>> a = torch.rand((8, 120, 512))
|
179 |
+
>>> enc = PositionalEncoding(input_size=a.shape[-1])
|
180 |
+
>>> b = enc(a)
|
181 |
+
>>> b.shape
|
182 |
+
torch.Size([1, 120, 512])
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, input_size, max_len=2500):
|
186 |
+
super().__init__()
|
187 |
+
if input_size % 2 != 0:
|
188 |
+
raise ValueError(
|
189 |
+
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
|
190 |
+
)
|
191 |
+
self.max_len = max_len
|
192 |
+
self.alpha = nn.Parameter(torch.ones(1)) # Define alpha as a trainable parameter
|
193 |
+
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
|
194 |
+
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
|
195 |
+
denominator = torch.exp(
|
196 |
+
torch.arange(0, input_size, 2).float()
|
197 |
+
* -(math.log(10000.0) / input_size)
|
198 |
+
)
|
199 |
+
|
200 |
+
pe[:, 0::2] = torch.sin(positions * denominator)
|
201 |
+
pe[:, 1::2] = torch.cos(positions * denominator)
|
202 |
+
pe = pe.unsqueeze(0)
|
203 |
+
self.register_buffer("pe", pe)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
"""
|
207 |
+
Arguments
|
208 |
+
---------
|
209 |
+
x : tensor
|
210 |
+
Input feature shape (batch, time, fea)
|
211 |
+
"""
|
212 |
+
pe_scaled = self.pe[:, :x.size(1)].clone().detach() * self.alpha # Scale positional encoding by alpha
|
213 |
+
return pe_scaled
|
214 |
+
|