Dionyssos commited on
Commit
d353343
·
1 Parent(s): 4795b97

HIFIGAN tune v 1.0

Browse files
Modules/hifigan.py CHANGED
@@ -4,13 +4,18 @@ import torch.nn as nn
4
  from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
  import math
7
- import random
8
  import numpy as np
9
 
 
 
 
10
  def get_padding(kernel_size, dilation=1):
11
  return int((kernel_size*dilation - dilation)/2)
12
 
13
- LRELU_SLOPE = 0.1
 
 
 
14
 
15
  class AdaIN1d(nn.Module):
16
 
@@ -22,15 +27,22 @@ class AdaIN1d(nn.Module):
22
  self.fc = nn.Linear(style_dim, num_features*2)
23
 
24
  def forward(self, x, s):
 
 
 
 
 
 
25
 
26
- s = self.fc(s) # [bs, 1024, 130]
27
- s = F.interpolate(s[:, :, 0, :].transpose(1,2), x.shape[2], mode='linear') # different time-resolution than Dur
28
 
29
- gamma, beta = torch.chunk(s, chunks=2, dim=1) # channels vary in for loop
30
 
31
- # affine (1 + lin(x)) * inst(x) + lin(x) is this a skip connection where the weight is a lin of itself
32
 
33
- return (1 + gamma) * self.norm(x) + beta # norm(x) = PLBERT has norm / beta&gamma = style has no norm()
 
 
 
 
34
 
35
  class AdaINResBlock1(torch.nn.Module):
36
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
@@ -73,10 +85,10 @@ class AdaINResBlock1(torch.nn.Module):
73
 
74
  def forward(self, x, s):
75
  for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
76
- xt = n1(x, s)
77
  xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
78
  xt = c1(xt)
79
- xt = n2(xt, s)
80
  xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
81
  xt = c2(xt)
82
  x = xt + x
@@ -89,205 +101,80 @@ class AdaINResBlock1(torch.nn.Module):
89
  remove_weight_norm(l)
90
 
91
  class SineGen(torch.nn.Module):
92
- """ Definition of sine generator
93
- SineGen(samp_rate, harmonic_num = 0,
94
- sine_amp = 0.1, noise_std = 0.003,
95
- voiced_threshold = 0,
96
- flag_for_pulse=False)
97
- samp_rate: sampling rate in Hz
98
- harmonic_num: number of harmonic overtones (default 0)
99
- sine_amp: amplitude of sine-wavefrom (default 0.1)
100
- noise_std: std of Gaussian noise (default 0.003)
101
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
102
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
103
- Note: when flag_for_pulse is True, the first time step of a voiced
104
- segment is always sin(np.pi) or cos(0)
105
- """
106
-
107
- def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
108
- sine_amp=0.1, noise_std=0.003,
109
- voiced_threshold=0,
110
- flag_for_pulse=False):
111
  super(SineGen, self).__init__()
112
- self.sine_amp = sine_amp
113
- self.noise_std = noise_std
114
  self.harmonic_num = harmonic_num
115
- self.dim = self.harmonic_num + 1
116
  self.sampling_rate = samp_rate
117
  self.voiced_threshold = voiced_threshold
118
- self.flag_for_pulse = flag_for_pulse
119
  self.upsample_scale = upsample_scale
120
 
121
- def _f02uv(self, f0):
122
- # generate uv signal
123
- uv = (f0 > self.voiced_threshold).type(torch.float32)
124
- return uv
125
-
126
  def _f02sine(self, f0_values):
127
- """ f0_values: (batchsize, length, dim)
128
- where dim indicates fundamental tone and overtones
129
- """
130
- # convert to F0 in rad. The interger part n can be ignored
131
- # because 2 * np.pi * n doesn't affect phase
132
- rad_values = (f0_values / self.sampling_rate) % 1
133
-
134
- # initial phase noise (no noise for fundamental component)
135
- rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
136
- device=f0_values.device)
137
- rand_ini[:, 0] = 0
138
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
139
-
140
- # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
141
- if not self.flag_for_pulse:
142
- # # for normal case
143
-
144
- # # To prevent torch.cumsum numerical overflow,
145
- # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
146
- # # Buffer tmp_over_one_idx indicates the time step to add -1.
147
- # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
148
- # tmp_over_one = torch.cumsum(rad_values, 1) % 1
149
- # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
150
- # cumsum_shift = torch.zeros_like(rad_values)
151
- # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
152
-
153
- # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
154
- rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
155
- scale_factor=1/self.upsample_scale,
156
- mode="linear").transpose(1, 2)
157
-
158
- # tmp_over_one = torch.cumsum(rad_values, 1) % 1
159
- # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
160
- # cumsum_shift = torch.zeros_like(rad_values)
161
- # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
162
-
163
- phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
164
- phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
165
- scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
166
- sines = torch.sin(phase)
167
-
168
- else:
169
- # If necessary, make sure that the first time step of every
170
- # voiced segments is sin(pi) or cos(0)
171
- # This is used for pulse-train generation
172
-
173
- # identify the last time step in unvoiced segments
174
- uv = self._f02uv(f0_values)
175
- uv_1 = torch.roll(uv, shifts=-1, dims=1)
176
- uv_1[:, -1, :] = 1
177
- u_loc = (uv < 1) * (uv_1 > 0)
178
-
179
- # get the instantanouse phase
180
- tmp_cumsum = torch.cumsum(rad_values, dim=1)
181
- # different batch needs to be processed differently
182
- for idx in range(f0_values.shape[0]):
183
- temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
184
- temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
185
- # stores the accumulation of i.phase within
186
- # each voiced segments
187
- tmp_cumsum[idx, :, :] = 0
188
- tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
189
-
190
- # rad_values - tmp_cumsum: remove the accumulation of i.phase
191
- # within the previous voiced segment.
192
- i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
193
-
194
- # get the sines
195
- sines = torch.cos(i_phase * 2 * np.pi)
196
  return sines
197
 
198
  def forward(self, f0):
199
- """ sine_tensor, uv = forward(f0)
200
- input F0: tensor(batchsize=1, length, dim=1)
201
- f0 for unvoiced steps should be 0
202
- output sine_tensor: tensor(batchsize=1, length, dim)
203
- output uv: tensor(batchsize=1, length, 1)
204
- """
205
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
206
- device=f0.device)
207
- # fundamental component
208
- fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
209
-
210
- # generate sine waveforms
211
- sine_waves = self._f02sine(fn) * self.sine_amp
212
-
213
- # generate uv signal
214
- # uv = torch.ones(f0.shape)
215
- # uv = uv * (f0 > self.voiced_threshold)
216
- uv = self._f02uv(f0)
217
-
218
- # noise: for unvoiced should be similar to sine_amp
219
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
- # . for voiced regions is self.noise_std
221
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
- noise = noise_amp * torch.randn_like(sine_waves)
223
-
224
- # first: set the unvoiced part to 0 by uv
225
- # then: additive noise
226
- sine_waves = sine_waves * uv + noise
227
- return sine_waves, uv, noise
228
-
229
 
