theNeofr commited on
Commit
578d30e
·
verified ·
1 Parent(s): b1bd4de

Upload 9 files

Browse files
libs/infer_packs/attentions.py CHANGED
@@ -1,459 +1,459 @@
1
- import copy
2
- import math
3
- from typing import Optional
4
-
5
- import numpy as np
6
- import torch
7
- from torch import nn
8
- from torch.nn import functional as F
9
-
10
- from infer.lib.infer_pack import commons, modules
11
- from infer.lib.infer_pack.modules import LayerNorm
12
-
13
-
14
- class Encoder(nn.Module):
15
- def __init__(
16
- self,
17
- hidden_channels,
18
- filter_channels,
19
- n_heads,
20
- n_layers,
21
- kernel_size=1,
22
- p_dropout=0.0,
23
- window_size=10,
24
- **kwargs
25
- ):
26
- super(Encoder, self).__init__()
27
- self.hidden_channels = hidden_channels
28
- self.filter_channels = filter_channels
29
- self.n_heads = n_heads
30
- self.n_layers = int(n_layers)
31
- self.kernel_size = kernel_size
32
- self.p_dropout = p_dropout
33
- self.window_size = window_size
34
-
35
- self.drop = nn.Dropout(p_dropout)
36
- self.attn_layers = nn.ModuleList()
37
- self.norm_layers_1 = nn.ModuleList()
38
- self.ffn_layers = nn.ModuleList()
39
- self.norm_layers_2 = nn.ModuleList()
40
- for i in range(self.n_layers):
41
- self.attn_layers.append(
42
- MultiHeadAttention(
43
- hidden_channels,
44
- hidden_channels,
45
- n_heads,
46
- p_dropout=p_dropout,
47
- window_size=window_size,
48
- )
49
- )
50
- self.norm_layers_1.append(LayerNorm(hidden_channels))
51
- self.ffn_layers.append(
52
- FFN(
53
- hidden_channels,
54
- hidden_channels,
55
- filter_channels,
56
- kernel_size,
57
- p_dropout=p_dropout,
58
- )
59
- )
60
- self.norm_layers_2.append(LayerNorm(hidden_channels))
61
-
62
- def forward(self, x, x_mask):
63
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
64
- x = x * x_mask
65
- zippep = zip(
66
- self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
67
- )
68
- for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
69
- y = attn_layers(x, x, attn_mask)
70
- y = self.drop(y)
71
- x = norm_layers_1(x + y)
72
-
73
- y = ffn_layers(x, x_mask)
74
- y = self.drop(y)
75
- x = norm_layers_2(x + y)
76
- x = x * x_mask
77
- return x
78
-
79
-
80
- class Decoder(nn.Module):
81
- def __init__(
82
- self,
83
- hidden_channels,
84
- filter_channels,
85
- n_heads,
86
- n_layers,
87
- kernel_size=1,
88
- p_dropout=0.0,
89
- proximal_bias=False,
90
- proximal_init=True,
91
- **kwargs
92
- ):
93
- super(Decoder, self).__init__()
94
- self.hidden_channels = hidden_channels
95
- self.filter_channels = filter_channels
96
- self.n_heads = n_heads
97
- self.n_layers = n_layers
98
- self.kernel_size = kernel_size
99
- self.p_dropout = p_dropout
100
- self.proximal_bias = proximal_bias
101
- self.proximal_init = proximal_init
102
-
103
- self.drop = nn.Dropout(p_dropout)
104
- self.self_attn_layers = nn.ModuleList()
105
- self.norm_layers_0 = nn.ModuleList()
106
- self.encdec_attn_layers = nn.ModuleList()
107
- self.norm_layers_1 = nn.ModuleList()
108
- self.ffn_layers = nn.ModuleList()
109
- self.norm_layers_2 = nn.ModuleList()
110
- for i in range(self.n_layers):
111
- self.self_attn_layers.append(
112
- MultiHeadAttention(
113
- hidden_channels,
114
- hidden_channels,
115
- n_heads,
116
- p_dropout=p_dropout,
117
- proximal_bias=proximal_bias,
118
- proximal_init=proximal_init,
119
- )
120
- )
121
- self.norm_layers_0.append(LayerNorm(hidden_channels))
122
- self.encdec_attn_layers.append(
123
- MultiHeadAttention(
124
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
125
- )
126
- )
127
- self.norm_layers_1.append(LayerNorm(hidden_channels))
128
- self.ffn_layers.append(
129
- FFN(
130
- hidden_channels,
131
- hidden_channels,
132
- filter_channels,
133
- kernel_size,
134
- p_dropout=p_dropout,
135
- causal=True,
136
- )
137
- )
138
- self.norm_layers_2.append(LayerNorm(hidden_channels))
139
-
140
- def forward(self, x, x_mask, h, h_mask):
141
- """
142
- x: decoder input
143
- h: encoder output
144
- """
145
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
146
- device=x.device, dtype=x.dtype
147
- )
148
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
149
- x = x * x_mask
150
- for i in range(self.n_layers):
151
- y = self.self_attn_layers[i](x, x, self_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_0[i](x + y)
154
-
155
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_1[i](x + y)
158
-
159
- y = self.ffn_layers[i](x, x_mask)
160
- y = self.drop(y)
161
- x = self.norm_layers_2[i](x + y)
162
- x = x * x_mask
163
- return x
164
-
165
-
166
- class MultiHeadAttention(nn.Module):
167
- def __init__(
168
- self,
169
- channels,
170
- out_channels,
171
- n_heads,
172
- p_dropout=0.0,
173
- window_size=None,
174
- heads_share=True,
175
- block_length=None,
176
- proximal_bias=False,
177
- proximal_init=False,
178
- ):
179
- super(MultiHeadAttention, self).__init__()
180
- assert channels % n_heads == 0
181
-
182
- self.channels = channels
183
- self.out_channels = out_channels
184
- self.n_heads = n_heads
185
- self.p_dropout = p_dropout
186
- self.window_size = window_size
187
- self.heads_share = heads_share
188
- self.block_length = block_length
189
- self.proximal_bias = proximal_bias
190
- self.proximal_init = proximal_init
191
- self.attn = None
192
-
193
- self.k_channels = channels // n_heads
194
- self.conv_q = nn.Conv1d(channels, channels, 1)
195
- self.conv_k = nn.Conv1d(channels, channels, 1)
196
- self.conv_v = nn.Conv1d(channels, channels, 1)
197
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
198
- self.drop = nn.Dropout(p_dropout)
199
-
200
- if window_size is not None:
201
- n_heads_rel = 1 if heads_share else n_heads
202
- rel_stddev = self.k_channels**-0.5
203
- self.emb_rel_k = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
- self.emb_rel_v = nn.Parameter(
208
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
209
- * rel_stddev
210
- )
211
-
212
- nn.init.xavier_uniform_(self.conv_q.weight)
213
- nn.init.xavier_uniform_(self.conv_k.weight)
214
- nn.init.xavier_uniform_(self.conv_v.weight)
215
- if proximal_init:
216
- with torch.no_grad():
217
- self.conv_k.weight.copy_(self.conv_q.weight)
218
- self.conv_k.bias.copy_(self.conv_q.bias)
219
-
220
- def forward(
221
- self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
222
- ):
223
- q = self.conv_q(x)
224
- k = self.conv_k(c)
225
- v = self.conv_v(c)
226
-
227
- x, _ = self.attention(q, k, v, mask=attn_mask)
228
-
229
- x = self.conv_o(x)
230
- return x
231
-
232
- def attention(
233
- self,
234
- query: torch.Tensor,
235
- key: torch.Tensor,
236
- value: torch.Tensor,
237
- mask: Optional[torch.Tensor] = None,
238
- ):
239
- # reshape [b, d, t] -> [b, n_h, t, d_k]
240
- b, d, t_s = key.size()
241
- t_t = query.size(2)
242
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
243
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
244
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
245
-
246
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
247
- if self.window_size is not None:
248
- assert (
249
- t_s == t_t
250
- ), "Relative attention is only available for self-attention."
251
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
252
- rel_logits = self._matmul_with_relative_keys(
253
- query / math.sqrt(self.k_channels), key_relative_embeddings
254
- )
255
- scores_local = self._relative_position_to_absolute_position(rel_logits)
256
- scores = scores + scores_local
257
- if self.proximal_bias:
258
- assert t_s == t_t, "Proximal bias is only available for self-attention."
259
- scores = scores + self._attention_bias_proximal(t_s).to(
260
- device=scores.device, dtype=scores.dtype
261
- )
262
- if mask is not None:
263
- scores = scores.masked_fill(mask == 0, -1e4)
264
- if self.block_length is not None:
265
- assert (
266
- t_s == t_t
267
- ), "Local attention is only available for self-attention."
268
- block_mask = (
269
- torch.ones_like(scores)
270
- .triu(-self.block_length)
271
- .tril(self.block_length)
272
- )
273
- scores = scores.masked_fill(block_mask == 0, -1e4)
274
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
275
- p_attn = self.drop(p_attn)
276
- output = torch.matmul(p_attn, value)
277
- if self.window_size is not None:
278
- relative_weights = self._absolute_position_to_relative_position(p_attn)
279
- value_relative_embeddings = self._get_relative_embeddings(
280
- self.emb_rel_v, t_s
281
- )
282
- output = output + self._matmul_with_relative_values(
283
- relative_weights, value_relative_embeddings
284
- )
285
- output = (
286
- output.transpose(2, 3).contiguous().view(b, d, t_t)
287
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
288
- return output, p_attn
289
-
290
- def _matmul_with_relative_values(self, x, y):
291
- """
292
- x: [b, h, l, m]
293
- y: [h or 1, m, d]
294
- ret: [b, h, l, d]
295
- """
296
- ret = torch.matmul(x, y.unsqueeze(0))
297
- return ret
298
-
299
- def _matmul_with_relative_keys(self, x, y):
300
- """
301
- x: [b, h, l, d]
302
- y: [h or 1, m, d]
303
- ret: [b, h, l, m]
304
- """
305
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
306
- return ret
307
-
308
- def _get_relative_embeddings(self, relative_embeddings, length: int):
309
- max_relative_position = 2 * self.window_size + 1
310
- # Pad first before slice to avoid using cond ops.
311
- pad_length: int = max(length - (self.window_size + 1), 0)
312
- slice_start_position = max((self.window_size + 1) - length, 0)
313
- slice_end_position = slice_start_position + 2 * length - 1
314
- if pad_length > 0:
315
- padded_relative_embeddings = F.pad(
316
- relative_embeddings,
317
- # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
318
- [0, 0, pad_length, pad_length, 0, 0],
319
- )
320
- else:
321
- padded_relative_embeddings = relative_embeddings
322
- used_relative_embeddings = padded_relative_embeddings[
323
- :, slice_start_position:slice_end_position
324
- ]
325
- return used_relative_embeddings
326
-
327
- def _relative_position_to_absolute_position(self, x):
328
- """
329
- x: [b, h, l, 2*l-1]
330
- ret: [b, h, l, l]
331
- """
332
- batch, heads, length, _ = x.size()
333
- # Concat columns of pad to shift from relative to absolute indexing.
334
- x = F.pad(
335
- x,
336
- # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
337
- [0, 1, 0, 0, 0, 0, 0, 0],
338
- )
339
-
340
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
341
- x_flat = x.view([batch, heads, length * 2 * length])
342
- x_flat = F.pad(
343
- x_flat,
344
- # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]])
345
- [0, int(length) - 1, 0, 0, 0, 0],
346
- )
347
-
348
- # Reshape and slice out the padded elements.
349
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
350
- :, :, :length, length - 1 :
351
- ]
352
- return x_final
353
-
354
- def _absolute_position_to_relative_position(self, x):
355
- """
356
- x: [b, h, l, l]
357
- ret: [b, h, l, 2*l-1]
358
- """
359
- batch, heads, length, _ = x.size()
360
- # padd along column
361
- x = F.pad(
362
- x,
363
- # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]])
364
- [0, int(length) - 1, 0, 0, 0, 0, 0, 0],
365
- )
366
- x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))])
367
- # add 0's in the beginning that will skew the elements after reshape
368
- x_flat = F.pad(
369
- x_flat,
370
- # commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]])
371
- [length, 0, 0, 0, 0, 0],
372
- )
373
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
374
- return x_final
375
-
376
- def _attention_bias_proximal(self, length: int):
377
- """Bias for self-attention to encourage attention to close positions.
378
- Args:
379
- length: an integer scalar.
380
- Returns:
381
- a Tensor with shape [1, 1, length, length]
382
- """
383
- r = torch.arange(length, dtype=torch.float32)
384
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
385
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
386
-
387
-
388
- class FFN(nn.Module):
389
- def __init__(
390
- self,
391
- in_channels,
392
- out_channels,
393
- filter_channels,
394
- kernel_size,
395
- p_dropout=0.0,
396
- activation: str = None,
397
- causal=False,
398
- ):
399
- super(FFN, self).__init__()
400
- self.in_channels = in_channels
401
- self.out_channels = out_channels
402
- self.filter_channels = filter_channels
403
- self.kernel_size = kernel_size
404
- self.p_dropout = p_dropout
405
- self.activation = activation
406
- self.causal = causal
407
- self.is_activation = True if activation == "gelu" else False
408
- # if causal:
409
- # self.padding = self._causal_padding
410
- # else:
411
- # self.padding = self._same_padding
412
-
413
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
414
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
415
- self.drop = nn.Dropout(p_dropout)
416
-
417
- def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
418
- if self.causal:
419
- padding = self._causal_padding(x * x_mask)
420
- else:
421
- padding = self._same_padding(x * x_mask)
422
- return padding
423
-
424
- def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
425
- x = self.conv_1(self.padding(x, x_mask))
426
- if self.is_activation:
427
- x = x * torch.sigmoid(1.702 * x)
428
- else:
429
- x = torch.relu(x)
430
- x = self.drop(x)
431
-
432
- x = self.conv_2(self.padding(x, x_mask))
433
- return x * x_mask
434
-
435
- def _causal_padding(self, x):
436
- if self.kernel_size == 1:
437
- return x
438
- pad_l: int = self.kernel_size - 1
439
- pad_r: int = 0
440
- # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
441
- x = F.pad(
442
- x,
443
- # commons.convert_pad_shape(padding)
444
- [pad_l, pad_r, 0, 0, 0, 0],
445
- )
446
- return x
447
-
448
- def _same_padding(self, x):
449
- if self.kernel_size == 1:
450
- return x
451
- pad_l: int = (self.kernel_size - 1) // 2
452
- pad_r: int = self.kernel_size // 2
453
- # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
- x = F.pad(
455
- x,
456
- # commons.convert_pad_shape(padding)
457
- [pad_l, pad_r, 0, 0, 0, 0],
458
- )
459
- return x
 
