Mendoza33 commited on
Commit
02e643c
·
verified ·
1 Parent(s): 95daf41

Upload 3 files

Browse files
Files changed (3) hide show
  1. kokoro-v0_19.pth +3 -0
  2. kokoro.py +165 -0
  3. models.py +372 -0
kokoro-v0_19.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a
3
+ size 327211206
kokoro.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+ import numpy as np
5
+
6
+ def split_num(num):
7
+ num = num.group()
8
+ if '.' in num:
9
+ return num
10
+ elif ':' in num:
11
+ h, m = [int(n) for n in num.split(':')]
12
+ if m == 0:
13
+ return f"{h} o'clock"
14
+ elif m < 10:
15
+ return f'{h} oh {m}'
16
+ return f'{h} {m}'
17
+ year = int(num[:4])
18
+ if year < 1100 or year % 1000 < 10:
19
+ return num
20
+ left, right = num[:2], int(num[2:4])
21
+ s = 's' if num.endswith('s') else ''
22
+ if 100 <= year % 1000 <= 999:
23
+ if right == 0:
24
+ return f'{left} hundred{s}'
25
+ elif right < 10:
26
+ return f'{left} oh {right}{s}'
27
+ return f'{left} {right}{s}'
28
+
29
+ def flip_money(m):
30
+ m = m.group()
31
+ bill = 'dollar' if m[0] == '$' else 'pound'
32
+ if m[-1].isalpha():
33
+ return f'{m[1:]} {bill}s'
34
+ elif '.' not in m:
35
+ s = '' if m[1:] == '1' else 's'
36
+ return f'{m[1:]} {bill}{s}'
37
+ b, c = m[1:].split('.')
38
+ s = '' if b == '1' else 's'
39
+ c = int(c.ljust(2, '0'))
40
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
41
+ return f'{b} {bill}{s} and {c} {coins}'
42
+
43
+ def point_num(num):
44
+ a, b = num.group().split('.')
45
+ return ' point '.join([a, ' '.join(b)])
46
+
47
+ def normalize_text(text):
48
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
49
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
50
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
51
+ text = text.replace('(', '«').replace(')', '»')
52
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
53
+ text = text.replace(a, b+' ')
54
+ text = re.sub(r'[^\S \n]', ' ', text)
55
+ text = re.sub(r' +', ' ', text)
56
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
57
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
58
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
59
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
60
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
61
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
62
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
63
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
64
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
65
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
66
+ text = re.sub(r'\d*\.\d+', point_num, text)
67
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
68
+ text = re.sub(r'(?<=\d)S', ' S', text)
69
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
70
+ text = re.sub(r"(?<=X')S\b", 's', text)
71
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
72
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
73
+ return text.strip()
74
+
75
+ def get_vocab():
76
+ _pad = "$"
77
+ _punctuation = ';:,.!?¡¿—…"«»“” '
78
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
79
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
80
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
81
+ dicts = {}
82
+ for i in range(len((symbols))):
83
+ dicts[symbols[i]] = i
84
+ return dicts
85
+
86
+ VOCAB = get_vocab()
87
+ def tokenize(ps):
88
+ return [i for i in map(VOCAB.get, ps) if i is not None]
89
+
90
+ phonemizers = dict(
91
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
92
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
93
+ )
94
+ def phonemize(text, lang, norm=True):
95
+ if norm:
96
+ text = normalize_text(text)
97
+ ps = phonemizers[lang].phonemize([text])
98
+ ps = ps[0] if ps else ''
99
+ # https://en.wiktionary.org/wiki/kokoro#English
100
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
101
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
102
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
103
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
104
+ if lang == 'a':
105
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
106
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
107
+ return ps.strip()
108
+
109
+ def length_to_mask(lengths):
110
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
111
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
112
+ return mask
113
+
114
+ @torch.no_grad()
115
+ def forward(model, tokens, ref_s, speed):
116
+ device = ref_s.device
117
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
118
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
119
+ text_mask = length_to_mask(input_lengths).to(device)
120
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
121
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
122
+ s = ref_s[:, 128:]
123
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
124
+ x, _ = model.predictor.lstm(d)
125
+ duration = model.predictor.duration_proj(x)
126
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
127
+ pred_dur = torch.round(duration).clamp(min=1).long()
128
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
129
+ c_frame = 0
130
+ for i in range(pred_aln_trg.size(0)):
131
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
132
+ c_frame += pred_dur[0,i].item()
133
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
134
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
135
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
136
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
137
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
138
+
139
+ def generate(model, text, voicepack, lang='a', speed=1, ps=None):
140
+ ps = ps or phonemize(text, lang)
141
+ tokens = tokenize(ps)
142
+ if not tokens:
143
+ return None
144
+ elif len(tokens) > 510:
145
+ tokens = tokens[:510]
146
+ print('Truncated to 510 tokens')
147
+ ref_s = voicepack[len(tokens)]
148
+ out = forward(model, tokens, ref_s, speed)
149
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
150
+ return out, ps
151
+
152
+ def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
153
+ ps = ps or phonemize(text, lang)
154
+ tokens = tokenize(ps)
155
+ if not tokens:
156
+ return None
157
+ outs = []
158
+ loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
159
+ for i in range(loop_count):
160
+ ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
161
+ out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
162
+ outs.append(out)
163
+ outs = np.concatenate(outs)
164
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
165
+ return outs, ps
models.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from istftnet import AdaIN1d, Decoder
3
+ from munch import Munch
4
+ from pathlib import Path
5
+ from plbert import load_plbert
6
+ from torch.nn.utils import weight_norm, spectral_norm
7
+ import json
8
+ import numpy as np
9
+ import os
10
+ import os.path as osp
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ class LinearNorm(torch.nn.Module):
16
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
17
+ super(LinearNorm, self).__init__()
18
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
19
+
20
+ torch.nn.init.xavier_uniform_(
21
+ self.linear_layer.weight,
22
+ gain=torch.nn.init.calculate_gain(w_init_gain))
23
+
24
+ def forward(self, x):
25
+ return self.linear_layer(x)
26
+
27
+ class LayerNorm(nn.Module):
28
+ def __init__(self, channels, eps=1e-5):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.eps = eps
32
+
33
+ self.gamma = nn.Parameter(torch.ones(channels))
34
+ self.beta = nn.Parameter(torch.zeros(channels))
35
+
36
+ def forward(self, x):
37
+ x = x.transpose(1, -1)
38
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
39
+ return x.transpose(1, -1)
40
+
41
+ class TextEncoder(nn.Module):
42
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
43
+ super().__init__()
44
+ self.embedding = nn.Embedding(n_symbols, channels)
45
+
46
+ padding = (kernel_size - 1) // 2
47
+ self.cnn = nn.ModuleList()
48
+ for _ in range(depth):
49
+ self.cnn.append(nn.Sequential(
50
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
51
+ LayerNorm(channels),
52
+ actv,
53
+ nn.Dropout(0.2),
54
+ ))
55
+ # self.cnn = nn.Sequential(*self.cnn)
56
+
57
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
58
+
59
+ def forward(self, x, input_lengths, m):
60
+ x = self.embedding(x) # [B, T, emb]
61
+ x = x.transpose(1, 2) # [B, emb, T]
62
+ m = m.to(input_lengths.device).unsqueeze(1)
63
+ x.masked_fill_(m, 0.0)
64
+
65
+ for c in self.cnn:
66
+ x = c(x)
67
+ x.masked_fill_(m, 0.0)
68
+
69
+ x = x.transpose(1, 2) # [B, T, chn]
70
+
71
+ input_lengths = input_lengths.cpu().numpy()
72
+ x = nn.utils.rnn.pack_padded_sequence(
73
+ x, input_lengths, batch_first=True, enforce_sorted=False)
74
+
75
+ self.lstm.flatten_parameters()
76
+ x, _ = self.lstm(x)
77
+ x, _ = nn.utils.rnn.pad_packed_sequence(
78
+ x, batch_first=True)
79
+
80
+ x = x.transpose(-1, -2)
81
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
82
+
83
+ x_pad[:, :, :x.shape[-1]] = x
84
+ x = x_pad.to(x.device)
85
+
86
+ x.masked_fill_(m, 0.0)
87
+
88
+ return x
89
+
90
+ def inference(self, x):
91
+ x = self.embedding(x)
92
+ x = x.transpose(1, 2)
93
+ x = self.cnn(x)
94
+ x = x.transpose(1, 2)
95
+ self.lstm.flatten_parameters()
96
+ x, _ = self.lstm(x)
97
+ return x
98
+
99
+ def length_to_mask(self, lengths):
100
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
101
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
102
+ return mask
103
+
104
+
105
+ class UpSample1d(nn.Module):
106
+ def __init__(self, layer_type):
107
+ super().__init__()
108
+ self.layer_type = layer_type
109
+
110
+ def forward(self, x):
111
+ if self.layer_type == 'none':
112
+ return x
113
+ else:
114
+ return F.interpolate(x, scale_factor=2, mode='nearest')
115
+
116
+ class AdainResBlk1d(nn.Module):
117
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
118
+ upsample='none', dropout_p=0.0):
119
+ super().__init__()
120
+ self.actv = actv
121
+ self.upsample_type = upsample
122
+ self.upsample = UpSample1d(upsample)
123
+ self.learned_sc = dim_in != dim_out
124
+ self._build_weights(dim_in, dim_out, style_dim)
125
+ self.dropout = nn.Dropout(dropout_p)
126
+
127
+ if upsample == 'none':
128
+ self.pool = nn.Identity()
129
+ else:
130
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
131
+
132
+
133
+ def _build_weights(self, dim_in, dim_out, style_dim):
134
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
135
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
136
+ self.norm1 = AdaIN1d(style_dim, dim_in)
137
+ self.norm2 = AdaIN1d(style_dim, dim_out)
138
+ if self.learned_sc:
139
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
140
+
141
+ def _shortcut(self, x):
142
+ x = self.upsample(x)
143
+ if self.learned_sc:
144
+ x = self.conv1x1(x)
145
+ return x
146
+
147
+ def _residual(self, x, s):
148
+ x = self.norm1(x, s)
149
+ x = self.actv(x)
150
+ x = self.pool(x)
151
+ x = self.conv1(self.dropout(x))
152
+ x = self.norm2(x, s)
153
+ x = self.actv(x)
154
+ x = self.conv2(self.dropout(x))
155
+ return x
156
+
157
+ def forward(self, x, s):
158
+ out = self._residual(x, s)
159
+ out = (out + self._shortcut(x)) / np.sqrt(2)
160
+ return out
161
+
162
+ class AdaLayerNorm(nn.Module):
163
+ def __init__(self, style_dim, channels, eps=1e-5):
164
+ super().__init__()
165
+ self.channels = channels
166
+ self.eps = eps
167
+
168
+ self.fc = nn.Linear(style_dim, channels*2)
169
+
170
+ def forward(self, x, s):
171
+ x = x.transpose(-1, -2)
172
+ x = x.transpose(1, -1)
173
+
174
+ h = self.fc(s)
175
+ h = h.view(h.size(0), h.size(1), 1)
176
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
177
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
178
+
179
+
180
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
181
+ x = (1 + gamma) * x + beta
182
+ return x.transpose(1, -1).transpose(-1, -2)
183
+
184
+ class ProsodyPredictor(nn.Module):
185
+
186
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
187
+ super().__init__()
188
+
189
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
190
+ d_model=d_hid,
191
+ nlayers=nlayers,
192
+ dropout=dropout)
193
+
194
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
195
+ self.duration_proj = LinearNorm(d_hid, max_dur)
196
+
197
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
198
+ self.F0 = nn.ModuleList()
199
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
200
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
201
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
202
+
203
+ self.N = nn.ModuleList()
204
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
205
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
206
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
207
+
208
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
209
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
210
+
211
+
212
+ def forward(self, texts, style, text_lengths, alignment, m):
213
+ d = self.text_encoder(texts, style, text_lengths, m)
214
+
215
+ batch_size = d.shape[0]
216
+ text_size = d.shape[1]
217
+
218
+ # predict duration
219
+ input_lengths = text_lengths.cpu().numpy()
220
+ x = nn.utils.rnn.pack_padded_sequence(
221
+ d, input_lengths, batch_first=True, enforce_sorted=False)
222
+
223
+ m = m.to(text_lengths.device).unsqueeze(1)
224
+
225
+ self.lstm.flatten_parameters()
226
+ x, _ = self.lstm(x)
227
+ x, _ = nn.utils.rnn.pad_packed_sequence(
228
+ x, batch_first=True)
229
+
230
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
231
+
232
+ x_pad[:, :x.shape[1], :] = x
233
+ x = x_pad.to(x.device)
234
+
235
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
236
+
237
+ en = (d.transpose(-1, -2) @ alignment)
238
+
239
+ return duration.squeeze(-1), en
240
+
241
+ def F0Ntrain(self, x, s):
242
+ x, _ = self.shared(x.transpose(-1, -2))
243
+
244
+ F0 = x.transpose(-1, -2)
245
+ for block in self.F0:
246
+ F0 = block(F0, s)
247
+ F0 = self.F0_proj(F0)
248
+
249
+ N = x.transpose(-1, -2)
250
+ for block in self.N:
251
+ N = block(N, s)
252
+ N = self.N_proj(N)
253
+
254
+ return F0.squeeze(1), N.squeeze(1)
255
+
256
+ def length_to_mask(self, lengths):
257
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
258
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
259
+ return mask
260
+
261
+ class DurationEncoder(nn.Module):
262
+
263
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
264
+ super().__init__()
265
+ self.lstms = nn.ModuleList()
266
+ for _ in range(nlayers):
267
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
268
+ d_model // 2,
269
+ num_layers=1,
270
+ batch_first=True,
271
+ bidirectional=True,
272
+ dropout=dropout))
273
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
274
+
275
+
276
+ self.dropout = dropout
277
+ self.d_model = d_model
278
+ self.sty_dim = sty_dim
279
+
280
+ def forward(self, x, style, text_lengths, m):
281
+ masks = m.to(text_lengths.device)
282
+
283
+ x = x.permute(2, 0, 1)
284
+ s = style.expand(x.shape[0], x.shape[1], -1)
285
+ x = torch.cat([x, s], axis=-1)
286
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
287
+
288
+ x = x.transpose(0, 1)
289
+ input_lengths = text_lengths.cpu().numpy()
290
+ x = x.transpose(-1, -2)
291
+
292
+ for block in self.lstms:
293
+ if isinstance(block, AdaLayerNorm):
294
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
295
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
296
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
297
+ else:
298
+ x = x.transpose(-1, -2)
299
+ x = nn.utils.rnn.pack_padded_sequence(
300
+ x, input_lengths, batch_first=True, enforce_sorted=False)
301
+ block.flatten_parameters()
302
+ x, _ = block(x)
303
+ x, _ = nn.utils.rnn.pad_packed_sequence(
304
+ x, batch_first=True)
305
+ x = F.dropout(x, p=self.dropout, training=self.training)
306
+ x = x.transpose(-1, -2)
307
+
308
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
309
+
310
+ x_pad[:, :, :x.shape[-1]] = x
311
+ x = x_pad.to(x.device)
312
+
313
+ return x.transpose(-1, -2)
314
+
315
+ def inference(self, x, style):
316
+ x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
317
+ style = style.expand(x.shape[0], x.shape[1], -1)
318
+ x = torch.cat([x, style], axis=-1)
319
+ src = self.pos_encoder(x)
320
+ output = self.transformer_encoder(src).transpose(0, 1)
321
+ return output
322
+
323
+ def length_to_mask(self, lengths):
324
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
325
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
326
+ return mask
327
+
328
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
329
+ def recursive_munch(d):
330
+ if isinstance(d, dict):
331
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
332
+ elif isinstance(d, list):
333
+ return [recursive_munch(v) for v in d]
334
+ else:
335
+ return d
336
+
337
+ def build_model(path, device):
338
+ config = Path(__file__).parent / 'config.json'
339
+ assert config.exists(), f'Config path incorrect: config.json not found at {config}'
340
+ with open(config, 'r') as r:
341
+ args = recursive_munch(json.load(r))
342
+ assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
343
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
344
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
345
+ upsample_rates = args.decoder.upsample_rates,
346
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
347
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
348
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
349
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
350
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
351
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
352
+ bert = load_plbert()
353
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
354
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
355
+ for child in parent.children():
356
+ if isinstance(child, nn.RNNBase):
357
+ child.flatten_parameters()
358
+ model = Munch(
359
+ bert=bert.to(device).eval(),
360
+ bert_encoder=bert_encoder.to(device).eval(),
361
+ predictor=predictor.to(device).eval(),
362
+ decoder=decoder.to(device).eval(),
363
+ text_encoder=text_encoder.to(device).eval(),
364
+ )
365
+ for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
366
+ assert key in model, key
367
+ try:
368
+ model[key].load_state_dict(state_dict)
369
+ except:
370
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
371
+ model[key].load_state_dict(state_dict, strict=False)
372
+ return model