230
- class SourceModuleHnNSF(torch.nn.Module):
231
- """ SourceModule for hn-nsf
232
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
- add_noise_std=0.003, voiced_threshod=0)
234
- sampling_rate: sampling_rate in Hz
235
- harmonic_num: number of harmonic above F0 (default: 0)
236
- sine_amp: amplitude of sine source signal (default: 0.1)
237
- add_noise_std: std of additive Gaussian noise (default: 0.003)
238
- note that amplitude of noise in unvoiced is decided
239
- by sine_amp
240
- voiced_threshold: threhold to set U/V given F0 (default: 0)
241
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
- F0_sampled (batchsize, length, 1)
243
- Sine_source (batchsize, length, 1)
244
- noise_source (batchsize, length 1)
245
- uv (batchsize, length, 1)
246
- """
247
-
248
- def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
- add_noise_std=0.003, voiced_threshod=0):
250
- super(SourceModuleHnNSF, self).__init__()
251
 
252
- self.sine_amp = sine_amp
253
- self.noise_std = add_noise_std
 
254
 
255
- # to produce sine waveforms
256
- self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
257
- sine_amp, add_noise_std, voiced_threshod)
258
 
259
- # to merge source harmonics into a single excitation
260
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
 
 
 
261
  self.l_tanh = torch.nn.Tanh()
262
 
263
  def forward(self, x):
264
- """
265
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
- F0_sampled (batchsize, length, 1)
267
- Sine_source (batchsize, length, 1)
268
- noise_source (batchsize, length 1)
269
- """
270
- # source for harmonic branch
271
- with torch.no_grad():
272
- sine_wavs, uv, _ = self.l_sin_gen(x)
273
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
274
-
275
- # source for noise branch, in the same shape as uv
276
- noise = torch.randn_like(uv) * self.sine_amp / 3
277
- return sine_merge, noise, uv
278
 
279
  class Generator(torch.nn.Module):
280
- def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
 
 
 
 
 
 
281
  super(Generator, self).__init__()
282
  self.num_kernels = len(resblock_kernel_sizes)
283
  self.num_upsamples = len(upsample_rates)
284
- resblock = AdaINResBlock1
285
-
286
- self.m_source = SourceModuleHnNSF(
287
- sampling_rate=24000,
288
- upsample_scale=np.prod(upsample_rates),
289
- harmonic_num=8, voiced_threshod=10)
290
-
291
  self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
292
  self.noise_convs = nn.ModuleList()
293
  self.ups = nn.ModuleList()
@@ -304,10 +191,10 @@ class Generator(torch.nn.Module):
304
  stride_f0 = np.prod(upsample_rates[i + 1:])
305
  self.noise_convs.append(Conv1d(
306
  1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
307
- self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
308
  else:
309
  self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
310
- self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
311
 
312
  self.resblocks = nn.ModuleList()
313
 
@@ -319,28 +206,35 @@ class Generator(torch.nn.Module):
319
  self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
320
 
321
  for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
322
- self.resblocks.append(resblock(ch, k, d, style_dim))
323
 
324
  self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
325
 
326
 
327
  def forward(self, x, s, f0):
328
 
329
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
 
 
 
330
 
331
- har_source, noi_source, uv = self.m_source(f0)
332
- har_source = har_source.transpose(1, 2)
333
 
 
 
334
  for i in range(self.num_upsamples):
 
335
  x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
336
  x_source = self.noise_convs[i](har_source)
337
  x_source = self.noise_res[i](x_source, s)
338
-
339
  x = self.ups[i](x)
 
340
  x = x + x_source
341
-
342
  xs = None
343
  for j in range(self.num_kernels):
 
344
  if xs is None:
345
  xs = self.resblocks[i*self.num_kernels+j](x, s)
346
  else:
@@ -373,9 +267,7 @@ class AdainResBlk1d(nn.Module):
373
  self.upsample_type = upsample
374
  self.upsample = UpSample1d(upsample)
375
  self.learned_sc = dim_in != dim_out
376
- self._build_weights(dim_in, dim_out, style_dim)
377
- self.dropout = nn.Dropout(dropout_p)
378
-
379
  if upsample == 'none':
380
  self.pool = nn.Identity()
381
  else:
@@ -400,10 +292,10 @@ class AdainResBlk1d(nn.Module):
400
  x = self.norm1(x, s)
401
  x = self.actv(x)
402
  x = self.pool(x)
403
- x = self.conv1(self.dropout(x))
404
  x = self.norm2(x, s)
405
  x = self.actv(x)
406
- x = self.conv2(self.dropout(x))
407
  return x
408
 
409
  def forward(self, x, s):
@@ -440,7 +332,7 @@ class Decoder(nn.Module):
440
  self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
441
  self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
442
 
443
- self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
444
 
445
  self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
446
 
@@ -453,21 +345,17 @@ class Decoder(nn.Module):
453
 
454
 
455
  def forward(self, asr=None, F0_curve=None, N=None, s=None):
456
- if self.training:
457
- downlist = [0, 3, 7]
458
- F0_down = downlist[random.randint(0, 2)]
459
- downlist = [0, 3, 7, 15]
460
- N_down = downlist[random.randint(0, 3)]
461
- if F0_down:
462
- F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
463
- if N_down:
464
- N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1) / N_down
465
-
466
 
467
- F0 = self.F0_conv(F0_curve.unsqueeze(1))
468
- N = self.N_conv(N.unsqueeze(1))
 
 
 
 
 
469
 
470
  x = torch.cat([asr, F0, N], axis=1)
 
471
  x = self.encode(x, s)
472
 
473
  asr_res = self.asr_res(asr)
@@ -475,7 +363,10 @@ class Decoder(nn.Module):
475
  res = True
476
  for block in self.decode:
477
  if res:
 
 
478
  x = torch.cat([x, asr_res, F0, N], axis=1)
 
479
  x = block(x, s)
480
  if block.upsample_type != "none":
481
  res = False
@@ -483,4 +374,4 @@ class Decoder(nn.Module):
483
  x = self.generator(x, s, F0_curve)
484
  return x
485
 
486
-
 
4
  from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
  import math
 
7
  import numpy as np
8
 
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
  def get_padding(kernel_size, dilation=1):
13
  return int((kernel_size*dilation - dilation)/2)
14
 
15
+ def _tile(x,
16
+ length=None):
17
+ x = x.repeat(1, 1, int(length / x.shape[2]) + 1)[:, :, :length]
18
+ return x
19
 
20
  class AdaIN1d(nn.Module):
21
 
 
27
  self.fc = nn.Linear(style_dim, num_features*2)
28
 
29
  def forward(self, x, s):
30
+
31
+ # x = torch.Size([1, 512, 248]) same as output
32
+ # s = torch.Size([1, 7, 1, 128])
33
+
34
+
35
+ s = self.fc(s.transpose(1, 2)).transpose(1, 2)
36
 
 
 
37
 
 
38
 
39
+ s = _tile(s, length=x.shape[2])
40
 
41
+ gamma, beta = torch.chunk(s, chunks=2, dim=1)
42
+ return (1+gamma) * self.norm(x) + beta
43
+
44
+
45
+
46
 
47
  class AdaINResBlock1(torch.nn.Module):
48
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
 
85
 
86
  def forward(self, x, s):
87
  for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
88
+ xt = n1(x, s) # THIS IS ADAIN - EXPECTS conv1d dims
89
  xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
90
  xt = c1(xt)
91
+ xt = n2(xt, s) # THIS IS ADAIN - EXPECTS conv1d dims
92
  xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
93
  xt = c2(xt)
94
  x = xt + x
 
101
  remove_weight_norm(l)
102
 
103
  class SineGen(torch.nn.Module):
104
+
105
+ def __init__(self,
106
+ samp_rate=24000,
107
+ upsample_scale=300,
108
+ harmonic_num=8, # HARDCODED due to nn.Linear() of SourceModuleHnNSF
109
+ voiced_threshold=10):
110
+
 
 
 
 
 
 
 
 
 
 
 
 
111
  super(SineGen, self).__init__()
 
 
112
  self.harmonic_num = harmonic_num
 
113
  self.sampling_rate = samp_rate
114
  self.voiced_threshold = voiced_threshold
 
115
  self.upsample_scale = upsample_scale
116
 
 
 
 
 
 
117
  def _f02sine(self, f0_values):
118
+ # --
119
+ # 134 HIFI
120
+ # torch.Size([1, 145200, 9])
121
+ # torch.Size([1, 145200, 9]) torch.Size([1, 145200, 9]) HIFi
122
+
123
+ rad_values = (f0_values / self.sampling_rate) % 1 # -21 % 10 = 9 as -3*10 + 9 = 21 NOTICE THAT LCM IS SIGNED HENCE not POSITIVE integer
124
+
125
+ # print('BEF', rad_values.shape)
126
+
127
+
128
+
129
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
130
+ scale_factor=1/self.upsample_scale,
131
+ mode="linear").transpose(1, 2)
132
+ print('AFt', rad_values.shape) # downsamples the phases to 1/300 and sums them to be 0,,1,100000,20000*2*pi
133
+ phase = torch.cumsum(rad_values, dim=1) * 1.84 * np.pi # 1.89 sounds also nice has woofer at punctuation
134
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
135
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
136
+ sines = torch.sin(phase)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  return sines
138
 