1
+ import copy
2
+ import math
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from libs.infer_packs import commons, modules
11
+ from libs.infer_packs.modules import LayerNorm
12
+
13
+
14
+ class Encoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ hidden_channels,
18
+ filter_channels,
19
+ n_heads,
20
+ n_layers,
21
+ kernel_size=1,
22
+ p_dropout=0.0,
23
+ window_size=10,
24
+ **kwargs
25
+ ):
26
+ super(Encoder, self).__init__()
27
+ self.hidden_channels = hidden_channels
28
+ self.filter_channels = filter_channels
29
+ self.n_heads = n_heads
30
+ self.n_layers = int(n_layers)
31
+ self.kernel_size = kernel_size
32
+ self.p_dropout = p_dropout
33
+ self.window_size = window_size
34
+
35
+ self.drop = nn.Dropout(p_dropout)
36
+ self.attn_layers = nn.ModuleList()
37
+ self.norm_layers_1 = nn.ModuleList()
38
+ self.ffn_layers = nn.ModuleList()
39
+ self.norm_layers_2 = nn.ModuleList()
40
+ for i in range(self.n_layers):
41
+ self.attn_layers.append(
42
+ MultiHeadAttention(
43
+ hidden_channels,
44
+ hidden_channels,
45
+ n_heads,
46
+ p_dropout=p_dropout,
47
+ window_size=window_size,
48
+ )
49
+ )
50
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
51
+ self.ffn_layers.append(
52
+ FFN(
53
+ hidden_channels,
54
+ hidden_channels,
55
+ filter_channels,
56
+ kernel_size,
57
+ p_dropout=p_dropout,
58
+ )
59
+ )
60
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
61
+
62
+ def forward(self, x, x_mask):
63
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
64
+ x = x * x_mask
65
+ zippep = zip(
66
+ self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
67
+ )
68
+ for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
69
+ y = attn_layers(x, x, attn_mask)
70
+ y = self.drop(y)
71
+ x = norm_layers_1(x + y)
72
+
73
+ y = ffn_layers(x, x_mask)
74
+ y = self.drop(y)
75
+ x = norm_layers_2(x + y)
76
+ x = x * x_mask
77
+ return x
78
+
79
+
80
+ class Decoder(nn.Module):
81
+ def __init__(
82
+ self,
83
+ hidden_channels,
84
+ filter_channels,
85
+ n_heads,
86
+ n_layers,
87
+ kernel_size=1,
88
+ p_dropout=0.0,
89
+ proximal_bias=False,
90
+ proximal_init=True,
91
+ **kwargs
92
+ ):
93
+ super(Decoder, self).__init__()
94
+ self.hidden_channels = hidden_channels
95
+ self.filter_channels = filter_channels
96
+ self.n_heads = n_heads
97
+ self.n_layers = n_layers
98
+ self.kernel_size = kernel_size
99
+ self.p_dropout = p_dropout
100
+ self.proximal_bias = proximal_bias
101
+ self.proximal_init = proximal_init
102
+
103
+ self.drop = nn.Dropout(p_dropout)
104
+ self.self_attn_layers = nn.ModuleList()
105
+ self.norm_layers_0 = nn.ModuleList()
106
+ self.encdec_attn_layers = nn.ModuleList()
107
+ self.norm_layers_1 = nn.ModuleList()
108
+ self.ffn_layers = nn.ModuleList()
109
+ self.norm_layers_2 = nn.ModuleList()
110
+ for i in range(self.n_layers):
111
+ self.self_attn_layers.append(
112
+ MultiHeadAttention(
113
+ hidden_channels,
114
+ hidden_channels,
115
+ n_heads,
116
+ p_dropout=p_dropout,
117
+ proximal_bias=proximal_bias,
118
+ proximal_init=proximal_init,
119
+ )
120
+ )
121
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
122
+ self.encdec_attn_layers.append(
123
+ MultiHeadAttention(
124
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
125
+ )
126
+ )
127
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
128
+ self.ffn_layers.append(
129
+ FFN(
130
+ hidden_channels,
131
+ hidden_channels,
132
+ filter_channels,
133
+ kernel_size,
134
+ p_dropout=p_dropout,
135
+ causal=True,
136
+ )
137
+ )
138
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
139
+
140
+ def forward(self, x, x_mask, h, h_mask):
141
+ """
142
+ x: decoder input
143
+ h: encoder output
144
+ """
145
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
146
+ device=x.device, dtype=x.dtype
147
+ )
148
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
149
+ x = x * x_mask
150
+ for i in range(self.n_layers):
151
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
152
+ y = self.drop(y)
153
+ x = self.norm_layers_0[i](x + y)
154
+
155
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
156
+ y = self.drop(y)
157
+ x = self.norm_layers_1[i](x + y)
158
+
159
+ y = self.ffn_layers[i](x, x_mask)
160
+ y = self.drop(y)
161
+ x = self.norm_layers_2[i](x + y)
162
+ x = x * x_mask
163
+ return x
164
+
165
+
166
+ class MultiHeadAttention(nn.Module):
167
+ def __init__(
168
+ self,
169
+ channels,
170
+ out_channels,
171
+ n_heads,
172
+ p_dropout=0.0,
173
+ window_size=None,
174
+ heads_share=True,
175
+ block_length=None,
176
+ proximal_bias=False,
177
+ proximal_init=False,
178
+ ):
179
+ super(MultiHeadAttention, self).__init__()
180
+ assert channels % n_heads == 0
181
+
182
+ self.channels = channels
183
+ self.out_channels = out_channels
184
+ self.n_heads = n_heads
185
+ self.p_dropout = p_dropout
186
+ self.window_size = window_size
187
+ self.heads_share = heads_share
188
+ self.block_length = block_length
189
+ self.proximal_bias = proximal_bias
190
+ self.proximal_init = proximal_init
191
+ self.attn = None
192
+
193
+ self.k_channels = channels // n_heads
194
+ self.conv_q = nn.Conv1d(channels, channels, 1)
195
+ self.conv_k = nn.Conv1d(channels, channels, 1)
196
+ self.conv_v = nn.Conv1d(channels, channels, 1)
197
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
198
+ self.drop = nn.Dropout(p_dropout)
199
+
200
+ if window_size is not None:
201
+ n_heads_rel = 1 if heads_share else n_heads
202
+ rel_stddev = self.k_channels**-0.5
203
+ self.emb_rel_k = nn.Parameter(
204
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
+ * rel_stddev
206
+ )
207
+ self.emb_rel_v = nn.Parameter(
208
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
209
+ * rel_stddev
210
+ )
211
+
212
+ nn.init.xavier_uniform_(self.conv_q.weight)
213
+ nn.init.xavier_uniform_(self.conv_k.weight)
214
+ nn.init.xavier_uniform_(self.conv_v.weight)
215
+ if proximal_init:
216
+ with torch.no_grad():
217
+ self.conv_k.weight.copy_(self.conv_q.weight)
218
+ self.conv_k.bias.copy_(self.conv_q.bias)
219
+
220
+ def forward(
221
+ self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
222
+ ):
223
+ q = self.conv_q(x)
224
+ k = self.conv_k(c)
225
+ v = self.conv_v(c)
226
+
227
+ x, _ = self.attention(q, k, v, mask=attn_mask)
228
+
229
+ x = self.conv_o(x)
230
+ return x
231
+
232
+ def attention(
233
+ self,
234
+ query: torch.Tensor,
235
+ key: torch.Tensor,
236
+ value: torch.Tensor,
237
+ mask: Optional[torch.Tensor] = None,
238
+ ):
239
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
240
+ b, d, t_s = key.size()
241
+ t_t = query.size(2)
242
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
243
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
244
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
245
+
246
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
247
+ if self.window_size is not None:
248
+ assert (
249
+ t_s == t_t
250
+ ), "Relative attention is only available for self-attention."
251
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
252
+ rel_logits = self._matmul_with_relative_keys(
253
+ query / math.sqrt(self.k_channels), key_relative_embeddings
254
+ )
255
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
256
+ scores = scores + scores_local
257
+ if self.proximal_bias:
258
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
259
+ scores = scores + self._attention_bias_proximal(t_s).to(
260
+ device=scores.device, dtype=scores.dtype
261
+ )
262
+ if mask is not None:
263
+ scores = scores.masked_fill(mask == 0, -1e4)
264
+ if self.block_length is not None:
265
+ assert (
266
+ t_s == t_t
267
+ ), "Local attention is only available for self-attention."
268
+ block_mask = (
269
+ torch.ones_like(scores)
270
+ .triu(-self.block_length)
271
+ .tril(self.block_length)
272
+ )
273
+ scores = scores.masked_fill(block_mask == 0, -1e4)
274
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
275
+ p_attn = self.drop(p_attn)
276
+ output = torch.matmul(p_attn, value)
277
+ if self.window_size is not None:
278
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
279
+ value_relative_embeddings = self._get_relative_embeddings(
280
+ self.emb_rel_v, t_s
281
+ )
282
+ output = output + self._matmul_with_relative_values(
283
+ relative_weights, value_relative_embeddings
284
+ )
285
+ output = (
286
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
287
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
288
+ return output, p_attn
289
+
290
+ def _matmul_with_relative_values(self, x, y):
291
+ """
292
+ x: [b, h, l, m]
293
+ y: [h or 1, m, d]
294
+ ret: [b, h, l, d]
295
+ """
296
+ ret = torch.matmul(x, y.unsqueeze(0))
297
+ return ret
298
+
299
+ def _matmul_with_relative_keys(self, x, y):
300
+ """
301
+ x: [b, h, l, d]
302
+ y: [h or 1, m, d]
303
+ ret: [b, h, l, m]
304
+ """
305
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
306
+ return ret
307
+
308
+ def _get_relative_embeddings(self, relative_embeddings, length: int):
309
+ max_relative_position = 2 * self.window_size + 1
310
+ # Pad first before slice to avoid using cond ops.
311
+ pad_length: int = max(length - (self.window_size + 1), 0)
312
+ slice_start_position = max((self.window_size + 1) - length, 0)
313
+ slice_end_position = slice_start_position + 2 * length - 1
314
+ if pad_length > 0:
315
+ padded_relative_embeddings = F.pad(
316
+ relative_embeddings,
317
+ # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
318
+ [0, 0, pad_length, pad_length, 0, 0],
319
+ )
320
+ else:
321
+ padded_relative_embeddings = relative_embeddings
322
+ used_relative_embeddings = padded_relative_embeddings[
323
+ :, slice_start_position:slice_end_position
324
+ ]
325
+ return used_relative_embeddings
326
+
327
+ def _relative_position_to_absolute_position(self, x):
328
+ """
329
+ x: [b, h, l, 2*l-1]
330
+ ret: [b, h, l, l]
331
+ """
332
+ batch, heads, length, _ = x.size()
333
+ # Concat columns of pad to shift from relative to absolute indexing.
334
+ x = F.pad(
335
+ x,
336
+ # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
337
+ [0, 1, 0, 0, 0, 0, 0, 0],
338
+ )
339
+
340
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
341
+ x_flat = x.view([batch, heads, length * 2 * length])
342
+ x_flat = F.pad(
343
+ x_flat,
344
+ # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]])
345
+ [0, int(length) - 1, 0, 0, 0, 0],
346
+ )
347
+
348
+ # Reshape and slice out the padded elements.
349
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
350
+ :, :, :length, length - 1 :
351
+ ]
352
+ return x_final
353
+
354
+ def _absolute_position_to_relative_position(self, x):
355
+ """
356
+ x: [b, h, l, l]
357
+ ret: [b, h, l, 2*l-1]
358
+ """
359
+ batch, heads, length, _ = x.size()
360
+ # padd along column
361
+ x = F.pad(
362
+ x,
363
+ # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]])
364
+ [0, int(length) - 1, 0, 0, 0, 0, 0, 0],
365
+ )
366
+ x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))])
367
+ # add 0's in the beginning that will skew the elements after reshape
368
+ x_flat = F.pad(
369
+ x_flat,
370
+ # commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]])
371
+ [length, 0, 0, 0, 0, 0],
372
+ )
373
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
374
+ return x_final
375
+
376
+ def _attention_bias_proximal(self, length: int):
377
+ """Bias for self-attention to encourage attention to close positions.
378
+ Args:
379
+ length: an integer scalar.
380
+ Returns:
381
+ a Tensor with shape [1, 1, length, length]
382
+ """
383
+ r = torch.arange(length, dtype=torch.float32)
384
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
385
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
386
+
387
+
388
+ class FFN(nn.Module):
389
+ def __init__(
390
+ self,
391
+ in_channels,
392
+ out_channels,
393
+ filter_channels,
394
+ kernel_size,
395
+ p_dropout=0.0,
396
+ activation: str = None,
397
+ causal=False,
398
+ ):
399
+ super(FFN, self).__init__()
400
+ self.in_channels = in_channels
401
+ self.out_channels = out_channels
402
+ self.filter_channels = filter_channels
403
+ self.kernel_size = kernel_size
404
+ self.p_dropout = p_dropout
405
+ self.activation = activation
406
+ self.causal = causal
407
+ self.is_activation = True if activation == "gelu" else False
408
+ # if causal:
409
+ # self.padding = self._causal_padding
410
+ # else:
411
+ # self.padding = self._same_padding
412
+
413
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
414
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
415
+ self.drop = nn.Dropout(p_dropout)
416
+
417
+ def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
418
+ if self.causal:
419
+ padding = self._causal_padding(x * x_mask)
420
+ else:
421
+ padding = self._same_padding(x * x_mask)
422
+ return padding
423
+
424
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
425
+ x = self.conv_1(self.padding(x, x_mask))
426
+ if self.is_activation:
427
+ x = x * torch.sigmoid(1.702 * x)
428
+ else:
429
+ x = torch.relu(x)
430
+ x = self.drop(x)
431
+
432
+ x = self.conv_2(self.padding(x, x_mask))
433
+ return x * x_mask
434
+
435
+ def _causal_padding(self, x):
436
+ if self.kernel_size == 1:
437
+ return x
438
+ pad_l: int = self.kernel_size - 1
439
+ pad_r: int = 0
440
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
441
+ x = F.pad(
442
+ x,
443
+ # commons.convert_pad_shape(padding)
444
+ [pad_l, pad_r, 0, 0, 0, 0],
445
+ )
446
+ return x
447
+
448
+ def _same_padding(self, x):
449
+ if self.kernel_size == 1:
450
+ return x
451
+ pad_l: int = (self.kernel_size - 1) // 2
452
+ pad_r: int = self.kernel_size // 2
453
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
+ x = F.pad(
455
+ x,
456
+ # commons.convert_pad_shape(padding)
457
+ [pad_l, pad_r, 0, 0, 0, 0],
458
+ )
459
+ return x
libs/infer_packs/models.py CHANGED
@@ -10,8 +10,8 @@ from torch import nn
10
  from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
