Sin2pi commited on
Commit
7354230
·
verified ·
1 Parent(s): 2b51e86

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +1581 -0
model.py ADDED
@@ -0,0 +1,1581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pyworld as pw
3
+ import os
4
+ import math
5
+ import warnings
6
+ import logging
7
+ import gzip
8
+ import base64
9
+ import torch
10
+ import torchaudio
11
+ import torchcrepe
12
+ import torch.nn.functional as F
13
+ import torch.nn.init as init
14
+ from torch import nn, Tensor
15
+ import numpy as np
16
+ from typing import Optional, Dict, Union, List, Tuple, Any
17
+ from functools import partial
18
+ from datetime import datetime
19
+ from datasets import load_dataset, Audio
20
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
21
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
22
+ import transformers
23
+ import evaluate
24
+ from dataclasses import dataclass
25
+
26
+ torch.backends.cudnn.allow_tf32 = True
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ torch.set_float32_matmul_precision('high')
29
+ transformers.utils.logging.set_verbosity_error()
30
+
31
+ device = torch.device(device="cuda:0")
32
+ dtype = torch.float32
33
+
34
+ torch.set_default_dtype(dtype)
35
+ warnings.filterwarnings("ignore")
36
+ logging.basicConfig(level=logging.ERROR)
37
+ tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
38
+
39
+ extractor = None
40
+ tokenizer = None
41
+ optimizer = None
42
+ scheduler = None
43
+ model = None
44
+ Residual = None
45
+ MultiheadA = None
46
+
47
+ @dataclass
48
+ class Dimensions:
49
+ vocab: int
50
+ text_ctx: int
51
+ text_dims: int
52
+ text_head: int
53
+ text_idx: int
54
+ mels: int
55
+ aud_ctx: int
56
+ aud_dims: int
57
+ aud_head: int
58
+ aud_idx: int
59
+ act: str
60
+ debug: List[str]
61
+ cross_attn: bool
62
+ features: List[str]
63
+ f0_rotary: bool
64
+
65
+ def exists(v):
66
+ return v is not None
67
+
68
+ def default(v, b):
69
+ return v if exists(v) else b
70
+
71
+ class Conv1d(nn.Conv1d):
72
+ def _conv_forward(
73
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
74
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
75
+
76
+ class Conv2d(nn.Conv2d):
77
+ def _conv_forward(
78
+ self, x: Tensor, weight: Tensor, bias) -> Tensor:
79
+ return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
80
+
81
+ class Linear(nn.Module):
82
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
83
+ super(Linear, self).__init__()
84
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
85
+ init.xavier_uniform_(self.linear.weight)
86
+ if bias:
87
+ init.zeros_(self.linear.bias)
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ return self.linear(x)
90
+
91
+ class RMSNorm(nn.Module):
92
+ def __init__(self, dims: Union[int, Tensor, List, Tuple],
93
+ eps = 1e-8, elementwise_affine = True):
94
+ super(RMSNorm, self).__init__()
95
+ if isinstance(dims, int):
96
+ self.normalized_shape = (dims,)
97
+ else:
98
+ self.normalized_shape = tuple(dims)
99
+ self.eps = eps
100
+ self.elementwise_affine = elementwise_affine
101
+ if self.elementwise_affine:
102
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape))
103
+ init.ones_(self.weight)
104
+ else:
105
+ self.register_parameter("weight", None)
106
+ def forward(self, x):
107
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
108
+
109
+ def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
110
+ weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
111
+ eps: float = 1e-5) -> Tensor:
112
+ return F.layer_norm(x, normalized_shape, weight, bias, eps)
113
+
114
+ def get_device():
115
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
116
+
117
+ def get_dtype():
118
+ return torch.float32 if torch.cuda.is_available() else torch.float64
119
+
120
+ def get_tox():
121
+ return {"device": get_device(), "dtype": get_dtype()}
122
+
123
+ def sinusoids(length, channels, max_timescale=10000):
124
+ """Returns sinusoids for positional embedding"""
125
+ assert channels % 2 == 0
126
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
127
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
128
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
129
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
130
+
131
+ class rotary(nn.Module):
132
+ _seen = set()
133
+ def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=True, variable_radius=False,
134
+ learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
+ super().__init__()
136
+ self.use_pbias = False
137
+
138
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139
+ self.device = device
140
+ dtype = torch.float32
141
+ self.dtype = dtype
142
+ self.debug = debug
143
+ self._counter = 0
144
+ self.dims = dims
145
+ self.max_ctx = max_ctx
146
+ self.variable_radius = variable_radius
147
+
148
+ self.inv_freq = nn.Parameter(
149
+ 1.0 / (19000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
150
+ requires_grad=learned_freq)
151
+ self.theta = nn.Parameter(
152
+ torch.tensor(float(theta)), requires_grad=learned_theta)
153
+ self.min_theta = nn.Parameter(
154
+ torch.tensor(600.0), requires_grad=learned_theta)
155
+ self.max_theta = nn.Parameter(
156
+ torch.tensor(2400.0), requires_grad=learned_theta)
157
+
158
+ self.pitch_scale = nn.Parameter(torch.tensor(1.0),
159
+ requires_grad=learned_pitch)
160
+
161
+ if variable_radius:
162
+ self.radius = nn.Parameter(
163
+ torch.ones(dims // 2),
164
+ requires_grad=learned_radius)
165
+
166
+ def get_pitch_bias(self, f0):
167
+ if f0 is None:
168
+ return None
169
+
170
+ f0_flat = f0.squeeze().float()
171
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
172
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
173
+ f0_norm.unsqueeze(1)) * self.pitch_scale)
174
+ return f0_sim.unsqueeze(0).unsqueeze(0)
175
+
176
+ def add_to_rotary(self):
177
+ def get_sim(self, freqs):
178
+ real = freqs.real.squeeze(0)
179
+ imag = freqs.imag.squeeze(0)
180
+ vecs = torch.cat([real.unsqueeze(-2), imag.unsqueeze(-2)], dim=-1)
181
+ vecs = vecs.squeeze(-2)
182
+ return F.cosine_similarity(vecs.unsqueeze(1), vecs.unsqueeze(0), dim=-1)
183
+
184
+ def fwd_sim(self, x=None, f0=None):
185
+ freqs = self.forward(x, f0)
186
+ sim = get_sim(self, freqs)
187
+ return freqs, sim
188
+
189
+ rotary.get_sim = get_sim
190
+ rotary.fwd_sim = fwd_sim
191
+
192
+ def forward(self, x = None, f0=None) -> Tensor:
193
+ if isinstance(x, int):
194
+ t = torch.arange(x, device=self.device).float()
195
+ else:
196
+ t = x.float().to(self.inv_freq.device)
197
+
198
+ if f0 is not None:
199
+
200
+ f0_mean = f0.mean()
201
+ perceptual_factor = torch.log(1 + f0_mean / 700.0) / torch.log(torch.tensor(1 + 300.0 / 700.0))
202
+ min_theta, max_theta = 800.0, 10000.0
203
+ f0_theta = self.theta + perceptual_factor * (max_theta - min_theta)
204
+ inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
205
+ else:
206
+ inv_freq = self.inv_freq
207
+ freqs = torch.einsum('i,j->ij', t, inv_freq)
208
+
209
+ freqs = freqs.float()
210
+ if self.variable_radius:
211
+ radius = F.softplus(self.radius)
212
+ freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
213
+ else:
214
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
215
+ freqs = freqs.unsqueeze(0)
216
+
217
+ if "rotary" in self.debug:
218
+ if f0 is not None:
219
+ key = f"{self._counter}_{f0_theta:.2f}"
220
+ if key not in rotary._seen:
221
+ if not hasattr(self, '_prev_f0_theta'):
222
+ self._prev_f0_theta = f0_theta
223
+ print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
224
+ elif abs(self._prev_f0_theta - f0_theta) > 200.0:
225
+ print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
226
+ self._prev_f0_theta = f0_theta
227
+ rotary._seen.add(key)
228
+ self._counter += 1
229
+
230
+ return freqs
231
+
232
+ @staticmethod
233
+ def apply_rotary(x, freqs):
234
+ multihead_format = len(freqs.shape) == 4
235
+ if multihead_format:
236
+ x1 = x[..., :freqs.shape[-1]*2]
237
+ x2 = x[..., freqs.shape[-1]*2:]
238
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
239
+ x1 = torch.view_as_complex(x1)
240
+ x1 = x1 * freqs
241
+ x1 = torch.view_as_real(x1).flatten(-2)
242
+ return torch.cat([x1.type_as(x), x2], dim=-1)
243
+
244
+ else:
245
+ x1 = x[..., :freqs.shape[-1]*2]
246
+ x2 = x[..., freqs.shape[-1]*2:]
247
+
248
+ if x.ndim == 2:
249
+
250
+ x1 = x1.unsqueeze(0)
251
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
252
+ x1 = torch.view_as_complex(x1)
253
+ x1 = x1 * freqs
254
+ x1 = torch.view_as_real(x1).flatten(-2)
255
+ x1 = x1.squeeze(0)
256
+ return torch.cat([x1.type_as(x), x2], dim=-1)
257
+ else:
258
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
259
+ x1 = torch.view_as_complex(x1)
260
+ x1 = x1 * freqs
261
+ x1 = torch.view_as_real(x1).flatten(-2)
262
+ return torch.cat([x1.type_as(x), x2], dim=-1)
263
+
264
+ class SliceAttention(nn.Module):
265
+ def __init__(self, dims, heads, dropout=0.0):
266
+ super().__init__()
267
+ self.dims = dims
268
+ self.heads = heads
269
+ self.head_dim = dims // heads
270
+ self.scale = self.head_dim ** -0.5
271
+
272
+ self.q_proj = Linear(dims, dims)
273
+ self.k_proj = Linear(dims, dims)
274
+ self.v_proj = Linear(dims, dims)
275
+ self.out_proj = Linear(dims, dims)
276
+ self.dropout = nn.Dropout(dropout)
277
+
278
+ assert dims % heads == 0, f"Dimensions {dims} not divisible by heads {heads}"
279
+
280
+ def parallel_slice(self, q, k, v, mask=None):
281
+ batch, heads, ctx, dims = q.shape
282
+ head_dim = self.head_dim
283
+ batch, ctx, dims = q.shape
284
+ ctx_len = k.shape[1]
285
+ num_heads = dims // head_dim
286
+
287
+ scores = torch.zeros(batch, num_heads, ctx, ctx_len, device=q.device)
288
+
289
+ for h in range(num_heads):
290
+ start_idx = h * head_dim
291
+ end_idx = start_idx + head_dim
292
+ q_h = q[:, :, start_idx:end_idx]
293
+ k_h = k[:, :, start_idx:end_idx]
294
+
295
+ scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
296
+
297
+ if mask is not None:
298
+ scores = scores + mask.unsqueeze(0).unsqueeze(0)
299
+
300
+ attn_weights = F.softmax(scores, dim=-1)
301
+
302
+ output = torch.zeros_like(q)
303
+ for h in range(num_heads):
304
+ start_idx = h * head_dim
305
+ end_idx = start_idx + head_dim
306
+ v_h = v[:, :, start_idx:end_idx]
307
+ output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
308
+ return output
309
+
310
+ def forward(self, x, context=None, mask=None):
311
+ batch, ctx, _ = x.shape
312
+ if context is None:
313
+ context = x
314
+
315
+ ctx_len = context.shape[1]
316
+ q = self.q_proj(x)
317
+ k = self.k_proj(context)
318
+ v = self.v_proj(context)
319
+ output = torch.zeros_like(q)
320
+
321
+ for h in range(self.heads):
322
+ start_idx = h * self.head_dim
323
+ end_idx = start_idx + self.head_dim
324
+
325
+ q_h = q[:, :, start_idx:end_idx]
326
+ k_h = k[:, :, start_idx:end_idx]
327
+ v_h = v[:, :, start_idx:end_idx]
328
+
329
+ attn_scores = torch.bmm(q_h, k_h.transpose(1, 2)) * self.scale
330
+ if mask is not None:
331
+ attn_scores = attn_scores + mask[:ctx, :ctx_len].unsqueeze(0)
332
+
333
+ attn_weights = F.softmax(attn_scores, dim=-1)
334
+ attn_weights = self.dropout(attn_weights)
335
+ head_output = torch.bmm(attn_weights, v_h)
336
+ output[:, :, start_idx:end_idx] = head_output
337
+ return self.out_proj(output)
338
+
339
+ def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
340
+
341
+ batch, heads, ctx, dims = q.shape
342
+ token_ids = k[:, :, :, 0]
343
+ is_padding = (token_ids.float() == pad_token).unsqueeze(-2)
344
+ log_scale_factor = -10.0
345
+ attn_mask = torch.zeros((batch, heads, ctx, ctx), device=q.device)
346
+
347
+ if mask is not None:
348
+ attn_mask = attn_mask + mask.unsqueeze(0).unsqueeze(0)
349
+ attn_mask = torch.where(is_padding,
350
+ torch.tensor(log_scale_factor, device=q.device),
351
+ attn_mask)
352
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
353
+ q, k, v, attn_mask=attn_mask,
354
+ dropout_p=0.0, is_causal=False)
355
+ attn_output = attn_output.permute(0, 2, 1, 3).flatten(start_dim=2)
356
+ return attn_output
357
+
358
+ class MultiheadA(nn.Module):
359
+ _seen = set()
360
+ rbf = False
361
+ def __init__(self, dims: int, head: int, rotary_emb: bool = False,
362
+ zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
363
+
364
+ super(MultiheadA, self).__init__()
365
+
366
+ self.debug = debug
367
+ self.pad_token = 0
368
+ self.dims = dims
369
+ self.head = head
370
+ self.head_dim = dims // head
371
+ self.rotary_emb = rotary_emb
372
+ self.minz = minz
373
+ self.maxz = maxz
374
+ self.zero_val = zero_val
375
+ self.optim_attn = optim_attn
376
+ self._counter = 0
377
+ if dims % head != 0:
378
+ raise ValueError(f"Dimensions {dims} must be divisible by number of heads {head}.")
379
+ if zero_val < minz or zero_val > maxz:
380
+ raise ValueError(f"Zero value {zero_val} must be between {minz} and {maxz}.")
381
+
382
+ self.q = Linear(dims, dims)
383
+ self.k = Linear(dims, dims, bias=False)
384
+ self.v = Linear(dims, dims)
385
+ self.o = Linear(dims, dims)
386
+ self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True)
387
+
388
+ if rotary_emb:
389
+ self.rope = rotary(
390
+ dims=self.head_dim,
391
+ debug = debug,
392
+ max_ctx=1500,
393
+ )
394
+ else:
395
+ self.rope = None
396
+
397
+ def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
398
+ scale = (self.dims // self.head) ** -0.25
399
+ dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
400
+ if rbf_ratio <= 0.0:
401
+ return dot_scores
402
+ q_norm = q.pow(2).sum(dim=-1, keepdim=True)
403
+ k_norm = k.pow(2).sum(dim=-1, keepdim=True)
404
+ qk = torch.matmul(q, k.transpose(-1, -2))
405
+ dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
406
+ rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
407
+ return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
408
+
409
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
410
+ return_attn: bool = False, f0: Tensor = None) -> tuple:
411
+
412
+ batch, ctx, dims = x.shape
413
+ scale = (self.dims // self.head) ** -0.25
414
+
415
+ z = default(xa, x)
416
+ q = self.q(x).to(x.dtype)
417
+ k = self.k(z).to(x.dtype)
418
+ v = self.v(z).to(x.dtype)
419
+
420
+ if self.rotary_emb:
421
+ if f0 is not None:
422
+ qf = self.rope(q.size(1), f0=f0)
423
+ kf = self.rope(k.size(1), f0=f0)
424
+ else:
425
+ qf = self.rope(q.size(1))
426
+ kf = self.rope(k.size(1))
427
+
428
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
429
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
430
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
431
+
432
+ q = self.rope.apply_rotary(q, qf)
433
+ k = self.rope.apply_rotary(k, kf)
434
+
435
+ else:
436
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
437
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
438
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
439
+ batch, head, ctx, head_dim = q.shape
440
+
441
+ if self.optim_attn and not return_attn:
442
+ wv = optim_attn(q * scale, k * scale, v, mask=mask,
443
+ pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
444
+ return self.o(wv), None
445
+
446
+ if self.rbf:
447
+ qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
448
+
449
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
450
+ if f0 is not None and self.rope.use_pbias:
451
+ pbias = self.rope.pbias(f0)
452
+ if pbias is not None:
453
+ qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
454
+ token_ids = k[:, :, :, 0]
455
+ zscale = torch.ones_like(token_ids)
456
+ fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
457
+ zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
458
+
459
+ if mask is not None:
460
+ mask = mask[:q.shape[2], :q.shape[2]]
461
+ qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
462
+ qk = qk * zscale.unsqueeze(-2)
463
+ if return_attn:
464
+ return qk, v
465
+ w = F.softmax(qk, dim=-1).to(q.dtype)
466
+ wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
467
+
468
+ if "multihead" in self.debug and self._counter % 100 == 0:
469
+ print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
470
+ print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
471
+ print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
472
+ self._counter += 1
473
+ return self.o(wv), qk.detach()
474
+
475
+ class FCGate(nn.Module):
476
+ def __init__(self, dims, dim):
477
+ super().__init__()
478
+ self.proj = Linear(dim, dims // 4)
479
+ self.gate = nn.Sequential(
480
+ Linear(dims + dims // 4, dims // 2),
481
+ nn.SiLU(),
482
+ Linear(dims // 2, 1),
483
+ nn.Sigmoid()
484
+ )
485
+ def forward(self, x, embedding):
486
+ info = self.proj(embedding)
487
+ info = info.unsqueeze(1).expand(-1, x.shape[1], -1)
488
+ gate_input = torch.cat([x, info], dim=-1)
489
+ return self.gate(gate_input)
490
+
491
+ class TTGate(nn.Module):
492
+ def __init__(self, dims, num_types=4):
493
+ super().__init__()
494
+ self.gate_projections = nn.ModuleList([
495
+ nn.Sequential(Linear(dims, 1), nn.Sigmoid())
496
+ for _ in range(num_types)])
497
+ self.type_classifier = nn.Sequential(
498
+ Linear(dims, num_types),
499
+ nn.Softmax(dim=-1))
500
+ def forward(self, x):
501
+ type_probs = self.type_classifier(x)
502
+ gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
503
+ combined_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
504
+ return combined_gate
505
+
506
+ class MGate(nn.Module):
507
+ def __init__(self, dims, memory_size=64):
508
+ super().__init__()
509
+ self.mkey = nn.Parameter(torch.randn(memory_size, dims))
510
+ self.mvalue = nn.Parameter(torch.randn(memory_size, 1))
511
+ self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
512
+
513
+ def forward(self, x):
514
+ dgate = torch.sigmoid(self.gate_proj(x))
515
+ attention = torch.matmul(x, self.mkey.transpose(0, 1))
516
+ attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
517
+ mgate = torch.matmul(attention, self.mvalue)
518
+ mgate = torch.sigmoid(mgate)
519
+ return 0.5 * (dgate + mgate)
520
+
521
+ class CMGate(nn.Module):
522
+ def __init__(self, dims):
523
+ super().__init__()
524
+ self.sgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
525
+ self.wgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
526
+ self.pgate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
527
+ self.integration = Linear(dims*3, dims)
528
+
529
+ def forward(self, x, features):
530
+ sfeat = features.get("spectrogram", x)
531
+ wfeat = features.get("waveform", x)
532
+ pfeat = features.get("pitch", x)
533
+ spec = self.sgate(x) * sfeat
534
+ wave = self.wgate(x) * wfeat
535
+ pitch = self.pgate(x) * pfeat
536
+
537
+ combined = torch.cat([spec, wave, pitch], dim=-1)
538
+ return self.integration(combined)
539
+
540
+ class Residual(nn.Module):
541
+ _seen = set()
542
+ def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [],
543
+ fgate=False, tgate=False, mgate=False, cgate=False,
544
+ memory_size=512, features=None):
545
+ super().__init__()
546
+ self.ctx = ctx
547
+ self._counter = 0
548
+ self.dropout = 0.01
549
+ self.dims = dims
550
+ self.head = head
551
+ self.head_dim = dims // head
552
+ self.cross_attn = cross_attn
553
+ self.debug = debug
554
+ self.fgate = fgate
555
+ self.tgate = tgate
556
+ self.mgate = mgate
557
+ self.cgate = cgate
558
+ self.features = features
559
+ self.blend = nn.Parameter(torch.tensor(0.5))
560
+
561
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
562
+ "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
563
+ "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
564
+ "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
565
+ act_fn = act_map.get(act, nn.GELU())
566
+
567
+ self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
568
+ self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
569
+
570
+ mlp = dims * 4
571
+ self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
572
+
573
+ self.fgate = FCGate(dims=dims, dim=dims) if fgate else None
574
+ self.tgate = TTGate(dims=dims, num_types=4) if tgate else None
575
+ self.mgate = MGate(dims=dims, memory_size=memory_size) if mgate else None
576
+ self.cgate = CMGate(dims=dims) if cgate else None
577
+
578
+ self.lna = RMSNorm(dims)
579
+ self.lnb = RMSNorm(dims) if cross_attn else None
580
+ self.lnc = RMSNorm(dims)
581
+
582
+ if not any([fgate, tgate, mgate, cgate]):
583
+ self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
584
+
585
+ def forward(self, x, xa=None, mask=None, f0=None, mode=None):
586
+ x = x + self.attna(self.lna(x), mask=mask, f0=f0)[0]
587
+
588
+ if self.attnb and xa is not None:
589
+ cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None)[0]
590
+ blend = torch.sigmoid(self.blend)
591
+ x = blend * x + (1 - blend) * cross
592
+
593
+ normx = self.lnc(x)
594
+ mlp_out = self.mlp(normx)
595
+
596
+ if self.tgate:
597
+ gate = self.tgate(normx)
598
+ x = x + gate * mlp_out
599
+
600
+ elif self.fgate:
601
+ embedding = f0.mean(dim=1) if f0 is not None else xa.mean(dim=1)
602
+ gate = self.fg(normx, embedding)
603
+ x = x + gate * mlp_out
604
+
605
+ elif self.mgate:
606
+ gate = self.mgate(normx)
607
+ x = x + gate * mlp_out
608
+
609
+ elif self.cgate and mode is not None:
610
+ gate_output = self.cgate(normx, self.features)
611
+ x = x + gate_output
612
+
613
+ else:
614
+ if hasattr(self, 'mlp_gate'):
615
+ mlp_gate = self.mlp_gate(normx)
616
+ x = x + mlp_gate * mlp_out
617
+ else:
618
+ x = x + mlp_out
619
+ if "residual" in self.debug and self._counter % 100 == 0:
620
+ print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
621
+ self._counter += 1
622
+ return x
623
+
624
+ class PEncoder(nn.Module):
625
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act):
626
+ super().__init__()
627
+
628
+ self.head_dim = dims // head
629
+ self.dropout = 0.01
630
+
631
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
632
+ act_fn = act_map.get(act, nn.GELU())
633
+
634
+ self.encoder = nn.Sequential(
635
+ Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
636
+ Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
637
+ Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
638
+
639
+ def forward(self, x, f0=None):
640
+ x = self.encoder(x).permute(0, 2, 1)
641
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
642
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
643
+ x = self.norm(x)
644
+ return x
645
+
646
+ class WEncoder(nn.Module):
647
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act):
648
+ super().__init__()
649
+
650
+ self.head_dim = dims // head
651
+ self.dropout = 0.01
652
+
653
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
654
+ act_fn = act_map.get(act, nn.GELU())
655
+
656
+ self.downsample = nn.Sequential(
657
+ Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
658
+ Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
659
+ Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
660
+
661
+ self.encoder = nn.Sequential(
662
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
663
+ Conv1d(dims, dims, kernel_size=1), act_fn)
664
+
665
+ self.positional = lambda length: sinusoids(length, dims)
666
+ self.norm = RMSNorm(dims)
667
+
668
+ def forward(self, x, f0=None):
669
+ x = self.downsample(x)
670
+ x = self.encoder(x)
671
+ x = x.permute(0, 2, 1)
672
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
673
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
674
+ return self.norm(x)
675
+
676
+ class FEncoder(nn.Module):
677
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
678
+ super().__init__()
679
+
680
+ self.head_dim = dims // head
681
+ self.dropout = 0.01
682
+
683
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
684
+ act_fn = act_map.get(act, nn.GELU())
685
+
686
+ self.encoder = nn.Sequential(
687
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
688
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
689
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
690
+
691
+ self.positional = lambda length: sinusoids(length, dims)
692
+ self.norm = RMSNorm(dims)
693
+ self._norm = RMSNorm(dims)
694
+
695
+ def forward(self, x, f0=None):
696
+ x = self.encoder(x).permute(0, 2, 1)
697
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
698
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
699
+ x = self._norm(x)
700
+ return x
701
+
702
+ class AudioEncoder(nn.Module):
703
+ _seen = set()
704
+ def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str],
705
+ debug: List[str], f0_rotary: bool = False, act: str = "gelu"):
706
+ super(AudioEncoder, self).__init__()
707
+
708
+ self.debug = debug
709
+ self.features = features
710
+ self._counter = 0
711
+ self.dropout = 0.01
712
+ self.f0_rotary = f0_rotary
713
+ self.dims = dims
714
+ self.ctx = ctx
715
+ self.head = head
716
+ self.head_dim = dims // head
717
+
718
+ self.rope = rotary(
719
+ dims=self.head_dim,
720
+ debug=debug,)
721
+
722
+ act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
723
+ act_fn = act_map.get(act, nn.GELU())
724
+
725
+ if features == ["spectrogram", "waveform", "pitch"]:
726
+ cgate=True
727
+ else:
728
+ cgate = False
729
+
730
+ self.blocks = nn.ModuleDict({
731
+ "spectrogram": nn.ModuleList(
732
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
733
+ [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "spectrogram" in features else None
734
+ ),
735
+ "waveform": nn.ModuleList(
736
+ [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
737
+ [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "waveform" in features else None
738
+ ),
739
+ "pitch": nn.ModuleList(
740
+ [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
741
+ [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug, cgate=cgate, features=features) for _ in range(layer)] if "pitch" in features else None
742
+ ),
743
+ "spec_envelope": nn.ModuleList(
744
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
745
+ [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_envelope" in features else None
746
+ ),
747
+ "spec_phase": nn.ModuleList(
748
+ [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
749
+ [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
750
+ })
751
+
752
+ def forward(self, x, f0=None):
753
+ outputs = {}
754
+ if self.f0_rotary:
755
+ f0 = f0 if f0 is not None else x.get("pitch")
756
+ else:
757
+ f0 = None
758
+ for y in self.features:
759
+ if y in x and y in self.blocks:
760
+ f = x[y]
761
+ for block in self.blocks[y]:
762
+ f = block(f, f0=f0)
763
+ outputs[y] = f
764
+
765
+ if "encoder" in self.debug and self._counter % 100 == 0:
766
+ names = list(x.keys())
767
+ shapes = {k: v.shape for k, v in x.items()}
768
+ print(f"Step {self._counter}: mode: {names}")
769
+ print(f"shapes: {shapes}")
770
+ self._counter += 1
771
+ return outputs
772
+
773
+ class TextDecoder(nn.Module):
774
+ def __init__(self, vocab: int, layer: int, dims: int, head: int, ctx: int, cross_attn: bool,
775
+ features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False):
776
+ super(TextDecoder, self).__init__()
777
+
778
+ self._counter = 0
779
+ self.dropout = 0.01
780
+ self.debug = debug
781
+ self.sequential = sequential
782
+ self.features = features
783
+ self.f0_rotary = f0_rotary
784
+
785
+ self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
786
+ with torch.no_grad():
787
+ self.token.weight[0].zero_()
788
+ self.positional = nn.Parameter(data=torch.empty(ctx, dims), requires_grad=True)
789
+
790
+ self._blocks = nn.ModuleList([
791
+ Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
792
+ for _ in range(layer)])
793
+
794
+ self.blocks = nn.ModuleDict({
795
+ f: nn.ModuleList([Residual(dims=dims, head=head, ctx=ctx, act="gelu", cross_attn=cross_attn, debug=debug, features=features)
796
+ for _ in range(layer)]) for f in features})
797
+
798
+ self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
799
+
800
+ self.ln_dec = RMSNorm(dims)
801
+
802
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
803
+ self.register_buffer("mask", mask, persistent=False)
804
+
805
+ def forward(self, x, enc, order=None, f0=None) -> Tensor:
806
+ x = x.to(device)
807
+ if self.f0_rotary:
808
+ f0 = f0
809
+ else:
810
+ f0 = None
811
+ if order is None:
812
+ order = self.features
813
+ mask = self.mask[:x.shape[1], :x.shape[1]]
814
+ x = self.token(x) + self.positional[:x.shape[1]]
815
+ x = F.dropout(x, p=self.dropout, training=self.training)
816
+ for block in self._blocks:
817
+ x = block(x, f0=f0, mask=mask)
818
+ for f in order:
819
+ if f in enc:
820
+ xa = enc[f]
821
+ for block in self.blocks[f]:
822
+ out = block(x=x, xa=xa, f0=f0, mask=None)
823
+ a = torch.sigmoid(self.blend[f])
824
+ x = a * out + (1 - a) * x
825
+ x = self.ln_dec(x)
826
+ return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
827
+
828
+ class Echo(nn.Module):
829
+ def __init__(self, param: Dimensions):
830
+ super().__init__()
831
+ self.param = param
832
+
833
+ self.encoder = AudioEncoder(
834
+ mels=param.mels,
835
+ ctx=param.aud_ctx,
836
+ dims=param.aud_dims,
837
+ head=param.aud_head,
838
+ layer=param.aud_idx,
839
+ act=param.act,
840
+ debug=param.debug,
841
+ features=param.features,
842
+ f0_rotary=param.f0_rotary,
843
+ )
844
+
845
+ self.decoder = TextDecoder(
846
+ vocab=param.vocab,
847
+ ctx=param.text_ctx,
848
+ dims=param.text_dims,
849
+ head=param.text_head,
850
+ layer=param.text_idx,
851
+ cross_attn=param.cross_attn,
852
+ debug=param.debug,
853
+ features=param.features,
854
+ f0_rotary=param.f0_rotary,
855
+ )
856
+
857
+ all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool)
858
+ all_head[self.param.text_idx // 2 :] = True
859
+ self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False)
860
+
861
+ def set_alignment_head(self, dump: bytes):
862
+ array = np.frombuffer(
863
+ gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
864
+ mask = torch.from_numpy(array).reshape(
865
+ self.param.text_idx, self.param.text_head)
866
+ self.register_buffer("alignment_head", mask.to_sparse(), persistent=False)
867
+
868
+ def embed_audio(self, spectrogram: torch.Tensor):
869
+ return self.encoder(spectrogram)
870
+
871
+ def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor):
872
+ return self.decoder(input_ids, encoder_output)
873
+
874
+ def forward(self,
875
+ decoder_input_ids=None,
876
+ labels=None,
877
+ waveform: Optional[torch.Tensor]=None,
878
+ input_ids=None,
879
+ spectrogram: torch.Tensor=None,
880
+ pitch: Optional[torch.Tensor]=None,
881
+ f0: Optional[torch.Tensor]=None,
882
+ envelope: Optional[torch.Tensor]=None,
883
+ phase: Optional[torch.Tensor]=None,
884
+ ) -> Dict[str, torch.Tensor]:
885
+
886
+ decoder_input_ids = input_ids
887
+ encoder_inputs = {}
888
+ if spectrogram is not None:
889
+ encoder_inputs["spectrogram"] = spectrogram
890
+ if waveform is not None:
891
+ encoder_inputs["waveform"] = waveform
892
+ if pitch is not None:
893
+ encoder_inputs["pitch"] = pitch
894
+ if envelope is not None:
895
+ encoder_inputs["envelope"] = envelope
896
+ if phase is not None:
897
+ encoder_inputs["phase"] = phase
898
+
899
+ encoder_outputs = self.encoder(encoder_inputs, f0=f0)
900
+ logits = self.decoder(input_ids, encoder_outputs, f0=f0)
901
+
902
+ loss = None
903
+ if labels is not None:
904
+ loss = F.cross_entropy(
905
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
906
+
907
+ return {
908
+ "logits": logits,
909
+ "loss": loss,
910
+ "labels": labels,
911
+ "input_ids": input_ids,
912
+ "decoder_input_ids": decoder_input_ids,
913
+ "encoder_output": encoder_outputs
914
+ }
915
+
916
+ def device(self):
917
+ return next(self.parameters()).device
918
+ @property
919
+ def dtype(self):
920
+ return next(self.parameters()).dtype
921
+
922
+ def _init_weights(self, module):
923
+ std = 0.02
924
+ self.init_counts = {
925
+ "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
926
+ "Conv2d": 0, "SEBlock": 0, "TextDecoder": 0, "AudioEncoder": 0,
927
+ "Residual": 0, "MultiheadA": 0, "MultiheadB - Cross Attention": 0,
928
+ "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
929
+ "WEncoder": 0, "PEncoder": 0}
930
+
931
+ for module in self.named_modules():
932
+ if isinstance(module, Linear):
933
+ nn.init.xavier_uniform_(module.weight)
934
+ if module.bias is not None:
935
+ nn.init.zeros_(module.bias)
936
+ self.init_counts["Linear"] += 1
937
+ elif isinstance(module, Conv1d):
938
+ nn.init.normal_(module.weight, mean=0.0, std=std)
939
+ if module.bias is not None:
940
+ nn.init.zeros_(module.bias)
941
+ self.init_counts["Conv1d"] += 1
942
+
943
+ elif isinstance(module, RMSNorm):
944
+ nn.init.ones_(module.weight)
945
+ self.init_counts["RMSNorm"] += 1
946
+ elif isinstance(module, MultiheadA):
947
+ self.init_counts["MultiheadA"] += 1
948
+ elif isinstance(module, Conv2d):
949
+ nn.init.normal_(module.weight, mean=0.0, std=std)
950
+ if module.bias is not None:
951
+ nn.init.zeros_(module.bias)
952
+ self.init_counts["Conv2d"] += 1
953
+
954
+ elif isinstance(module, TextDecoder):
955
+ self.init_counts["TextDecoder"] += 1
956
+ elif isinstance(module, AudioEncoder):
957
+ self.init_counts["AudioEncoder"] += 1
958
+ elif isinstance(module, Residual):
959
+ self.init_counts["Residual"] += 1
960
+
961
+ def init_weights(self):
962
+ print("Initializing all weights")
963
+ self.apply(self._init_weights)
964
+ print("Initialization summary:")
965
+ for module_type, count in self.init_counts.items():
966
+ if count > 0:
967
+ print(f"{module_type}: {count}")
968
+
969
+ metric = evaluate.load(path="wer")
970
+
971
+ @dataclass
972
+ class DataCollator:
973
+ tokenizer: Any
974
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
975
+ pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
976
+ bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
977
+
978
+ batch = {}
979
+
980
+ if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
981
+ spectrogram_list = [f["spectrogram"] for f in features]
982
+ max_len_feat = max(f.shape[-1] for f in spectrogram_list)
983
+ pad_spectrogram = []
984
+ for feat in spectrogram_list:
985
+ current_len = feat.shape[-1]
986
+ padding = max_len_feat - current_len
987
+ if padding > 0:
988
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
989
+ else:
990
+ pad_feat = feat
991
+ pad_spectrogram.append(pad_feat)
992
+ batch["spectrogram"] = torch.stack(pad_spectrogram)
993
+
994
+ if "waveform" in features[0] and features[0]["waveform"] is not None:
995
+ waveform_list = [f["waveform"] for f in features]
996
+ max_len_wav = max(w.shape[-1] for w in waveform_list)
997
+ pad_waveforms = []
998
+ for wav in waveform_list:
999
+ current_len = wav.shape[-1]
1000
+ padding = max_len_wav - current_len
1001
+ if padding > 0:
1002
+ if wav.ndim == 1:
1003
+ wav = wav.unsqueeze(0)
1004
+ pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1005
+ else:
1006
+ pad_wav = wav
1007
+ pad_waveforms.append(pad_wav)
1008
+ batch["waveform"] = torch.stack(pad_waveforms)
1009
+
1010
+ if "label" in features[0] and features[0]["label"] is not None:
1011
+ labels_list = [f["label"] for f in features]
1012
+ max_len = max(len(l) for l in labels_list)
1013
+ all_ids = []
1014
+ all_labels = []
1015
+
1016
+ for label in labels_list:
1017
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1018
+ decoder_input = [bos_token_id] + label_list
1019
+ label_eos = label_list + [pad_token_id]
1020
+ input_len = max_len + 1 - len(decoder_input)
1021
+ label_len = max_len + 1 - len(label_eos)
1022
+ padded_input = decoder_input + [pad_token_id] * input_len
1023
+ padded_labels = label_eos + [pad_token_id] * label_len
1024
+ all_ids.append(padded_input)
1025
+ all_labels.append(padded_labels)
1026
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1027
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1028
+
1029
+ if "pitch" in features[0] and features[0]["pitch"] is not None:
1030
+ pitch_list = [f["pitch"] for f in features]
1031
+ max_len_pitch = max(e.shape[-1] for e in pitch_list)
1032
+ pad_pitch = []
1033
+ for pitch in pitch_list:
1034
+ current_len = pitch.shape[-1]
1035
+ padding = max_len_pitch - current_len
1036
+ if padding > 0:
1037
+ pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1038
+ else:
1039
+ pad_pitch_item = pitch
1040
+ pad_pitch.append(pad_pitch_item)
1041
+ batch["pitch"] = torch.stack(pad_pitch)
1042
+
1043
+ if "f0" in features[0] and features[0]["f0"] is not None:
1044
+ all_f0 = torch.cat([f["f0"] for f in features])
1045
+ batch["f0"] = all_f0.unsqueeze(0)
1046
+
1047
+ if "envelope" in features[0] and features[0]["envelope"] is not None:
1048
+ env_list = [f["envelope"] for f in features]
1049
+ max_len = max(f.shape[-1] for f in env_list)
1050
+ pad_env = []
1051
+ for feat in env_list:
1052
+ current_len = feat.shape[-1]
1053
+ padding = max_len_feat - current_len
1054
+ if padding > 0:
1055
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1056
+ else:
1057
+ pad_feat = feat
1058
+ pad_env.append(pad_feat)
1059
+ batch["envelope"] = torch.stack(pad_env)
1060
+
1061
+ if "phase" in features[0] and features[0]["phase"] is not None:
1062
+ ph_list = [f["phase"] for f in features]
1063
+ max_len = max(f.shape[-1] for f in ph_list)
1064
+ pad_ph = []
1065
+ for feat in ph_list:
1066
+ current_len = feat.shape[-1]
1067
+ padding = max_len_feat - current_len
1068
+ if padding > 0:
1069
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1070
+ else:
1071
+ pad_feat = feat
1072
+ pad_ph.append(pad_feat)
1073
+ batch["phase"] = torch.stack(pad_ph)
1074
+ return batch
1075
+
1076
+ def hilbert_transform(x):
1077
+ N = x.shape[-1]
1078
+ xf = torch.fft.rfft(x)
1079
+ h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1080
+ if N % 2 == 0:
1081
+ h[0] = h[N//2] = 1
1082
+ h[1:N//2] = 2
1083
+ else:
1084
+ h[0] = 1
1085
+ h[1:(N+1)//2] = 2
1086
+ return torch.fft.irfft(xf * h, n=N)
1087
+
1088
+ def analytic_signal(x):
1089
+ return x + 1j * hilbert_transform(x)
1090
+
1091
+ def hilbert_transform_2d(x, dim=-1):
1092
+ N = x.shape[dim]
1093
+ if dim == -1 or dim == len(x.shape) - 1:
1094
+ xf = torch.fft.rfft(x)
1095
+ else:
1096
+ xf = torch.fft.rfft(x, dim=dim)
1097
+ h_shape = [1] * len(x.shape)
1098
+ h_shape[dim] = N // 2 + 1
1099
+ h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1100
+ if dim == -1 or dim == len(x.shape) - 1:
1101
+ if N % 2 == 0:
1102
+ h[..., 0] = h[..., -1] = 1
1103
+ h[..., 1:-1] = 2
1104
+ else:
1105
+ h[..., 0] = 1
1106
+ h[..., 1:] = 2
1107
+ else:
1108
+ pass
1109
+ return torch.fft.irfft(xf * h, n=N, dim=dim)
1110
+
1111
+ def hilbert_transform_true_2d(x):
1112
+ xf = torch.fft.rfft2(x)
1113
+ h1, h2 = torch.meshgrid(
1114
+ torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1115
+ torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1116
+ indexing='ij')
1117
+ h = -1j / (math.pi * (h1 + 1j*h2))
1118
+ h[0, 0] = 0
1119
+ return torch.fft.irfft2(xf * h.to(x.device))
1120
+
1121
+ def process_spectrogram_with_hilbert(spec):
1122
+ analytic = spec + 1j * hilbert_transform(spec)
1123
+ envelope = torch.abs(analytic)
1124
+ phase = torch.angle(analytic)
1125
+ return envelope, phase
1126
+
1127
+ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
1128
+ hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1129
+ pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1130
+ norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1131
+
1132
+ dtype = torch.float32
1133
+ device = torch.device("cuda:0")
1134
+ audio = batch["audio"]
1135
+ sampling_rate = audio["sampling_rate"]
1136
+
1137
+ wav = torch.tensor(audio["array"]).float()
1138
+ sr = audio["sampling_rate"]
1139
+
1140
+ if sr != sampling_rate:
1141
+ original_length = wav.shape[-1]
1142
+ target_length = int(original_length * (sampling_rate / sr))
1143
+
1144
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate)
1145
+ wav = resampler(wav)
1146
+
1147
+ if abs(wav.shape[-1] - target_length) > 1:
1148
+ new_waveform = torch.zeros((wav.shape[0], target_length), dtype=dtype, device=device)
1149
+ copy_length = min(wav.shape[1], target_length)
1150
+ new_waveform[:, :copy_length] = wav[:, :copy_length]
1151
+ wav = new_waveform
1152
+
1153
+ if spectrogram:
1154
+ transform = torchaudio.transforms.MelSpectrogram(
1155
+ f_max=fmax,
1156
+ f_min=fmin,
1157
+ n_mels=n_mels,
1158
+ sample_rate=sr,
1159
+ n_fft=n_fft,
1160
+ hop_length=hop_length,
1161
+ norm=norm,
1162
+ normalized=normalized,
1163
+ power=power,
1164
+ center=center,
1165
+ mel_scale=mel_scale,
1166
+ window_fn=window_fn,
1167
+ pad_mode=pad_mode)
1168
+
1169
+ mel_spectrogram = transform(wav)
1170
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1171
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1172
+ spec = (log_mel + 4.0) / 4.0
1173
+ spec = torch.tensor(spec)
1174
+ batch["spectrogram"] = spec
1175
+
1176
+ if hilbert:
1177
+ envelope_list = []
1178
+ phase_list = []
1179
+
1180
+ for ch_idx in range(spec.shape[0]):
1181
+ envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
1182
+ envelope_list.append(envelope)
1183
+ phase_list.append(phase)
1184
+
1185
+ batch["envelope"] = torch.stack(envelope_list)
1186
+ batch["phase"] = torch.stack(phase_list)
1187
+
1188
+ wav_1d = wav.unsqueeze(0)
1189
+
1190
+ if waveforms:
1191
+ batch["waveform"] = wav_1d
1192
+
1193
+ if pitch:
1194
+ if period:
1195
+ pit, periodocity = torchcrepe.predict(
1196
+ wav_1d,
1197
+ sampling_rate,
1198
+ hop_length,
1199
+ fmin=80,
1200
+ fmax=800,
1201
+ model="tiny",
1202
+ decoder=torchcrepe.decode.viterbi,
1203
+ return_periodicity=True,
1204
+ device=device,
1205
+ pad=True
1206
+ )
1207
+ batch["pitch"] = pit
1208
+ batch["period"] = periodocity
1209
+ else:
1210
+ pit = torchcrepe.predict(
1211
+ wav_1d,
1212
+ sampling_rate,
1213
+ hop_length,
1214
+ fmin=80,
1215
+ fmax=800,
1216
+ model="tiny",
1217
+ decoder=torchcrepe.decode.viterbi,
1218
+ return_periodicity=False,
1219
+ device=device,
1220
+ pad=True
1221
+ )
1222
+ batch["pitch"] = pit
1223
+
1224
+ if f0:
1225
+ wav_np = wav.numpy().astype(np.float64)
1226
+ f0, t = pw.dio(wav_np, sampling_rate,
1227
+ frame_period=hop_length/sampling_rate*1000)
1228
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1229
+ batch["f0"] = torch.from_numpy(f0).float()
1230
+
1231
+ if spectrogram and waveforms and pitch:
1232
+ spec_mean = batch["spectrogram"].mean()
1233
+ spec_std = batch["spectrogram"].std() + 1e-6
1234
+ batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1235
+
1236
+ wav_mean = batch["waveform"].mean()
1237
+ wav_std = batch["waveform"].std() + 1e-6
1238
+ batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1239
+
1240
+ if batch["pitch"].max() > 1.0:
1241
+ pitch_min = 50.0
1242
+ pitch_max = 600.0
1243
+ batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1244
+
1245
+ batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1246
+ return batch
1247
+
1248
+ def compute_metrics(eval_pred, compute_result: bool = True,
1249
+ print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
1250
+
1251
+ pred_logits = eval_pred.predictions
1252
+ label_ids = eval_pred.label_ids
1253
+
1254
+ if hasattr(pred_logits, "cpu"):
1255
+ pred_logits = pred_logits.cpu()
1256
+ if hasattr(label_ids, "cpu"):
1257
+ label_ids = label_ids.cpu()
1258
+ if isinstance(pred_logits, tuple):
1259
+ pred_ids = pred_logits[0]
1260
+ else:
1261
+ pred_ids = pred_logits
1262
+ if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1263
+ if not isinstance(pred_ids, torch.Tensor):
1264
+ pred_ids = torch.tensor(pred_ids)
1265
+ pred_ids = pred_ids.argmax(dim=-1)
1266
+ pred_ids = pred_ids.tolist()
1267
+
1268
+ if hasattr(label_ids, "tolist"):
1269
+ label_ids = label_ids.tolist()
1270
+
1271
+ label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1272
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1273
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1274
+
1275
+ if print_pred:
1276
+ for i in range(min(num_samples, len(pred_str))):
1277
+ print(f"Preds: {pred_str[i]}")
1278
+ print(f"Label: {label_str[i]}")
1279
+ print(f"preds: {pred_ids[i]}")
1280
+ print(f"label: {label_ids[i]}")
1281
+ print("--------------------------------")
1282
+
1283
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1284
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1285
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1286
+
1287
+ if model is None:
1288
+ global global_model
1289
+ if 'global_model' in globals():
1290
+ model = global_model
1291
+
1292
+ if model is not None:
1293
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1294
+ if trainable_params > 0:
1295
+ efficiency_score = (100 - wer) / trainable_params
1296
+ else:
1297
+ print("Warning: Zero trainable parameters detected")
1298
+ efficiency_score = 0.0
1299
+ else:
1300
+ print("Warning: Model not available for parameter counting")
1301
+ trainable_params = 0.0
1302
+ efficiency_score = 0.0
1303
+
1304
+ if hasattr(wer, "item"):
1305
+ wer = wer.item()
1306
+
1307
+ metrics = {
1308
+ "wer": float(wer),
1309
+ "trainable_params_M": float(trainable_params),
1310
+ "efficiency_score": float(efficiency_score),
1311
+ }
1312
+
1313
+ print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}")
1314
+ return metrics
1315
+
1316
+ logger = logging.getLogger(__name__)
1317
+
1318
+ def create_model(param: Dimensions) -> Echo:
1319
+ model = Echo(param).to('cuda')
1320
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1321
+ total_params = sum(p.numel() for p in model.parameters())
1322
+ logger.info(f"Trainable parameters: {trainable_params:,}")
1323
+ logger.info(f"Total parameters: {total_params:,}")
1324
+ print(f"Trainable parameters: {trainable_params:,}")
1325
+ print(f"Total parameters: {total_params:,}")
1326
+ model.init_weights()
1327
+ return model
1328
+
1329
+ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
1330
+ from tokenizers import Tokenizer
1331
+ tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1332
+ orig_encode = tokenizer.encode
1333
+ def enc(text, add_special_tokens=True):
1334
+ ids = orig_encode(text).ids
1335
+ if not add_special_tokens:
1336
+ sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1337
+ ids = [id for id in ids if id not in sp_ids]
1338
+ return ids
1339
+ def bdec(ids_list, skip_special_tokens=True):
1340
+ results = []
1341
+ for ids in ids_list:
1342
+ if skip_special_tokens:
1343
+ ids = [id for id in ids if id not in [0, 1, 2]]
1344
+ results.append(tokenizer.decode(ids))
1345
+ return results
1346
+ def save_pretrained(save_dir):
1347
+ os.makedirs(save_dir, exist_ok=True)
1348
+ tokenizer.save(f"{save_dir}/tokenizer.json")
1349
+ tokenizer.encode = enc
1350
+ tokenizer.batch_decode = bdec
1351
+ tokenizer.save_pretrained = save_pretrained
1352
+ tokenizer.pad_token_id = 0
1353
+ tokenizer.bos_token_id = 1
1354
+ tokenizer.eos_token_id = 2
1355
+ return tokenizer
1356
+
1357
+ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1358
+ if dataset_config is None:
1359
+ dataset_config = {
1360
+ "spectrogram": True,
1361
+ "waveforms": True,
1362
+ "pitch": True,
1363
+ "f0": True,
1364
+ "downsamples": True,
1365
+ "hop_length": 128,
1366
+ "fmin": 50,
1367
+ "fmax": 2000,
1368
+ "n_mels": 128,
1369
+ "n_fft": 1024,
1370
+ "sampling_rate": 16000,
1371
+ }
1372
+
1373
+ dataset = load_dataset(
1374
+ "google/fleurs",
1375
+ "en_us",
1376
+ token=token,
1377
+ trust_remote_code=True,
1378
+ streaming=False
1379
+ )
1380
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000))
1381
+
1382
+ if sanity_check:
1383
+ dataset = dataset["test"].take(10).shuffle()
1384
+ dataset = dataset.select_columns(["audio", "transcription"])
1385
+ logger.info(f"Sanity dataset size: {dataset.num_rows}")
1386
+ print(f"Sanity dataset size: {dataset.num_rows}")
1387
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1388
+
1389
+ dataset = dataset.map(
1390
+ function=prepare_fn,
1391
+ remove_columns=["audio", "transcription"]
1392
+ ).with_format(type="torch")
1393
+ train_dataset = dataset
1394
+ test_dataset = dataset
1395
+ else:
1396
+ def filter_func(x):
1397
+ return (0 < len(x["transcription"]) < 512 and
1398
+ len(x["audio"]["array"]) > 0 and
1399
+ len(x["audio"]["array"]) < 1500 * 160)
1400
+
1401
+ dataset = dataset.filter(filter_func).shuffle()
1402
+ logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1403
+ print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1404
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1405
+ train_dataset = dataset["train"]
1406
+ test_dataset = dataset["test"]
1407
+ columns_to_remove = list(next(iter(dataset.values())).features)
1408
+
1409
+ train_dataset = train_dataset.map(
1410
+ function=prepare_fn,
1411
+ remove_columns=columns_to_remove
1412
+ ).with_format(type="torch")
1413
+
1414
+ test_dataset = test_dataset.map(
1415
+ function=prepare_fn,
1416
+ remove_columns=columns_to_remove
1417
+ ).with_format(type="torch")
1418
+
1419
+ return train_dataset, test_dataset
1420
+
1421
+ def get_training_args(
1422
+ log_dir: str,
1423
+ batch_eval_metrics: bool = False,
1424
+ max_steps: int = 10,
1425
+ save_steps: int = 1000,
1426
+ eval_steps: int = 100,
1427
+ warmup_steps: int = 0,
1428
+ num_train_epochs: int = 1,
1429
+ logging_steps: int = 10,
1430
+ eval_on_start: bool = False,
1431
+ learning_rate: float = 1e-4,
1432
+ weight_decay: float = 0.01,
1433
+ max_grad_norm: float = 1.0,
1434
+ ) -> Seq2SeqTrainingArguments:
1435
+
1436
+ return Seq2SeqTrainingArguments(
1437
+ output_dir=log_dir,
1438
+ per_device_train_batch_size=1,
1439
+ per_device_eval_batch_size=1,
1440
+ gradient_accumulation_steps=1,
1441
+ eval_accumulation_steps=1,
1442
+ tf32=True,
1443
+ bf16=True,
1444
+ eval_strategy="steps",
1445
+ save_strategy="steps",
1446
+ max_steps=max_steps,
1447
+ save_steps=save_steps,
1448
+ eval_steps=eval_steps,
1449
+ warmup_steps=warmup_steps,
1450
+ num_train_epochs=num_train_epochs,
1451
+ logging_steps=logging_steps,
1452
+ logging_dir=log_dir,
1453
+ logging_strategy="steps",
1454
+ report_to=["tensorboard"],
1455
+ push_to_hub=False,
1456
+ disable_tqdm=False,
1457
+ save_total_limit=1,
1458
+ label_names=["labels"],
1459
+ optim="adamw_torch",
1460
+ lr_scheduler_type="cosine",
1461
+ learning_rate=learning_rate,
1462
+ weight_decay=weight_decay,
1463
+ save_safetensors=False,
1464
+ eval_on_start=eval_on_start,
1465
+ batch_eval_metrics=batch_eval_metrics,
1466
+ max_grad_norm=max_grad_norm,
1467
+
1468
+ )
1469
+
1470
+ def main():
1471
+
1472
+ token = ""
1473
+ log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
1474
+ os.makedirs(name=log_dir, exist_ok=True)
1475
+ tokenizer = setup_tokenizer(token)
1476
+
1477
+ def sanity(sanity: bool):
1478
+
1479
+ if sanity:
1480
+ training_args = get_training_args(
1481
+ log_dir,
1482
+ batch_eval_metrics = False,
1483
+ max_steps = 10,
1484
+ save_steps = 0,
1485
+ eval_steps = 1,
1486
+ warmup_steps = 0,
1487
+ logging_steps = 1,
1488
+ eval_on_start = True,
1489
+ learning_rate = 5e-6,
1490
+ weight_decay = 0.01,
1491
+ )
1492
+ else:
1493
+ training_args = get_training_args(
1494
+ log_dir,
1495
+ batch_eval_metrics = False,
1496
+ max_steps = 10000,
1497
+ save_steps = 10000,
1498
+ eval_steps = 1000,
1499
+ warmup_steps = 1000,
1500
+ logging_steps = 100,
1501
+ eval_on_start = False,
1502
+ learning_rate = 2.5e-4,
1503
+ weight_decay = 0.01,
1504
+ )
1505
+
1506
+ return training_args
1507
+
1508
+ param = Dimensions(
1509
+ mels=128,
1510
+ aud_ctx=1500,
1511
+ aud_head=4,
1512
+ aud_dims=512,
1513
+ aud_idx=4,
1514
+ vocab=40000,
1515
+ text_ctx=512,
1516
+ text_head=4,
1517
+ text_dims=512,
1518
+ text_idx=4,
1519
+ act="swish",
1520
+ debug={},
1521
+ cross_attn=True,
1522
+ f0_rotary=True,
1523
+ features = ["spectrogram"],
1524
+ )
1525
+
1526
+ sanity_check = False
1527
+
1528
+ training_args = sanity(sanity_check)
1529
+
1530
+ dataset_config = {
1531
+ "spectrogram": True,
1532
+ "waveforms": False,
1533
+ "pitch": False,
1534
+ "downsamples": False,
1535
+ "f0": True,
1536
+ "hilbert": False,
1537
+ "hop_length": 128,
1538
+ "fmin": 150,
1539
+ "fmax": 2000,
1540
+ "n_mels": 128,
1541
+ "n_fft": 1024,
1542
+ "sampling_rate": 16000,
1543
+ "pad_mode": "constant",
1544
+ "center": True,
1545
+ "power": 2.0,
1546
+ "window_fn": torch.hann_window,
1547
+ "mel_scale": "htk",
1548
+ "norm": None,
1549
+ "normalized": False}
1550
+
1551
+ model = create_model(param)
1552
+
1553
+ global global_model
1554
+ global_model = model
1555
+
1556
+ metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1557
+ tokenizer=tokenizer, model=model)
1558
+
1559
+ print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1560
+ train_dataset, test_dataset = prepare_datasets(
1561
+ tokenizer=tokenizer,
1562
+ token=token,
1563
+ sanity_check=sanity_check,
1564
+ dataset_config=dataset_config)
1565
+
1566
+
1567
+ trainer = Seq2SeqTrainer(
1568
+ args=training_args,
1569
+ model=model,
1570
+ train_dataset=train_dataset,
1571
+ eval_dataset=test_dataset,
1572
+ data_collator=DataCollator(tokenizer=tokenizer),
1573
+ compute_metrics=metrics_fn,
1574
+ )
1575
+
1576
+ trainer.train()
1577
+
1578
+ if __name__ == "__main__":
1579
+ main()
1580
+
1581
+