139
  def forward(self, f0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # f0 is already full length - [1, 142600, 1]
142
+
143
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) # [1, 145200, 9]
144
+
145
+ sine_waves = self._f02sine(fn) * .007 # very important effect DEFAULT=0.1 very sensitive to speaker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
148
+
149
+ return sine_waves * uv #+ noise
150
 
151
+ class SourceModuleHnNSF(torch.nn.Module):
 
 
152
 
153
+ def __init__(self, harmonic_num=8):
154
+
155
+ super(SourceModuleHnNSF, self).__init__()
156
+ self.l_sin_gen = SineGen()
157
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) # harmonic=8 is hard fixed due to this nn.Linear()
158
  self.l_tanh = torch.nn.Tanh()
159
 
160
  def forward(self, x):
161
+ # print(' HNnSF', x.shape) # why this is [1, 300, 1, 535800]
162
+ sine_wavs = self.l_sin_gen(x)
163
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs)) # This linear sums all 9 harmonics
164
+ return sine_merge
 
 
 
 
 
 
 
 
 
 
165
 
166
  class Generator(torch.nn.Module):
167
+ def __init__(self,
168
+ style_dim,
169
+ resblock_kernel_sizes,
170
+ upsample_rates,
171
+ upsample_initial_channel,
172
+ resblock_dilation_sizes,
173
+ upsample_kernel_sizes):
174
  super(Generator, self).__init__()
175
  self.num_kernels = len(resblock_kernel_sizes)
176
  self.num_upsamples = len(upsample_rates)
177
+ self.m_source = SourceModuleHnNSF()
 
 
 
 
 
 
178
  self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
179
  self.noise_convs = nn.ModuleList()
180
  self.ups = nn.ModuleList()
 
191
  stride_f0 = np.prod(upsample_rates[i + 1:])
192
  self.noise_convs.append(Conv1d(
193
  1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
194
+ self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
195
  else:
196
  self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
197
+ self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
198
 
199
  self.resblocks = nn.ModuleList()
200
 
 
206
  self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
207
 
208
  for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
209
+ self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
210
 
211
  self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
212
 
213
 
214
  def forward(self, x, s, f0):
215
 
216
+ # x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 484]) GENERAT 249
217
+ f0 = self.f0_upsamp(f0).transpose(1, 2)
218
+ print(f'{x.shape=} {s.shape=} {f0.shape=} GENERAT 249 LALALALALA\n\n')
219
+ # x.shape=torch.Size([1, 512, 484]) s.shape=torch.Size([1, 1, 1, 128]) f0.shape=torch.Size([1, 145200, 1]) GENERAT 253
220
 
221
+ har_source = self.m_source(f0) # [1, 145400, 1] f0 enters already upsampled to full wav 24kHz length
 
222
 
223
+ har_source = har_source.transpose(1, 2)
224
+
225
  for i in range(self.num_upsamples):
226
+
227
  x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
228
  x_source = self.noise_convs[i](har_source)
229
  x_source = self.noise_res[i](x_source, s)
230
+
231
  x = self.ups[i](x)
232
+ print(x.min(), x.max(), x_source.min(), x_source.max())
233
  x = x + x_source
234
+
235
  xs = None
236
  for j in range(self.num_kernels):
237
+
238
  if xs is None:
239
  xs = self.resblocks[i*self.num_kernels+j](x, s)
240
  else:
 
267
  self.upsample_type = upsample
268
  self.upsample = UpSample1d(upsample)
269
  self.learned_sc = dim_in != dim_out
270
+ self._build_weights(dim_in, dim_out, style_dim)
 
 
271
  if upsample == 'none':
272
  self.pool = nn.Identity()
273
  else:
 
292
  x = self.norm1(x, s)
293
  x = self.actv(x)
294
  x = self.pool(x)
295
+ x = self.conv1(x)
296
  x = self.norm2(x, s)
297
  x = self.actv(x)
298
+ x = self.conv2(x)
299
  return x
300
 
301
  def forward(self, x, s):
 
332
  self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
333
  self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
334
 
335
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)) # smooth
336
 
337
  self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
338
 
 
345
 
346
 
347
  def forward(self, asr=None, F0_curve=None, N=None, s=None):
 
 
 
 
 
 
 
 
 
 
348
 
349
+ print('p', asr.shape, F0_curve.shape, N.shape)
350
+ F0 = self.F0_conv(F0_curve)
351
+ N = self.N_conv(N)
352
+
353
+
354
+ print(asr.shape, F0.shape, N.shape, 'TF')
355
+
356
 
357
  x = torch.cat([asr, F0, N], axis=1)
358
+
359
  x = self.encode(x, s)
360
 
361
  asr_res = self.asr_res(asr)
 
363
  res = True
364
  for block in self.decode:
365
  if res:
366
+
367
+
368
  x = torch.cat([x, asr_res, F0, N], axis=1)
369
+
370
  x = block(x, s)
