Spaces:
Running
Running
https://github.com/audeering/shift
Browse files
tts.py
ADDED
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import nltk
|
3 |
+
nltk.download('punkt', download_dir='./') # COMMENT IF DOWNLOADED
|
4 |
+
nltk.download('punkt_tab', download_dir='./') # COMMENT IF DOWNLOADED
|
5 |
+
nltk.data.path.append('.')
|
6 |
+
import librosa
|
7 |
+
import audiofile
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn as nn
|
12 |
+
import string
|
13 |
+
import textwrap
|
14 |
+
import phonemizer
|
15 |
+
from espeak_util import set_espeak_library
|
16 |
+
from transformers import AlbertConfig, AlbertModel
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
from nltk.tokenize import word_tokenize
|
19 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
20 |
+
from torch.nn.utils.parametrizations import weight_norm
|
21 |
+
from torch.nn.utils import spectral_norm
|
22 |
+
|
23 |
+
_pad = "$"
|
24 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
25 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
26 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
27 |
+
MAX_PHONEMES = 424 # For OOM is the max length of single (non-split) sentence for StyleTTS2 inference
|
28 |
+
|
29 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
30 |
+
|
31 |
+
dicts = {}
|
32 |
+
for i in range(len((symbols))):
|
33 |
+
dicts[symbols[i]] = i
|
34 |
+
|
35 |
+
|
36 |
+
class TextCleaner:
|
37 |
+
def __init__(self, dummy=None):
|
38 |
+
self.word_index_dictionary = dicts
|
39 |
+
print(len(dicts))
|
40 |
+
|
41 |
+
def __call__(self, text):
|
42 |
+
indexes = []
|
43 |
+
for char in text:
|
44 |
+
try:
|
45 |
+
indexes.append(self.word_index_dictionary[char])
|
46 |
+
except KeyError:
|
47 |
+
# `=NONVOCAL == \x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f
|
48 |
+
# print(f'NonVOCAL {char}', end='\r')
|
49 |
+
pass
|
50 |
+
return indexes
|
51 |
+
|
52 |
+
set_espeak_library()
|
53 |
+
|
54 |
+
textclenaer = TextCleaner()
|
55 |
+
|
56 |
+
global_phonemizer = phonemizer.backend.EspeakBackend(language="en-us", preserve_punctuation=True, with_stress=True)
|
57 |
+
|
58 |
+
def _del_prefix(d):
|
59 |
+
# del ".module"
|
60 |
+
out = {}
|
61 |
+
for k, v in d.items():
|
62 |
+
out[k[7:]] = v
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
class StyleTTS2(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
albert_base_configuration = AlbertConfig(vocab_size=178,
|
73 |
+
hidden_size=768,
|
74 |
+
num_attention_heads=12,
|
75 |
+
intermediate_size=2048,
|
76 |
+
max_position_embeddings=512,
|
77 |
+
num_hidden_layers=12,
|
78 |
+
dropout=0.1)
|
79 |
+
self.bert = AlbertModel(albert_base_configuration)
|
80 |
+
state_dict = torch.load(hf_hub_download(repo_id='dkounadis/artificial-styletts2',
|
81 |
+
filename='Utils/PLBERT/step_1000000.pth'),
|
82 |
+
map_location='cpu')['net']
|
83 |
+
new_state_dict = {}
|
84 |
+
for k, v in state_dict.items():
|
85 |
+
name = k[7:] # remove `module.`
|
86 |
+
if name.startswith('encoder.'):
|
87 |
+
name = name[8:] # remove `encoder.`
|
88 |
+
new_state_dict[name] = v
|
89 |
+
del new_state_dict["embeddings.position_ids"]
|
90 |
+
self.bert.load_state_dict(new_state_dict, strict=True)
|
91 |
+
self.decoder = Decoder(dim_in=512,
|
92 |
+
style_dim=128,
|
93 |
+
dim_out=80, # n_mels
|
94 |
+
resblock_kernel_sizes=[3, 7, 11],
|
95 |
+
upsample_rates=[10, 5, 3, 2],
|
96 |
+
upsample_initial_channel=512,
|
97 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
98 |
+
upsample_kernel_sizes=[20, 10, 6, 4])
|
99 |
+
self.text_encoder = TextEncoder(channels=512,
|
100 |
+
kernel_size=5,
|
101 |
+
depth=3, # args['model_params']['n_layer'],
|
102 |
+
n_symbols=178, # args['model_params']['n_token']
|
103 |
+
)
|
104 |
+
self.predictor = ProsodyPredictor(style_dim=128,
|
105 |
+
d_hid=512,
|
106 |
+
nlayers=3, # OFFICIAL config.nlayers=5;
|
107 |
+
max_dur=50)
|
108 |
+
self.style_encoder = StyleEncoder()
|
109 |
+
self.predictor_encoder = StyleEncoder()
|
110 |
+
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, 512)
|
111 |
+
self.mel_spec = MelSpec()
|
112 |
+
params = torch.load(hf_hub_download(repo_id='yl4579/StyleTTS2-LibriTTS',
|
113 |
+
filename='Models/LibriTTS/epochs_2nd_00020.pth'),
|
114 |
+
map_location='cpu')['net']
|
115 |
+
self.bert.load_state_dict(_del_prefix(params['bert']), strict=True)
|
116 |
+
self.bert_encoder.load_state_dict(_del_prefix(params['bert_encoder']), strict=True)
|
117 |
+
self.predictor.load_state_dict(_del_prefix(params['predictor']), strict=True)
|
118 |
+
self.decoder.load_state_dict(_del_prefix(params['decoder']), strict=True)
|
119 |
+
self.text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True)
|
120 |
+
self.predictor_encoder.load_state_dict(_del_prefix(params['predictor_encoder']), strict=True)
|
121 |
+
self.style_encoder.load_state_dict(_del_prefix(params['style_encoder']), strict=True)
|
122 |
+
|
123 |
+
# FOR LSTM
|
124 |
+
for n, p in self.named_parameters():
|
125 |
+
p.requires_grad = False
|
126 |
+
self.eval()
|
127 |
+
|
128 |
+
|
129 |
+
def device(self):
|
130 |
+
return self.style_encoder.unshared.weight.device
|
131 |
+
|
132 |
+
def compute_style(self, wav_file=None):
|
133 |
+
|
134 |
+
x, sr = librosa.load(wav_file, sr=24000)
|
135 |
+
x, _ = librosa.effects.trim(x, top_db=30)
|
136 |
+
if sr != 24000:
|
137 |
+
x = librosa.resample(x, sr, 24000)
|
138 |
+
# LOGMEL - Has 16KHz default basisc - Called on 24KHz .wav
|
139 |
+
x = torch.from_numpy(x[None, :]).to(device=self.device(),
|
140 |
+
dtype=torch.float)
|
141 |
+
mel_tensor = (torch.log(1e-5 + self.mel_spec(x)) + 4) / 4
|
142 |
+
#mel_tensor = preprocess(audio).to(device)
|
143 |
+
ref_s = self.style_encoder(mel_tensor)
|
144 |
+
ref_p = self.predictor_encoder(mel_tensor) # [bs, 11, 1, 128]
|
145 |
+
s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
|
146 |
+
s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
|
147 |
+
return s # [1, 128, 11]
|
148 |
+
|
149 |
+
def inference(self,
|
150 |
+
text,
|
151 |
+
ref_s=None):
|
152 |
+
'''text may become too long when phonemized'''
|
153 |
+
|
154 |
+
if isinstance(ref_s, str):
|
155 |
+
ref_s = self.compute_style(ref_s)
|
156 |
+
else:
|
157 |
+
pass # assume ref_s = precomputed style vector
|
158 |
+
|
159 |
+
|
160 |
+
# text = transliterate_number(text, lang='en').strip()
|
161 |
+
# as we are in english transliteration is already done by the text cleaner?
|
162 |
+
# somehow we have phonemes in text that try to be rephonemized
|
163 |
+
# The ds txt should be only ascii
|
164 |
+
|
165 |
+
|
166 |
+
if isinstance(text, str):
|
167 |
+
|
168 |
+
_translator = str.maketrans('', '', string.punctuation)
|
169 |
+
|
170 |
+
text = [sub_sent.translate(_translator) + '.' for sub_sent in textwrap.wrap(text, 74)]
|
171 |
+
|
172 |
+
# # text = nltk.sent_tokenize(text)
|
173 |
+
# # text = [i for sent in sentences for i in textwrap.wrap(sent, width=120)]
|
174 |
+
|
175 |
+
|
176 |
+
# # text = textwrap.wrap(text, width=MAX_PHONEMES) # phonemes thus sent_tokenize() can't split them in sentences
|
177 |
+
|
178 |
+
|
179 |
+
device = ref_s.device
|
180 |
+
total = []
|
181 |
+
for _t in text:
|
182 |
+
|
183 |
+
_t = global_phonemizer.phonemize([_t])
|
184 |
+
_t = word_tokenize(_t[0])
|
185 |
+
_t = ' '.join(_t)
|
186 |
+
|
187 |
+
tokens = textclenaer(_t)[:MAX_PHONEMES] + [4] # textclenaer('.;?!') = [4,1,6,5] # append . punctuation to assure proper sound termination (pulse Issue)
|
188 |
+
|
189 |
+
# After filter we should assure is terminating as a sentence
|
190 |
+
# print(len(_t), len(tokens), 'Msi')#, textclenaer('.;?!'))
|
191 |
+
# ================================= Delete Phonemes If len(phonemes) > len(text) === OOM during training
|
192 |
+
tokens.insert(0, 0)
|
193 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
194 |
+
with torch.no_grad():
|
195 |
+
hidden_states = self.text_encoder(tokens)
|
196 |
+
bert_dur = self.bert(tokens, attention_mask=torch.ones_like(tokens)
|
197 |
+
).last_hidden_state
|
198 |
+
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
|
199 |
+
aln_trg, F0_pred, N_pred = self.predictor(d_en=d_en, s=ref_s[:, 128:, :])
|
200 |
+
asr = torch.bmm(aln_trg, hidden_states)
|
201 |
+
asr = asr.transpose(1, 2)
|
202 |
+
asr_new = torch.zeros_like(asr)
|
203 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
204 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
205 |
+
asr = asr_new
|
206 |
+
x = self.decoder(asr=asr,
|
207 |
+
F0_curve=F0_pred,
|
208 |
+
N=N_pred,
|
209 |
+
s=ref_s[:, :128, :]) # different part of ref_s
|
210 |
+
# print(x.shape, 'TTS TTS TTS TTS')
|
211 |
+
if x.shape[2] < 100:
|
212 |
+
x = torch.zeros(1, 1, 1000, device=self.device()) # silence if this sentence was empty
|
213 |
+
|
214 |
+
# NORMALIS / Crop Scratch at end (The endingscratch sound is not solved even with nltk.sentence split & punctuation)
|
215 |
+
x = x[..., 40:-4000]
|
216 |
+
# x /= x.abs().max() + 1e-7 # preserve as torch
|
217 |
+
# return x
|
218 |
+
if x.shape[2] == 0:
|
219 |
+
# nohing to vocode
|
220 |
+
x = torch.zeros(1, 1, 1000, device=self.device())
|
221 |
+
total.append(x)
|
222 |
+
|
223 |
+
# --
|
224 |
+
total = 1.94 * torch.cat(total, 2) # 1.94 * Perhaps exceeding -1,1 affects MIMI encode
|
225 |
+
total /= 1.02 * total.abs().max() + 1e-7
|
226 |
+
# --
|
227 |
+
return total
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
def get_padding(kernel_size, dilation=1):
|
233 |
+
return int((kernel_size*dilation - dilation)/2)
|
234 |
+
|
235 |
+
|
236 |
+
def _tile(x,
|
237 |
+
length=None):
|
238 |
+
x = x.repeat(1, 1, int(length / x.shape[2]) + 1)[:, :, :length]
|
239 |
+
return x
|
240 |
+
|
241 |
+
|
242 |
+
class AdaIN1d(nn.Module):
|
243 |
+
|
244 |
+
# used by HiFiGan & ProsodyPredictor
|
245 |
+
|
246 |
+
def __init__(self, style_dim, num_features):
|
247 |
+
super().__init__()
|
248 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
249 |
+
self.fc = nn.Linear(style_dim, num_features*2)
|
250 |
+
|
251 |
+
def forward(self, x, s):
|
252 |
+
|
253 |
+
# x = torch.Size([1, 512, 248]) same as output
|
254 |
+
# s = torch.Size([1, 7, 1, 128])
|
255 |
+
|
256 |
+
s = self.fc(s.transpose(1, 2)).transpose(1, 2)
|
257 |
+
|
258 |
+
s = _tile(s, length=x.shape[2])
|
259 |
+
|
260 |
+
gamma, beta = torch.chunk(s, chunks=2, dim=1)
|
261 |
+
return (1+gamma) * self.norm(x) + beta
|
262 |
+
|
263 |
+
|
264 |
+
class AdaINResBlock1(torch.nn.Module):
|
265 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
266 |
+
super(AdaINResBlock1, self).__init__()
|
267 |
+
self.convs1 = nn.ModuleList([
|
268 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
269 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
270 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
271 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
272 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
273 |
+
padding=get_padding(kernel_size, dilation[2])))
|
274 |
+
])
|
275 |
+
# self.convs1.apply(init_weights)
|
276 |
+
|
277 |
+
self.convs2 = nn.ModuleList([
|
278 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
279 |
+
padding=get_padding(kernel_size, 1))),
|
280 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
281 |
+
padding=get_padding(kernel_size, 1))),
|
282 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
283 |
+
padding=get_padding(kernel_size, 1)))
|
284 |
+
])
|
285 |
+
# self.convs2.apply(init_weights)
|
286 |
+
|
287 |
+
self.adain1 = nn.ModuleList([
|
288 |
+
AdaIN1d(style_dim, channels),
|
289 |
+
AdaIN1d(style_dim, channels),
|
290 |
+
AdaIN1d(style_dim, channels),
|
291 |
+
])
|
292 |
+
|
293 |
+
self.adain2 = nn.ModuleList([
|
294 |
+
AdaIN1d(style_dim, channels),
|
295 |
+
AdaIN1d(style_dim, channels),
|
296 |
+
AdaIN1d(style_dim, channels),
|
297 |
+
])
|
298 |
+
|
299 |
+
self.alpha1 = nn.ParameterList(
|
300 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
301 |
+
self.alpha2 = nn.ParameterList(
|
302 |
+
[nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
303 |
+
|
304 |
+
def forward(self, x, s):
|
305 |
+
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
306 |
+
xt = n1(x, s) # THIS IS ADAIN - EXPECTS conv1d dims
|
307 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
308 |
+
xt = c1(xt)
|
309 |
+
xt = n2(xt, s) # THIS IS ADAIN - EXPECTS conv1d dims
|
310 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
311 |
+
xt = c2(xt)
|
312 |
+
x = xt + x
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
317 |
+
|
318 |
+
def __init__(self):
|
319 |
+
|
320 |
+
super().__init__()
|
321 |
+
self.harmonic_num = 8
|
322 |
+
self.l_linear = torch.nn.Linear(self.harmonic_num + 1, 1)
|
323 |
+
self.upsample_scale = 300
|
324 |
+
|
325 |
+
|
326 |
+
def forward(self, x):
|
327 |
+
# --
|
328 |
+
x = torch.multiply(x, torch.FloatTensor(
|
329 |
+
[[range(1, self.harmonic_num + 2)]]).to(x.device)) # [1, 145200, 9]
|
330 |
+
|
331 |
+
# modulo of negative f0_values => -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT f0_values IS SIGNED
|
332 |
+
rad_values = x / 25647 #).clamp(0, 1)
|
333 |
+
# rad_values = torch.where(torch.logical_or(rad_values < 0, rad_values > 1), 0.5, rad_values)
|
334 |
+
rad_values = rad_values % 1 # % of neg values
|
335 |
+
rad_values = F.interpolate(rad_values.transpose(1, 2),
|
336 |
+
scale_factor=1/self.upsample_scale,
|
337 |
+
mode='linear').transpose(1, 2)
|
338 |
+
|
339 |
+
# 1.89 sounds also nice has woofer at punctuation
|
340 |
+
phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi
|
341 |
+
phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
342 |
+
scale_factor=self.upsample_scale, mode='linear').transpose(1, 2)
|
343 |
+
x = .009 * phase.sin()
|
344 |
+
# --
|
345 |
+
x = self.l_linear(x).tanh()
|
346 |
+
return x
|
347 |
+
|
348 |
+
|
349 |
+
class Generator(torch.nn.Module):
|
350 |
+
def __init__(self,
|
351 |
+
style_dim,
|
352 |
+
resblock_kernel_sizes,
|
353 |
+
upsample_rates,
|
354 |
+
upsample_initial_channel,
|
355 |
+
resblock_dilation_sizes,
|
356 |
+
upsample_kernel_sizes):
|
357 |
+
super(Generator, self).__init__()
|
358 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
359 |
+
self.num_upsamples = len(upsample_rates)
|
360 |
+
self.m_source = SourceModuleHnNSF()
|
361 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
362 |
+
self.noise_convs = nn.ModuleList()
|
363 |
+
self.ups = nn.ModuleList()
|
364 |
+
self.noise_res = nn.ModuleList()
|
365 |
+
|
366 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
367 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
368 |
+
|
369 |
+
self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
|
370 |
+
upsample_initial_channel//(
|
371 |
+
2**(i+1)),
|
372 |
+
k, u, padding=(u//2 + u % 2), output_padding=u % 2)))
|
373 |
+
|
374 |
+
if i + 1 < len(upsample_rates):
|
375 |
+
stride_f0 = np.prod(upsample_rates[i + 1:])
|
376 |
+
self.noise_convs.append(Conv1d(
|
377 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
378 |
+
self.noise_res.append(AdaINResBlock1(
|
379 |
+
c_cur, 7, [1, 3, 5], style_dim))
|
380 |
+
else:
|
381 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
382 |
+
self.noise_res.append(AdaINResBlock1(
|
383 |
+
c_cur, 11, [1, 3, 5], style_dim))
|
384 |
+
|
385 |
+
self.resblocks = nn.ModuleList()
|
386 |
+
|
387 |
+
self.alphas = nn.ParameterList()
|
388 |
+
self.alphas.append(nn.Parameter(
|
389 |
+
torch.ones(1, upsample_initial_channel, 1)))
|
390 |
+
|
391 |
+
for i in range(len(self.ups)):
|
392 |
+
ch = upsample_initial_channel//(2**(i+1))
|
393 |
+
self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
|
394 |
+
|
395 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
396 |
+
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
|
397 |
+
|
398 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
399 |
+
|
400 |
+
def forward(self, x, s, f0):
|
401 |
+
|
402 |
+
# x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 484]) GENERAT 249
|
403 |
+
f0 = self.f0_upsamp(f0).transpose(1, 2)
|
404 |
+
|
405 |
+
# x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 145200, 1]) GENERAT 253
|
406 |
+
|
407 |
+
# [1, 145400, 1] f0 enters already upsampled to full wav 24kHz length
|
408 |
+
har_source = self.m_source(f0)
|
409 |
+
|
410 |
+
har_source = har_source.transpose(1, 2)
|
411 |
+
|
412 |
+
for i in range(self.num_upsamples):
|
413 |
+
|
414 |
+
x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
|
415 |
+
x_source = self.noise_convs[i](har_source)
|
416 |
+
x_source = self.noise_res[i](x_source, s)
|
417 |
+
|
418 |
+
x = self.ups[i](x)
|
419 |
+
|
420 |
+
x = x + x_source
|
421 |
+
|
422 |
+
xs = None
|
423 |
+
for j in range(self.num_kernels):
|
424 |
+
|
425 |
+
if xs is None:
|
426 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
427 |
+
else:
|
428 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
429 |
+
x = xs / self.num_kernels
|
430 |
+
# x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2) # noisy
|
431 |
+
x = self.conv_post(x)
|
432 |
+
x = torch.tanh(x)
|
433 |
+
|
434 |
+
return x
|
435 |
+
|
436 |
+
class AdainResBlk1d(nn.Module):
|
437 |
+
|
438 |
+
# also used in ProsodyPredictor()
|
439 |
+
|
440 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
441 |
+
upsample='none', dropout_p=0.0):
|
442 |
+
super().__init__()
|
443 |
+
self.actv = actv
|
444 |
+
self.upsample_type = upsample
|
445 |
+
self.upsample = UpSample1d(upsample)
|
446 |
+
self.learned_sc = dim_in != dim_out
|
447 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
448 |
+
if upsample == 'none':
|
449 |
+
self.pool = nn.Identity()
|
450 |
+
else:
|
451 |
+
self.pool = weight_norm(nn.ConvTranspose1d(
|
452 |
+
dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
453 |
+
|
454 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
455 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
456 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
457 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
458 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
459 |
+
if self.learned_sc:
|
460 |
+
self.conv1x1 = weight_norm(
|
461 |
+
nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
462 |
+
|
463 |
+
def _shortcut(self, x):
|
464 |
+
x = self.upsample(x)
|
465 |
+
if self.learned_sc:
|
466 |
+
x = self.conv1x1(x)
|
467 |
+
return x
|
468 |
+
|
469 |
+
def _residual(self, x, s):
|
470 |
+
x = self.norm1(x, s)
|
471 |
+
x = self.actv(x)
|
472 |
+
x = self.pool(x)
|
473 |
+
x = self.conv1(x)
|
474 |
+
x = self.norm2(x, s)
|
475 |
+
x = self.actv(x)
|
476 |
+
x = self.conv2(x)
|
477 |
+
return x
|
478 |
+
|
479 |
+
def forward(self, x, s):
|
480 |
+
out = self._residual(x, s)
|
481 |
+
out = (out + self._shortcut(x)) / math.sqrt(2)
|
482 |
+
return out
|
483 |
+
|
484 |
+
|
485 |
+
class UpSample1d(nn.Module):
|
486 |
+
def __init__(self, layer_type):
|
487 |
+
super().__init__()
|
488 |
+
self.layer_type = layer_type
|
489 |
+
|
490 |
+
def forward(self, x):
|
491 |
+
if self.layer_type == 'none':
|
492 |
+
return x
|
493 |
+
else:
|
494 |
+
return F.interpolate(x, scale_factor=2, mode='nearest-exact')
|
495 |
+
|
496 |
+
|
497 |
+
class Decoder(nn.Module):
|
498 |
+
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
499 |
+
resblock_kernel_sizes=[3, 7, 11],
|
500 |
+
upsample_rates=[10, 5, 3, 2],
|
501 |
+
upsample_initial_channel=512,
|
502 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
503 |
+
upsample_kernel_sizes=[20, 10, 6, 4]):
|
504 |
+
super().__init__()
|
505 |
+
|
506 |
+
self.decode = nn.ModuleList()
|
507 |
+
|
508 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
509 |
+
|
510 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
511 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
512 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
513 |
+
self.decode.append(AdainResBlk1d(
|
514 |
+
1024 + 2 + 64, 512, style_dim, upsample=True))
|
515 |
+
|
516 |
+
self.F0_conv = weight_norm(
|
517 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) # smooth
|
518 |
+
|
519 |
+
self.N_conv = weight_norm(
|
520 |
+
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
521 |
+
|
522 |
+
self.asr_res = nn.Sequential(
|
523 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
524 |
+
)
|
525 |
+
|
526 |
+
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
527 |
+
upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
|
528 |
+
|
529 |
+
def forward(self, asr=None, F0_curve=None, N=None, s=None):
|
530 |
+
|
531 |
+
|
532 |
+
F0 = self.F0_conv(F0_curve)
|
533 |
+
N = self.N_conv(N)
|
534 |
+
|
535 |
+
|
536 |
+
x = torch.cat([asr, F0, N], axis=1)
|
537 |
+
|
538 |
+
x = self.encode(x, s)
|
539 |
+
|
540 |
+
asr_res = self.asr_res(asr)
|
541 |
+
|
542 |
+
res = True
|
543 |
+
for block in self.decode:
|
544 |
+
if res:
|
545 |
+
|
546 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
547 |
+
|
548 |
+
x = block(x, s)
|
549 |
+
if block.upsample_type != "none":
|
550 |
+
res = False
|
551 |
+
|
552 |
+
x = self.generator(x, s, F0_curve)
|
553 |
+
return x
|
554 |
+
|
555 |
+
|
556 |
+
class MelSpec(torch.nn.Module):
|
557 |
+
|
558 |
+
def __init__(self,
|
559 |
+
sample_rate=17402, # https://github.com/fakerybakery/styletts2-cli/blob/main/msinference.py = Default 16000. However 17400 vocalises better also "en_US/vctk_p274"
|
560 |
+
n_fft=2048,
|
561 |
+
win_length=1200,
|
562 |
+
hop_length=300,
|
563 |
+
n_mels=80
|
564 |
+
):
|
565 |
+
'''avoids dependency on torchaudio'''
|
566 |
+
super().__init__()
|
567 |
+
self.n_fft = n_fft
|
568 |
+
self.win_length = win_length if win_length is not None else n_fft
|
569 |
+
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
|
570 |
+
# --
|
571 |
+
f_min = 0.0
|
572 |
+
f_max = float(sample_rate // 2)
|
573 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_fft//2+1)
|
574 |
+
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
|
575 |
+
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
|
576 |
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
577 |
+
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
|
578 |
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
|
579 |
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
|
580 |
+
zero = torch.zeros(1)
|
581 |
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
|
582 |
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
|
583 |
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
584 |
+
# --
|
585 |
+
self.register_buffer('fb', fb, persistent=False)
|
586 |
+
window = torch.hann_window(self.win_length)
|
587 |
+
self.register_buffer('window', window, persistent=False)
|
588 |
+
|
589 |
+
def forward(self, x):
|
590 |
+
spec_f = torch.stft(x,
|
591 |
+
self.n_fft,
|
592 |
+
self.hop_length,
|
593 |
+
self.win_length,
|
594 |
+
self.window,
|
595 |
+
center=True,
|
596 |
+
pad_mode="reflect",
|
597 |
+
normalized=False,
|
598 |
+
onesided=True,
|
599 |
+
return_complex=True) # [bs, 1025, 56]
|
600 |
+
mel_specgram = torch.matmul(spec_f.abs().pow(2).transpose(1, 2), self.fb).transpose(1, 2)
|
601 |
+
return mel_specgram[:, None, :, :] # [bs, 1, 80, time]
|
602 |
+
|
603 |
+
|
604 |
+
class LearnedDownSample(nn.Module):
|
605 |
+
def __init__(self, dim_in):
|
606 |
+
super().__init__()
|
607 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(
|
608 |
+
3, 3), stride=(2, 2), groups=dim_in, padding=1))
|
609 |
+
|
610 |
+
def forward(self, x):
|
611 |
+
return self.conv(x)
|
612 |
+
|
613 |
+
|
614 |
+
class ResBlk(nn.Module):
|
615 |
+
def __init__(self,
|
616 |
+
dim_in, dim_out):
|
617 |
+
super().__init__()
|
618 |
+
self.actv = nn.LeakyReLU(0.2) # .07 also nice
|
619 |
+
self.downsample_res = LearnedDownSample(dim_in)
|
620 |
+
self.learned_sc = dim_in != dim_out
|
621 |
+
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
622 |
+
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
623 |
+
if self.learned_sc:
|
624 |
+
self.conv1x1 = spectral_norm(
|
625 |
+
nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
|
626 |
+
|
627 |
+
def _shortcut(self, x):
|
628 |
+
if self.learned_sc:
|
629 |
+
x = self.conv1x1(x)
|
630 |
+
if x.shape[3] % 2 != 0: # [bs, 128, Freq, Time]
|
631 |
+
x = torch.cat([x, x[:, :, :, -1:]], dim=3)
|
632 |
+
return F.interpolate(x, scale_factor=.5, mode='nearest-exact') # F.avg_pool2d(x, 2)
|
633 |
+
|
634 |
+
def _residual(self, x):
|
635 |
+
x = self.actv(x)
|
636 |
+
x = self.conv1(x)
|
637 |
+
x = self.downsample_res(x)
|
638 |
+
x = self.actv(x)
|
639 |
+
x = self.conv2(x)
|
640 |
+
return x
|
641 |
+
|
642 |
+
def forward(self, x):
|
643 |
+
x = self._shortcut(x) + self._residual(x)
|
644 |
+
return x / math.sqrt(2) # unit variance
|
645 |
+
|
646 |
+
|
647 |
+
class StyleEncoder(nn.Module):
|
648 |
+
|
649 |
+
# for both acoustic & prosodic ref_s/p
|
650 |
+
|
651 |
+
def __init__(self,
|
652 |
+
dim_in=64,
|
653 |
+
style_dim=128,
|
654 |
+
max_conv_dim=512):
|
655 |
+
super().__init__()
|
656 |
+
blocks = [spectral_norm(nn.Conv2d(1, dim_in, 3, stride=1, padding=1))]
|
657 |
+
for _ in range(4):
|
658 |
+
dim_out = min(dim_in * 2,
|
659 |
+
max_conv_dim)
|
660 |
+
blocks += [ResBlk(dim_in, dim_out)]
|
661 |
+
dim_in = dim_out
|
662 |
+
blocks += [nn.LeakyReLU(0.24), # w/o this activation - produces no speech
|
663 |
+
spectral_norm(nn.Conv2d(dim_out, dim_out, 5, stride=1, padding=0)),
|
664 |
+
nn.LeakyReLU(0.2) # 0.3 sounds nice
|
665 |
+
]
|
666 |
+
self.shared = nn.Sequential(*blocks)
|
667 |
+
self.unshared = nn.Linear(dim_out, style_dim)
|
668 |
+
|
669 |
+
def forward(self, x):
|
670 |
+
x = self.shared(x)
|
671 |
+
x = x.mean(3, keepdims=True) # comment this line for time varying style vector
|
672 |
+
x = x.transpose(1, 3)
|
673 |
+
s = self.unshared(x)
|
674 |
+
return s
|
675 |
+
|
676 |
+
|
677 |
+
class LinearNorm(torch.nn.Module):
|
678 |
+
def __init__(self, in_dim, out_dim, bias=True):
|
679 |
+
super().__init__()
|
680 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
681 |
+
|
682 |
+
def forward(self, x):
|
683 |
+
return self.linear_layer(x)
|
684 |
+
|
685 |
+
|
686 |
+
class LayerNorm(nn.Module):
|
687 |
+
def __init__(self, channels, eps=1e-5):
|
688 |
+
super().__init__()
|
689 |
+
self.channels = channels
|
690 |
+
self.eps = eps
|
691 |
+
|
692 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
693 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
694 |
+
|
695 |
+
def forward(self, x):
|
696 |
+
x = x.transpose(1, -1)
|
697 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
698 |
+
return x.transpose(1, -1)
|
699 |
+
|
700 |
+
|
701 |
+
class TextEncoder(nn.Module):
|
702 |
+
def __init__(self, channels, kernel_size, depth, n_symbols):
|
703 |
+
super().__init__()
|
704 |
+
self.embedding = nn.Embedding(n_symbols, channels)
|
705 |
+
padding = (kernel_size - 1) // 2
|
706 |
+
self.cnn = nn.ModuleList()
|
707 |
+
for _ in range(depth):
|
708 |
+
self.cnn.append(nn.Sequential(
|
709 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
710 |
+
LayerNorm(channels),
|
711 |
+
nn.LeakyReLU(0.24))
|
712 |
+
)
|
713 |
+
self.lstm = nn.LSTM(channels, channels//2, 1,
|
714 |
+
batch_first=True, bidirectional=True)
|
715 |
+
|
716 |
+
def forward(self, x):
|
717 |
+
x = self.embedding(x) # [B, T, emb]
|
718 |
+
x = x.transpose(1, 2)
|
719 |
+
for c in self.cnn:
|
720 |
+
x = c(x)
|
721 |
+
x = x.transpose(1, 2)
|
722 |
+
x, _ = self.lstm(x)
|
723 |
+
return x
|
724 |
+
|
725 |
+
|
726 |
+
class AdaLayerNorm(nn.Module):
|
727 |
+
|
728 |
+
def __init__(self, style_dim, channels=None, eps=1e-5):
|
729 |
+
super().__init__()
|
730 |
+
self.eps = eps
|
731 |
+
self.fc = nn.Linear(style_dim, 1024)
|
732 |
+
|
733 |
+
def forward(self, x, s):
|
734 |
+
h = self.fc(s)
|
735 |
+
gamma = h[:, :, :512]
|
736 |
+
beta = h[:, :, 512:1024]
|
737 |
+
x = F.layer_norm(x, (512, ), eps=self.eps)
|
738 |
+
x = (1 + gamma) * x + beta
|
739 |
+
return x # [1, 75, 512]
|
740 |
+
|
741 |
+
|
742 |
+
class ProsodyPredictor(nn.Module):
|
743 |
+
|
744 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50):
|
745 |
+
super().__init__()
|
746 |
+
|
747 |
+
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
748 |
+
d_model=d_hid,
|
749 |
+
nlayers=nlayers) # called outside forward
|
750 |
+
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2,
|
751 |
+
1, batch_first=True, bidirectional=True)
|
752 |
+
self.duration_proj = LinearNorm(d_hid, max_dur)
|
753 |
+
self.shared = nn.LSTM(d_hid + style_dim, d_hid //
|
754 |
+
2, 1, batch_first=True, bidirectional=True)
|
755 |
+
self.F0 = nn.ModuleList([
|
756 |
+
AdainResBlk1d(d_hid, d_hid, style_dim),
|
757 |
+
AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
|
758 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim),
|
759 |
+
])
|
760 |
+
self.N = nn.ModuleList([
|
761 |
+
AdainResBlk1d(d_hid, d_hid, style_dim),
|
762 |
+
AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True),
|
763 |
+
AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim)
|
764 |
+
])
|
765 |
+
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
766 |
+
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
767 |
+
|
768 |
+
def F0Ntrain(self, x, s):
|
769 |
+
|
770 |
+
x, _ = self.shared(x) # [bs, time, ch] LSTM
|
771 |
+
|
772 |
+
x = x.transpose(1, 2) # [bs, ch, time]
|
773 |
+
|
774 |
+
F0 = x
|
775 |
+
|
776 |
+
for block in self.F0:
|
777 |
+
# print(f'LOOP {F0.shape=} {s.shape=}\n')
|
778 |
+
# )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
|
779 |
+
# This is an AdainResBlk1d expects conv1d dimensions
|
780 |
+
F0 = block(F0, s)
|
781 |
+
F0 = self.F0_proj(F0)
|
782 |
+
|
783 |
+
N = x
|
784 |
+
|
785 |
+
for block in self.N:
|
786 |
+
N = block(N, s)
|
787 |
+
N = self.N_proj(N)
|
788 |
+
|
789 |
+
return F0, N
|
790 |
+
|
791 |
+
def forward(self, d_en=None, s=None):
|
792 |
+
blend = self.text_encoder(d_en, s)
|
793 |
+
x, _ = self.lstm(blend)
|
794 |
+
dur = self.duration_proj(x) # [bs, 150, 50]
|
795 |
+
|
796 |
+
_, input_length, classifier_50 = dur.shape
|
797 |
+
|
798 |
+
dur = dur[0, :, :]
|
799 |
+
dur = torch.sigmoid(dur).sum(1)
|
800 |
+
dur = dur.round().clamp(min=1).to(torch.int64)
|
801 |
+
aln_trg = torch.zeros(1,
|
802 |
+
dur.sum(),
|
803 |
+
input_length,
|
804 |
+
device=s.device)
|
805 |
+
c_frame = 0
|
806 |
+
for i in range(input_length):
|
807 |
+
aln_trg[:, c_frame:c_frame + dur[i], i] = 1
|
808 |
+
c_frame += dur[i]
|
809 |
+
en = torch.bmm(aln_trg, blend)
|
810 |
+
F0_pred, N_pred = self.F0Ntrain(en, s)
|
811 |
+
return aln_trg, F0_pred, N_pred
|
812 |
+
|
813 |
+
|
814 |
+
class DurationEncoder(nn.Module):
|
815 |
+
|
816 |
+
def __init__(self, sty_dim=128, d_model=512, nlayers=3):
|
817 |
+
super().__init__()
|
818 |
+
self.lstms = nn.ModuleList()
|
819 |
+
for _ in range(nlayers):
|
820 |
+
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
821 |
+
d_model // 2,
|
822 |
+
num_layers=1,
|
823 |
+
batch_first=True,
|
824 |
+
bidirectional=True
|
825 |
+
))
|
826 |
+
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
827 |
+
|
828 |
+
|
829 |
+
def forward(self, x, style):
|
830 |
+
|
831 |
+
_, _, input_lengths = x.shape # [bs, 512, time]
|
832 |
+
|
833 |
+
style = _tile(style, length=x.shape[2]).transpose(1, 2)
|
834 |
+
x = x.transpose(1, 2)
|
835 |
+
|
836 |
+
for block in self.lstms:
|
837 |
+
if isinstance(block, AdaLayerNorm):
|
838 |
+
|
839 |
+
x = block(x, style) # LSTM has transposed x
|
840 |
+
|
841 |
+
else:
|
842 |
+
x = torch.cat([x, style], axis=2)
|
843 |
+
# LSTM
|
844 |
+
|
845 |
+
x,_ = block(x) # expects [bs, time, chan] OUTPUTS [bs, time, 2*chan] 2x FROM BIDIRECTIONAL
|
846 |
+
|
847 |
+
return torch.cat([x, style], axis=2) # predictor.lstm()
|