11
  from torch.nn import functional as F
12
  from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
13
- from libs.infer_pack import attentions, commons, modules
14
- from libs.infer_pack.commons import get_padding, init_weights
15
 
16
  has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
17
 
 
10
  from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
11
  from torch.nn import functional as F
12
  from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
13
+ from libs.infer_packs import attentions, commons, modules
14
+ from libs.infer_packs.commons import get_padding, init_weights
15
 
16
  has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available())
17
 
libs/infer_packs/modules.py CHANGED
@@ -10,9 +10,9 @@ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
10
  from torch.nn import functional as F
11
  from torch.nn.utils import remove_weight_norm, weight_norm
12
 
13
- from libs.infer_pack import commons
14
- from libs.infer_pack.commons import get_padding, init_weights
15
- from libs.infer_pack.transforms import piecewise_rational_quadratic_transform
16
 
17
  LRELU_SLOPE = 0.1
18
 
 
10
  from torch.nn import functional as F
11
  from torch.nn.utils import remove_weight_norm, weight_norm
12
 
13
+ from libs.infer_packs import commons
14
+ from libs.infer_packs.commons import get_padding, init_weights
15
+ from libs.infer_packs.transforms import piecewise_rational_quadratic_transform
16
 
17
  LRELU_SLOPE = 0.1
18