371
  if block.upsample_type != "none":
372
  res = False
 
374
  x = self.generator(x, s, F0_curve)
375
  return x
376
 
377
+
Utils/ASR/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
Utils/ASR/config.yml DELETED
@@ -1,29 +0,0 @@
1
- log_dir: "logs/20201006"
2
- save_freq: 5
3
- device: "cuda"
4
- epochs: 180
5
- batch_size: 64
6
- pretrained_model: ""
7
- train_data: "ASRDataset/train_list.txt"
8
- val_data: "ASRDataset/val_list.txt"
9
-
10
- dataset_params:
11
- data_augmentation: false
12
-
13
- preprocess_parasm:
14
- sr: 24000
15
- spect_params:
16
- n_fft: 2048
17
- win_length: 1200
18
- hop_length: 300
19
- mel_params:
20
- n_mels: 80
21
-
22
- model_params:
23
- input_dim: 80
24
- hidden_dim: 256
25
- n_token: 178
26
- token_embedding_dim: 512
27
-
28
- optimizer_params:
29
- lr: 0.0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Utils/ASR/epoch_00080.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
- size 94552811
 
 
 
 
Utils/ASR/layers.py DELETED
@@ -1,354 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from typing import Optional, Any
5
- from torch import Tensor
6
- import torch.nn.functional as F
7
- import torchaudio
8
- import torchaudio.functional as audio_F
9
-
10
- import random
11
- random.seed(0)
12
-
13
-
14
- def _get_activation_fn(activ):
15
- if activ == 'relu':
16
- return nn.ReLU()
17
- elif activ == 'lrelu':
18
- return nn.LeakyReLU(0.2)
19
- elif activ == 'swish':
20
- return lambda x: x*torch.sigmoid(x)
21
- else:
22
- raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
-
24
- class LinearNorm(torch.nn.Module):
25
- def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
- super(LinearNorm, self).__init__()
27
- self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
-
29
- torch.nn.init.xavier_uniform_(
30
- self.linear_layer.weight,
31
- gain=torch.nn.init.calculate_gain(w_init_gain))
32
-
33
- def forward(self, x):
34
- return self.linear_layer(x)
35
-
36
-
37
- class ConvNorm(torch.nn.Module):
38
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
- padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
- super(ConvNorm, self).__init__()
41
- if padding is None:
42
- assert(kernel_size % 2 == 1)
43
- padding = int(dilation * (kernel_size - 1) / 2)
44
-
45
- self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
- kernel_size=kernel_size, stride=stride,
47
- padding=padding, dilation=dilation,
48
- bias=bias)
49
-
50
- torch.nn.init.xavier_uniform_(
51
- self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
-
53
- def forward(self, signal):
54
- conv_signal = self.conv(signal)
55
- return conv_signal
56
-
57
- class CausualConv(nn.Module):
58
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
- super(CausualConv, self).__init__()
60
- if padding is None:
61
- assert(kernel_size % 2 == 1)
62
- padding = int(dilation * (kernel_size - 1) / 2) * 2
63
- else:
64
- self.padding = padding * 2
65
- self.conv = nn.Conv1d(in_channels, out_channels,
66
- kernel_size=kernel_size, stride=stride,
67
- padding=self.padding,
68
- dilation=dilation,
69
- bias=bias)
70
-
71
- torch.nn.init.xavier_uniform_(
72
- self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
-
74
- def forward(self, x):
75
- x = self.conv(x)
76
- x = x[:, :, :-self.padding]
77
- return x
78
-
79
- class CausualBlock(nn.Module):
80
- def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
- super(CausualBlock, self).__init__()
82
- self.blocks = nn.ModuleList([
83
- self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
- for i in range(n_conv)])
85
-
86
- def forward(self, x):
87
- for block in self.blocks:
88
- res = x
89
- x = block(x)
90
- x += res
91
- return x
92
-
93
- def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
- layers = [
95
- CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
- _get_activation_fn(activ),
97
- nn.BatchNorm1d(hidden_dim),
98
- nn.Dropout(p=dropout_p),
99
- CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
- _get_activation_fn(activ),
101
- nn.Dropout(p=dropout_p)
102
- ]
103
- return nn.Sequential(*layers)
104
-
105
- class ConvBlock(nn.Module):
106
- def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
- super().__init__()
108
- self._n_groups = 8
109
- self.blocks = nn.ModuleList([
110
- self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
- for i in range(n_conv)])
112
-
113
-
114
- def forward(self, x):
115
- for block in self.blocks:
116
- res = x
117
- x = block(x)
118
- x += res
119
- return x
120
-
121
- def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
- layers = [
123
- ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
- _get_activation_fn(activ),
125
- nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
- nn.Dropout(p=dropout_p),
127
- ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
- _get_activation_fn(activ),
129
- nn.Dropout(p=dropout_p)
130
- ]
131
- return nn.Sequential(*layers)
132
-
133
- class LocationLayer(nn.Module):
134
- def __init__(self, attention_n_filters, attention_kernel_size,
135
- attention_dim):
136
- super(LocationLayer, self).__init__()
137
- padding = int((attention_kernel_size - 1) / 2)
138
- self.location_conv = ConvNorm(2, attention_n_filters,
139
- kernel_size=attention_kernel_size,
140
- padding=padding, bias=False, stride=1,
141
- dilation=1)
142
- self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
- bias=False, w_init_gain='tanh')
144
-
145
- def forward(self, attention_weights_cat):
146
- processed_attention = self.location_conv(attention_weights_cat)
147
- processed_attention = processed_attention.transpose(1, 2)
148
- processed_attention = self.location_dense(processed_attention)
149
- return processed_attention
150
-
151
-
152
- class Attention(nn.Module):
153
- def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
- attention_location_n_filters, attention_location_kernel_size):
155
- super(Attention, self).__init__()
156
- self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
- bias=False, w_init_gain='tanh')
158
- self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
- w_init_gain='tanh')
160
- self.v = LinearNorm(attention_dim, 1, bias=False)
161
- self.location_layer = LocationLayer(attention_location_n_filters,
162
- attention_location_kernel_size,
163
- attention_dim)
164
- self.score_mask_value = -float("inf")
165
-
166
- def get_alignment_energies(self, query, processed_memory,
167
- attention_weights_cat):
168
- """
169
- PARAMS
170
- ------
171
- query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
- processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
- attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
- RETURNS
175
- -------
176
- alignment (batch, max_time)
177
- """
178
-
179
- processed_query = self.query_layer(query.unsqueeze(1))
180
- processed_attention_weights = self.location_layer(attention_weights_cat)
181
- energies = self.v(torch.tanh(
182
- processed_query + processed_attention_weights + processed_memory))
183
-
184
- energies = energies.squeeze(-1)
185
- return energies
186
-
187
- def forward(self, attention_hidden_state, memory, processed_memory,
188
- attention_weights_cat, mask):
189
- """
190
- PARAMS
191
- ------
192
- attention_hidden_state: attention rnn last output
193
- memory: encoder outputs
194
- processed_memory: processed encoder outputs
195
- attention_weights_cat: previous and cummulative attention weights
196
- mask: binary mask for padded data
197
- """
198
- alignment = self.get_alignment_energies(
199
- attention_hidden_state, processed_memory, attention_weights_cat)
200
-
201
- if mask is not None:
202
- alignment.data.masked_fill_(mask, self.score_mask_value)
203
-
204
- attention_weights = F.softmax(alignment, dim=1)
205
- attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
- attention_context = attention_context.squeeze(1)
207
-
208
- return attention_context, attention_weights
209
-
210
-
211
- class ForwardAttentionV2(nn.Module):
212
- def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
- attention_location_n_filters, attention_location_kernel_size):
214
- super(ForwardAttentionV2, self).__init__()
215
- self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
- bias=False, w_init_gain='tanh')
217
- self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
- w_init_gain='tanh')
219
- self.v = LinearNorm(attention_dim, 1, bias=False)
220
- self.location_layer = LocationLayer(attention_location_n_filters,
221
- attention_location_kernel_size,
222
- attention_dim)
223
- self.score_mask_value = -float(1e20)
224
-
225
- def get_alignment_energies(self, query, processed_memory,
226
- attention_weights_cat):
227
- """
228
- PARAMS
229
- ------
230
- query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
- processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
- attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
- RETURNS
234
- -------
235
- alignment (batch, max_time)
236
- """
237
-
238
- processed_query = self.query_layer(query.unsqueeze(1))
239
- processed_attention_weights = self.location_layer(attention_weights_cat)
240
- energies = self.v(torch.tanh(
241
- processed_query + processed_attention_weights + processed_memory))
242
-
243
- energies = energies.squeeze(-1)
244
- return energies
245
-
246
- def forward(self, attention_hidden_state, memory, processed_memory,
247
- attention_weights_cat, mask, log_alpha):
248
- """
249
- PARAMS
250
- ------
251
- attention_hidden_state: attention rnn last output
252
- memory: encoder outputs
253
- processed_memory: processed encoder outputs
254
- attention_weights_cat: previous and cummulative attention weights
255
- mask: binary mask for padded data
256
- """
257
- log_energy = self.get_alignment_energies(
258
- attention_hidden_state, processed_memory, attention_weights_cat)
259
-
260
- #log_energy =
261
-
262
- if mask is not None:
263
- log_energy.data.masked_fill_(mask, self.score_mask_value)
264
-
265
- #attention_weights = F.softmax(alignment, dim=1)
266
-
267
- #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
- #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
-
270
- #log_total_score = log_alpha + content_score
271
-
272
- #previous_attention_weights = attention_weights_cat[:,0,:]
273
-
274
- log_alpha_shift_padded = []
275
- max_time = log_energy.size(1)
276
- for sft in range(2):
277
- shifted = log_alpha[:,:max_time-sft]
278
- shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
- log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
-
281
- biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
-
283
- log_alpha_new = biased + log_energy
284
-
285
- attention_weights = F.softmax(log_alpha_new, dim=1)
286
-
287
- attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
- attention_context = attention_context.squeeze(1)
289
-
290
- return attention_context, attention_weights, log_alpha_new
291
-
292
-
293
- class PhaseShuffle2d(nn.Module):
294
- def __init__(self, n=2):
295
- super(PhaseShuffle2d, self).__init__()
296
- self.n = n
297
- self.random = random.Random(1)
298
-
299
- def forward(self, x, move=None):
300
- # x.size = (B, C, M, L)
301
- if move is None:
302
- move = self.random.randint(-self.n, self.n)
303
-
304
- if move == 0:
305
- return x
306
- else:
307
- left = x[:, :, :, :move]
308
- right = x[:, :, :, move:]
309
- shuffled = torch.cat([right, left], dim=3)
310
- return shuffled
311
-
312
- class PhaseShuffle1d(nn.Module):
313
- def __init__(self, n=2):
314
- super(PhaseShuffle1d, self).__init__()
315
- self.n = n
316
- self.random = random.Random(1)
317
-
318
- def forward(self, x, move=None):
319
- # x.size = (B, C, M, L)
320
- if move is None:
321
- move = self.random.randint(-self.n, self.n)
322
-
323
- if move == 0:
324
- return x
325
- else:
326
- left = x[:, :, :move]
327
- right = x[:, :, move:]
328
- shuffled = torch.cat([right, left], dim=2)
329
-
330
- return shuffled
331
-
332
- class MFCC(nn.Module):
333
- def __init__(self, n_mfcc=40, n_mels=80):
334
- super(MFCC, self).__init__()
335
- self.n_mfcc = n_mfcc
336
- self.n_mels = n_mels
337
- self.norm = 'ortho'
338
- dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
- self.register_buffer('dct_mat', dct_mat)
340
-
341
- def forward(self, mel_specgram):
342
- if len(mel_specgram.shape) == 2:
343
- mel_specgram = mel_specgram.unsqueeze(0)
344
- unsqueezed = True
345
- else:
346
- unsqueezed = False
347
- # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
- # -> (channel, time, n_mfcc).tranpose(...)
349
- mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
-
351
- # unpack batch
352
- if unsqueezed:
353
- mfcc = mfcc.squeeze(0)
354
- return mfcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Utils/ASR/models.py DELETED
@@ -1,186 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import TransformerEncoder
5
- import torch.nn.functional as F
6
- from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
-
8
- class ASRCNN(nn.Module):
9
- def __init__(self,
10
- input_dim=80,
11
- hidden_dim=256,
12
- n_token=35,
13
- n_layers=6,
14
- token_embedding_dim=256,
15
-
16
- ):
17
- super().__init__()
18
- self.n_token = n_token
19
- self.n_down = 1
20
- self.to_mfcc = MFCC()
21
- self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
- self.cnns = nn.Sequential(
23
- *[nn.Sequential(
24
- ConvBlock(hidden_dim),
25
- nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
- ) for n in range(n_layers)])
27
- self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
- self.ctc_linear = nn.Sequential(
29
- LinearNorm(hidden_dim//2, hidden_dim),
30
- nn.ReLU(),
31
- LinearNorm(hidden_dim, n_token))
32
- self.asr_s2s = ASRS2S(
33
- embedding_dim=token_embedding_dim,
34
- hidden_dim=hidden_dim//2,
35
- n_token=n_token)
36
-
37
- def forward(self, x, src_key_padding_mask=None, text_input=None):
38
- x = self.to_mfcc(x)
39
- x = self.init_cnn(x)
40
- x = self.cnns(x)
41
- x = self.projection(x)
42
- x = x.transpose(1, 2)
43
- ctc_logit = self.ctc_linear(x)
44
- if text_input is not None:
45
- _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
- return ctc_logit, s2s_logit, s2s_attn
47
- else:
48
- return ctc_logit
49
-
50
- def get_feature(self, x):
51
- x = self.to_mfcc(x.squeeze(1))
52
- x = self.init_cnn(x)
53
- x = self.cnns(x)
54
- x = self.projection(x)
55
- return x
56
-
57
- def length_to_mask(self, lengths):
58
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
- mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
- return mask
61
-
62
- def get_future_mask(self, out_length, unmask_future_steps=0):
63
- """
64
- Args:
65
- out_length (int): returned mask shape is (out_length, out_length).
66
- unmask_futre_steps (int): unmasking future step size.
67
- Return:
68
- mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
- """
70
- index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
- mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
- return mask
73
-
74
- class ASRS2S(nn.Module):
75
- def __init__(self,
76
- embedding_dim=256,
77
- hidden_dim=512,
78
- n_location_filters=32,
79
- location_kernel_size=63,
80
- n_token=40):
81
- super(ASRS2S, self).__init__()
82
- self.embedding = nn.Embedding(n_token, embedding_dim)
83
- val_range = math.sqrt(6 / hidden_dim)
84
- self.embedding.weight.data.uniform_(-val_range, val_range)
85
-
86
- self.decoder_rnn_dim = hidden_dim
87
- self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
- self.attention_layer = Attention(
89
- self.decoder_rnn_dim,
90
- hidden_dim,
91
- hidden_dim,
92
- n_location_filters,
93
- location_kernel_size
94
- )
95
- self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
- self.project_to_hidden = nn.Sequential(
97
- LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
- nn.Tanh())
99
- self.sos = 1
100
- self.eos = 2
101
-
102
- def initialize_decoder_states(self, memory, mask):
103
- """
104
- moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
- """
106
- B, L, H = memory.shape
107
- self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
- self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
- self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
- self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
- self.attention_context = torch.zeros((B, H)).type_as(memory)
112
- self.memory = memory
113
- self.processed_memory = self.attention_layer.memory_layer(memory)
114
- self.mask = mask
115
- self.unk_index = 3
116
- self.random_mask = 0.1
117
-
118
- def forward(self, memory, memory_mask, text_input):
119
- """
120
- moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
- moemory_mask.shape = (B, L, )
122
- texts_input.shape = (B, T)
123
- """
124
- self.initialize_decoder_states(memory, memory_mask)
125
- # text random mask
126
- random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
- _text_input = text_input.clone()
128
- _text_input.masked_fill_(random_mask, self.unk_index)
129
- decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
- start_embedding = self.embedding(
131
- torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
- decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
-
134
- hidden_outputs, logit_outputs, alignments = [], [], []
135
- while len(hidden_outputs) < decoder_inputs.size(0):
136
-
137
- decoder_input = decoder_inputs[len(hidden_outputs)]
138
- hidden, logit, attention_weights = self.decode(decoder_input)
139
- hidden_outputs += [hidden]
140
- logit_outputs += [logit]
141
- alignments += [attention_weights]
142
-
143
- hidden_outputs, logit_outputs, alignments = \
144
- self.parse_decoder_outputs(
145
- hidden_outputs, logit_outputs, alignments)
146
-
147
- return hidden_outputs, logit_outputs, alignments
148
-
149
-
150
- def decode(self, decoder_input):
151
-
152
- cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
- self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
- cell_input,
155
- (self.decoder_hidden, self.decoder_cell))
156
-
157
- attention_weights_cat = torch.cat(
158
- (self.attention_weights.unsqueeze(1),
159
- self.attention_weights_cum.unsqueeze(1)),dim=1)
160
-
161
- self.attention_context, self.attention_weights = self.attention_layer(
162
- self.decoder_hidden,
163
- self.memory,
164
- self.processed_memory,
165
- attention_weights_cat,
166
- self.mask)
167
-
168
- self.attention_weights_cum += self.attention_weights
169
-
170
- hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
- hidden = self.project_to_hidden(hidden_and_context)
172
-
173
- # dropout to increasing g
174
- logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
-
176
- return hidden, logit, self.attention_weights
177
-
178
- def parse_decoder_outputs(self, hidden, logit, alignments):
179
-
180
- # -> [B, T_out + 1, max_time]
181
- alignments = torch.stack(alignments).transpose(0,1)
182
- # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
- logit = torch.stack(logit).transpose(0, 1).contiguous()
184
- hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
-
186
- return hidden, logit, alignments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -6,9 +6,9 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch.nn.utils import weight_norm, spectral_norm
9
- from Utils.ASR.models import ASRCNN
10
  from Utils.JDC.model import JDCNet
11
- from Modules.hifigan import AdainResBlk1d
12
  import yaml
13
 
14
 
@@ -257,20 +257,6 @@ class TextEncoder(nn.Module):
257
  x, batch_first=True)
258
  x = x.transpose(-1, -2)
259
  return x
260
-
261
- # def inference(self, x):
262
- # x = self.embedding(x)
263
- # x = x.transpose(1, 2)
264
- # x = self.cnn(x)
265
- # x = x.transpose(1, 2)
266
- # self.lstm.flatten_parameters()
267
- # x, _ = self.lstm(x)
268
- # return x
269
-
270
- # def length_to_mask(self, lengths):
271
- # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
272
- # mask = torch.gt(mask+1, lengths.unsqueeze(1))
273
- # return mask
274
 
275
  class AdaLayerNorm(nn.Module):
276
 
@@ -318,25 +304,28 @@ class ProsodyPredictor(nn.Module):
318
  self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
319
 
320
  def F0Ntrain(self, x, s):
321
- x, _ = self.shared(x.transpose(-1, -2))
 
 
 
322
 
323
- F0 = x.transpose(-1, -2)
324
 
 
325
 
326
  for block in self.F0:
327
- print(f'F)N {F0.shape=} {s.shape=}\n')
328
  # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
329
-
330
- F0 = block(F0, s)
331
  F0 = self.F0_proj(F0)
332
-
333
- N = x.transpose(-1, -2)
 
334
  for block in self.N:
335
  N = block(N, s)
336
  N = self.N_proj(N)
337
 
338
- return F0.squeeze(1), N.squeeze(1)
339
-
340
  class DurationEncoder(nn.Module):
341
 
342
  def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
@@ -357,13 +346,13 @@ class DurationEncoder(nn.Module):
357
  self.sty_dim = sty_dim
358
 
359
  def forward(self, x, style, text_lengths):
 
 
360
 
361
- style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
362
-
363
- style = F.interpolate(style, x.shape[2], mode='nearest')
364
 
365
  x = torch.cat([x, style], axis=1) # [bs, 640, 75]
366
-
367
  input_lengths = text_lengths.cpu().numpy()
368
 
369
  for block in self.lstms:
@@ -398,28 +387,4 @@ def load_F0_models(path):
398
  F0_model.load_state_dict(params)
399
  _ = F0_model.train()
400
 
401
- return F0_model
402
-
403
- def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
404
- # load ASR model
405
- def _load_config(path):
406
- with open(path) as f:
407
- config = yaml.safe_load(f)
408
- model_config = config['model_params']
409
- return model_config
410
-
411
- def _load_model(model_config, model_path):
412
- model = ASRCNN(**model_config)
413
- params = torch.load(
414
- model_path,
415
- map_location='cpu',
416
- weights_only=False
417
- )['model']
418
- model.load_state_dict(params)
419
- return model
420
-
421
- asr_model_config = _load_config(ASR_MODEL_CONFIG)
422
- asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
423
- _ = asr_model.train()
424
-
425
- return asr_model
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch.nn.utils import weight_norm, spectral_norm
9
+ # from Utils.ASR.models import ASRCNN
10
  from Utils.JDC.model import JDCNet
11
+ from Modules.hifigan import _tile, AdainResBlk1d
12
  import yaml
13
 
14
 
 
257
  x, batch_first=True)
258
  x = x.transpose(-1, -2)
259
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  class AdaLayerNorm(nn.Module):
262
 
 
304
  self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
305
 
306
  def F0Ntrain(self, x, s):
307
+ print(x.shape, s.shape, 'F)N T T T')
308
+ x, _ = self.shared(x.transpose(1, 2)) # [bs, time, ch] LSTM
309
+
310
+ x = x.transpose(1, 2) # [bs, ch, time]
311
 
 
312
 
313
+ F0 = x
314
 
315
  for block in self.F0:
316
+ print(f'LOOP {F0.shape=} {s.shape=}\n')
317
  # )N F0.shape=torch.Size([1, 512, 147]) s.shape=torch.Size([1, 128])
318
+ F0 = block(F0, s) # This is an AdainResBlk1d expects conv1d dimensions
 
319
  F0 = self.F0_proj(F0)
320
+ print('____________________________2nd F0Ntra')
321
+ N = x
322
+
323
  for block in self.N:
324
  N = block(N, s)
325
  N = self.N_proj(N)
326
 
327
+ return F0, N
328
+
329
  class DurationEncoder(nn.Module):
330
 
331
  def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
 
346
  self.sty_dim = sty_dim
347
 
348
  def forward(self, x, style, text_lengths):
349
+
350
+ # style = style[:, :, 0, :].transpose(2, 1) # [bs, 128, 11]
351
 
352
+ style = _tile(style, length=x.shape[2]) # replicate style vector to duration of txt - F.interpolate or cyclic/tile
 
 
353
 
354
  x = torch.cat([x, style], axis=1) # [bs, 640, 75]
355
+
356
  input_lengths = text_lengths.cpu().numpy()
357
 
358
  for block in self.lstms:
 
387
  F0_model.load_state_dict(params)
388
  _ = F0_model.train()
389
 
390
+ return F0_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
msinference.py CHANGED
@@ -7,14 +7,9 @@ import numpy as np
7
  import yaml
8
  import torchaudio
9
  import librosa
10
- from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_ASR_models, load_F0_models
11
  from nltk.tokenize import word_tokenize
12
 
13
- torch.manual_seed(0)
14
- # torch.backends.cudnn.benchmark = False
15
- # torch.backends.cudnn.deterministic = True
16
- np.random.seed(0)
17
-
18
  # IPA Phonemizer: https://github.com/bootphon/phonemizer
19
 
20
  _pad = "$"
@@ -72,8 +67,11 @@ def compute_style(path):
72
  with torch.no_grad():
73
  ref_s = style_encoder(mel_tensor.unsqueeze(1))
74
  ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
75
- print(f'\n\n\n\nCOMPUTE STYLe {ref_s.shape=} {ref_p.shape=}')
76
- return torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
 
 
 
77
 
78
  device = 'cpu'
79
  if torch.cuda.is_available():
@@ -91,53 +89,14 @@ global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_
91
  args = yaml.safe_load(open(str('Utils/config.yml')))
92
  ASR_config = args['ASR_config']
93
 
94
- ASR_path = args['ASR_path']
95
- text_aligner = load_ASR_models(ASR_path, ASR_config).eval().to(device)
96
-
97
  F0_path = args['F0_path']
98
  pitch_extractor = load_F0_models(F0_path).eval().to(device)
99
 
100
  from Utils.PLBERT.util import load_plbert
101
- bert = load_plbert(args['PLBERT_dir']).eval().to(device)
102
- # model_params = recursive_munch(config['model_params'])
103
- # --
104
- # def build_model(args, text_aligner, pitch_extractor, bert):
105
- # print(f'\n==============\n {args.decoder.type=}\n==============L584 models.py @ build_model()\n')
106
- # # ======================================
107
- # In [4]: args['model_params']
108
- # Out[4]:
109
- # {'decoder': {'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
110
- # 'resblock_kernel_sizes': [3, 7, 11],
111
- # 'type': 'hifigan',
112
- # 'upsample_initial_channel': 512,
113
- # 'upsample_kernel_sizes': [20, 10, 6, 4],
114
- # 'upsample_rates': [10, 5, 3, 2]},
115
- # 'diffusion': {'dist': {'estimate_sigma_data': True,
116
- # 'mean': -3.0,
117
- # 'sigma_data': 0.19926648961191362,
118
- # 'std': 1.0},
119
- # 'embedding_mask_proba': 0.1,
120
- # 'transformer': {'head_features': 64,
121
- # 'multiplier': 2,
122
- # 'num_heads': 8,
123
- # 'num_layers': 3}},
124
- # 'dim_in': 64,
125
- # 'dropout': 0.2,
126
- # 'hidden_dim': 512,
127
- # 'max_conv_dim': 512,
128
- # 'max_dur': 50,
129
- # 'multispeaker': True,
130
- # 'n_layer': 3,
131
- # 'n_mels': 80,
132
- # 'n_token': 178,
133
- # 'slm': {'hidden': 768,
134
- # 'initial_channel': 64,
135
- # 'model': 'microsoft/wavlm-base-plus',
136
- # 'nlayers': 13,
137
- # 'sr': 16000},
138
- # 'style_dim': 128}
139
- # # ===============================================
140
  from Modules.hifigan import Decoder
 
 
 
141
  decoder = Decoder(dim_in=512,
142
  style_dim=128,
143
  dim_out=80, # n_mels
@@ -166,12 +125,7 @@ predictor_encoder = StyleEncoder(dim_in=64,
166
  style_dim=128,
167
  max_conv_dim=512).eval().to(device) # prosodic style encoder
168
  bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
169
- # --
170
- # model = build_model(model_params, text_aligner, pitch_extractor, plbert)
171
- # _ = [model[key].eval() for key in model]
172
- # _ = [model[key].to(device) for key in model]
173
 
174
- # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
175
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
176
  params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
177
  params = params_whole['net']
@@ -204,7 +158,6 @@ decoder.load_state_dict( _del_prefix(params['decoder']), strict=True)
204
  text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True)
205
  predictor_encoder.load_state_dict(_del_prefix(params['predictor_encoder']), strict=True)
206
  style_encoder.load_state_dict(_del_prefix(params['style_encoder']), strict=True)
207
- text_aligner.load_state_dict( _del_prefix(params['text_aligner']), strict=True)
208
  pitch_extractor.load_state_dict(_del_prefix(params['pitch_extractor']), strict=True)
209
 
210
  # def _shift(x):
@@ -236,40 +189,22 @@ def inference(text,
236
  with torch.no_grad():
237
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
238
 
239
- # -----------------------
240
- # WHO TRANSLATES these tokens to sylla
241
- # print(text_mask.shape, '\n__\n', tokens, '\n__\n', text_mask.min(), text_mask.max())
242
- # text_mask=is binary
243
- # tokes = tensor([[ 0, 55, 157, 86, 125, 83, 55, 156, 57, 158, 123, 48, 83, 61,
244
- # 157, 102, 61, 16, 138, 64, 16, 53, 156, 138, 54, 62, 131, 85,
245
- # 123, 83, 54, 16, 50, 156, 86, 123, 102, 125, 102, 46, 147, 16,
246
- # 62, 135, 16, 76, 158, 92, 55, 156, 86, 56, 62, 177, 46, 16,
247
- # 50, 157, 43, 102, 58, 85, 55, 156, 51, 158, 46, 51, 158, 83,
248
- # 16, 48, 76, 158, 123, 16, 72, 53, 61, 157, 86, 61, 83, 44,
249
- # 156, 102, 54, 177, 125, 51, 16, 72, 56, 46, 16, 102, 112, 53,
250
- # 54, 156, 63, 158, 147, 83, 56, 16, 4]], device='cuda:0')
251
-
252
-
253
- t_en = text_encoder(tokens, input_lengths)
254
  bert_dur = bert(tokens, attention_mask=None)
255
  d_en = bert_encoder(bert_dur).transpose(-1, -2)
256
-
257
- ref = ref_s[:, :, :, :128] # [bs, 11, 1, 128]
258
- s = ref_s[:, :, :, 128:] # have channels as last dim so it can go through nn.Linear layers
259
-
260
-
261
- # ON compute style we dont know yet the size to interpolate
262
- # Perhaps we can interpolate ref_s here as now we know how many bert time-frames the text needs
263
- # s = .74 * s # prosody / arousal & fading unvoiced syllabes [x0.7 - x1.2]
264
-
265
 
266
- print(f'{d_en.shape=} {s.shape=} {input_lengths.shape=}')
267
  d = predictor.text_encoder(d_en,
268
  s,
269
  input_lengths)
270
 
271
  x, _ = predictor.lstm(d)
272
- print(d.shape, x.shape, 'Lstm')
273
  duration = predictor.duration_proj(x)
274
 
275
  duration = torch.sigmoid(duration).sum(axis=-1)
@@ -281,24 +216,23 @@ def inference(text,
281
  for i in range(pred_aln_trg.size(0)):
282
  pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
283
  c_frame += int(pred_dur[i].data)
284
-
285
- # encode prosody
286
  en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
287
-
288
- asr_new = torch.zeros_like(en)
289
- asr_new[:, :, 0] = en[:, :, 0]
290
- asr_new[:, :, 1:] = en[:, :, 0:-1]
291
- en = asr_new
292
- print('_________________________________________F0_____________________________')
293
  F0_pred, N_pred = predictor.F0Ntrain(en, s)
294
 
295
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- asr_new = torch.zeros_like(asr)
298
- asr_new[:, :, 0] = asr[:, :, 0]
299
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
300
- asr = asr_new
301
- print('_________________________________________HiFI_____________________________')
302
  x = decoder(asr=asr,
303
  F0_curve=F0_pred,
304
  N=N_pred,
 
7
  import yaml
8
  import torchaudio
9
  import librosa
10
+ from models import ProsodyPredictor, TextEncoder, StyleEncoder, load_F0_models
11
  from nltk.tokenize import word_tokenize
12
 
 
 
 
 
 
13
  # IPA Phonemizer: https://github.com/bootphon/phonemizer
14
 
15
  _pad = "$"
 
67
  with torch.no_grad():
68
  ref_s = style_encoder(mel_tensor.unsqueeze(1))
69
  ref_p = predictor_encoder(mel_tensor.unsqueeze(1)) # [bs, 11, 1, 128]
70
+
71
+ s = torch.cat([ref_s, ref_p], dim=3) # [bs, 11, 1, 256]
72
+
73
+ s = s[:, :, 0, :].transpose(1, 2) # [1, 128, 11]
74
+ return s# [1, 128, 11]
75
 
76
  device = 'cpu'
77
  if torch.cuda.is_available():
 
89
  args = yaml.safe_load(open(str('Utils/config.yml')))
90
  ASR_config = args['ASR_config']
91
 
 
 
 
92
  F0_path = args['F0_path']
93
  pitch_extractor = load_F0_models(F0_path).eval().to(device)
94
 
95
  from Utils.PLBERT.util import load_plbert
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  from Modules.hifigan import Decoder
97
+
98
+ bert = load_plbert(args['PLBERT_dir']).eval().to(device)
99
+
100
  decoder = Decoder(dim_in=512,
101
  style_dim=128,
102
  dim_out=80, # n_mels
 
125
  style_dim=128,
126
  max_conv_dim=512).eval().to(device) # prosodic style encoder
127
  bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device)
 
 
 
 
128
 
 
129
  # params_whole = torch.load('freevc2/yl4579_styletts2.pth' map_location='cpu')
130
  params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
131
  params = params_whole['net']
 
158
  text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True)
159
  predictor_encoder.load_state_dict(_del_prefix(params['predictor_encoder']), strict=True)
160
  style_encoder.load_state_dict(_del_prefix(params['style_encoder']), strict=True)
 
161
  pitch_extractor.load_state_dict(_del_prefix(params['pitch_extractor']), strict=True)
162
 
163
  # def _shift(x):
 
189
  with torch.no_grad():
190
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
191
 
192
+ hidden_states = text_encoder(tokens, input_lengths)
193
+
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  bert_dur = bert(tokens, attention_mask=None)
195
  d_en = bert_encoder(bert_dur).transpose(-1, -2)
196
+ ref = ref_s[:, :128, :] # [bs, 128, 11]
197
+ s = ref_s[:, 128:, :]
198
+ d = predictor.text_encoder(d_en, s, input_lengths)
199
+ d = d.transpose(1, 2)
200
+ # -------------------------------- pred_aln_trg = clones bert frames as duration
 
 
 
 
201
 
 
202
  d = predictor.text_encoder(d_en,
203
  s,
204
  input_lengths)
205
 
206
  x, _ = predictor.lstm(d)
207
+
208
  duration = predictor.duration_proj(x)
209
 
210
  duration = torch.sigmoid(duration).sum(axis=-1)
 
216
  for i in range(pred_aln_trg.size(0)):
217
  pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
218
  c_frame += int(pred_dur[i].data)
 
 
219
  en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
220
+
 
 
 
 
 
221
  F0_pred, N_pred = predictor.F0Ntrain(en, s)
222
 
223
+ asr = (hidden_states @ pred_aln_trg.unsqueeze(0).to(device))
224
+
225
+ # -- END DURATION
226
+
227
+ # [bs, 640, 198]
228
+
229
+ # replicated Huberrt frames for duration-of-each-frame to elast [bs, 640, 130] -> [bs, 640, 198]
230
+
231
+ # every Hubert frame can be cloned from 1 to ~12 times and appended to the final array
232
+
233
+
234
+ F0_pred, N_pred = predictor.F0Ntrain(en, s)
235
 
 
 
 
 
 
236
  x = decoder(asr=asr,
237
  F0_curve=F0_pred,
238
  N=N_pred,