Sin2pi commited on
Commit
79a5277
·
verified ·
1 Parent(s): e0e4084

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +445 -1149
modelA.py CHANGED
@@ -1,313 +1,29 @@
1
  import os
2
- import pyworld as pw
3
  import math
4
  import warnings
5
  import logging
 
6
  import torch
7
- import torchaudio
8
  import torch.nn.functional as F
9
- import torch.nn.init as init
10
  from torch import nn, Tensor
11
-
12
- import matplotlib.pyplot as plt
13
- from typing import Optional, Dict, Union, List, Tuple, Any
14
  import numpy as np
15
  from functools import partial
16
  from datetime import datetime
17
- from datasets import load_dataset, Audio
18
  from transformers.trainer_seq2seq import Seq2SeqTrainer
19
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
20
- import transformers
21
- from dataclasses import dataclass
22
- from opimizer import MaxFactor
23
- from transformers.generation.configuration_utils import GenerationConfig
24
- torch.backends.cudnn.allow_tf32 = True
25
- torch.backends.cuda.matmul.allow_tf32 = True
26
- torch.set_float32_matmul_precision('high')
27
- transformers.utils.logging.set_verbosity_error()
28
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
  dtype = torch.float32
31
-
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
 
35
- PATH = 'E:/hf'
36
- os.environ['HF_HOME'] = PATH
37
- os.environ['HF_DATASETS_CACHE'] = PATH
38
- os.environ['TORCH_HOME'] = PATH
39
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
40
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
-
42
- def get_activation(act: str) -> nn.Module:
43
- """Get activation function by name."""
44
- act_map = {
45
- "gelu": nn.GELU(),
46
- "relu": nn.ReLU(),
47
- "sigmoid": nn.Sigmoid(),
48
- "tanh": nn.Tanh(),
49
- "swish": nn.SiLU(),
50
- "tanhshrink": nn.Tanhshrink(),
51
- "softplus": nn.Softplus(),
52
- "softshrink": nn.Softshrink(),
53
- "leaky_relu": nn.LeakyReLU(),
54
- "elu": nn.ELU()
55
- }
56
- return act_map.get(act, nn.GELU())
57
-
58
- @dataclass
59
- class Dimensions:
60
- vocab: int
61
- mels: int
62
- ctx: int
63
- dims: int
64
- head: int
65
- layer: int
66
- act: str
67
- debug: List[str]
68
- features: List[str]
69
-
70
- def get_generation_config(param):
71
- return GenerationConfig(
72
- max_length=param.text_ctx,
73
- pad_token_id=getattr(param, "pad_token_id", 0),
74
- bos_token_id=getattr(param, "bos_token_id", 1),
75
- eos_token_id=getattr(param, "eos_token_id", 2),
76
- do_sample=False,
77
- num_beams=1,
78
- early_stopping=False,
79
- length_penalty=1.0,
80
- no_repeat_ngram_size=0,
81
- repetition_penalty=1.0,
82
- temperature=1.0,
83
- decoder_start_token_id=1,
84
- is_multilingual=False,
85
- use_cache=False,
86
- return_timestamps=False)
87
-
88
- def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
89
- title="", markers=None, marker_labels=None,
90
- show_voiced_regions=True, show_energy=False):
91
- num_plots = sum([x is not None, w is not None, p is not None, per is not None])
92
- if num_plots == 0:
93
- raise ValueError("No data to plot. Please provide at least one input tensor.")
94
- t_spans = []
95
-
96
- if w is not None:
97
- w_np = w[sample_idx].detach().cpu().numpy()
98
- if w_np.ndim > 1:
99
- w_np = w_np.squeeze()
100
- t_spans.append(len(w_np) / sr)
101
- if x is not None:
102
- x_np = x[sample_idx].detach().cpu().numpy()
103
- if x_np.shape[0] < x_np.shape[1]:
104
- x_np = x_np.T
105
- t_spans.append(x_np.shape[0] * hop_length / sr)
106
- if p is not None:
107
- p_np = p[sample_idx].detach().cpu().numpy()
108
- if p_np.ndim > 1:
109
- p_np = p_np.squeeze()
110
- t_spans.append(len(p_np) * hop_length / sr)
111
- if per is not None:
112
- per_np = per[sample_idx].detach().cpu().numpy()
113
- if per_np.ndim > 1:
114
- per_np = per_np.squeeze()
115
- t_spans.append(len(per_np) * hop_length / sr)
116
- max_t = max(t_spans) if t_spans else 0
117
- fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
118
- if num_plots == 1:
119
- axs = [axs]
120
- if show_voiced_regions and per is not None:
121
- per_np = per[sample_idx].detach().cpu().numpy()
122
- if per_np.ndim > 1:
123
- per_np = per_np.squeeze()
124
- t_per = np.arange(len(per_np)) * hop_length / sr
125
- threshold = 0.5
126
- for ax in axs:
127
- for i in range(len(per_np)-1):
128
- if per_np[i] > threshold:
129
- ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
130
- cu_ax = 0
131
- if w is not None:
132
- w_np = w[sample_idx].detach().cpu().numpy()
133
- if w_np.ndim > 1:
134
- w_np = w_np.squeeze()
135
- t = np.arange(len(w_np)) / sr
136
- axs[cu_ax].plot(t, w_np, color="tab:blue")
137
-
138
- if show_energy:
139
- frame_length = hop_length
140
- hop_length_energy = hop_length // 2
141
- energy = []
142
- for i in range(0, len(w_np)-frame_length, hop_length_energy):
143
- frame = w_np[i:i+frame_length]
144
- energy.append(np.sqrt(np.mean(frame**2)))
145
- energy = np.array(energy)
146
- energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
147
- t_energy = np.arange(len(energy)) * hop_length_energy / sr
148
- axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
149
- axs[cu_ax].legend(loc='upper right')
150
- axs[cu_ax].set_title("Waveform")
151
- axs[cu_ax].set_ylabel("Amplitude")
152
- axs[cu_ax].set_xlim([0, max_t])
153
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
154
- cu_ax += 1
155
-
156
- if x is not None:
157
- x_np = x[sample_idx].detach().cpu().numpy()
158
- if x_np.shape[0] < x_np.shape[1]:
159
- x_np = x_np.T
160
- axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
161
- extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
162
- axs[cu_ax].set_title("Spectrogram")
163
- axs[cu_ax].set_ylabel("Mel Bin")
164
- axs[cu_ax].set_xlim([0, max_t])
165
- axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
166
- cu_ax += 1
167
-
168
- if p is not None:
169
- p_np = p[sample_idx].detach().cpu().numpy()
170
- if p_np.ndim > 1:
171
- p_np = p_np.squeeze()
172
- t_p = np.arange(len(p_np)) * hop_length / sr
173
- axs[cu_ax].plot(t_p, p_np, color="tab:green")
174
- axs[cu_ax].set_title("Pitch")
175
- axs[cu_ax].set_ylabel("Frequency (Hz)")
176
- axs[cu_ax].set_xlim([0, max_t])
177
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
178
- axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
179
- cu_ax += 1
180
-
181
- if per is not None:
182
- per_np = per[sample_idx].detach().cpu().numpy()
183
- if per_np.ndim > 1:
184
- per_np = per_np.squeeze()
185
- t_per = np.arange(len(per_np)) * hop_length / sr
186
- axs[cu_ax].plot(t_per, per_np, color="tab:red")
187
- axs[cu_ax].set_title("Period (Voice Activity)")
188
- axs[cu_ax].set_ylabel("periodocity")
189
- axs[cu_ax].set_xlim([0, max_t])
190
- axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
191
- axs[cu_ax].set_ylim([-0.05, 1.05])
192
- axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
193
-
194
- if markers is not None:
195
- for i, t in enumerate(markers):
196
- label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
197
- for ax in axs:
198
- ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
199
- if marker_labels:
200
- axs[0].legend(loc='upper right', fontsize='small')
201
- axs[-1].set_xlabel("t (s)")
202
- fig.suptitle(title, fontsize=16)
203
- plt.tight_layout(rect=[0, 0, 1, 0.97])
204
- plt.show()
205
- return fig
206
-
207
- def valid(default_value, *items):
208
- """Get first non-None item"""
209
- for item in items:
210
- if item is not None:
211
- return item
212
- return default_value
213
-
214
- def dict_to(d, device, dtype=dtype):
215
- return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
216
- for k, v in d.items()}
217
-
218
- def exists(v):
219
- return v is not None
220
-
221
- def default(v, b):
222
- return v if exists(v) else b
223
-
224
- class Conv1d(nn.Conv1d):
225
- def _conv_forward(
226
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
227
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
228
-
229
- class Conv2d(nn.Conv2d):
230
- def _conv_forward(
231
- self, x: Tensor, weight: Tensor, bias) -> Tensor:
232
- return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
233
-
234
- class Linear(nn.Module):
235
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
236
- super(Linear, self).__init__()
237
- self.linear = nn.Linear(in_features, out_features, bias=bias)
238
- init.xavier_uniform_(self.linear.weight)
239
- if bias:
240
- init.zeros_(self.linear.bias)
241
- def forward(self, x: Tensor) -> Tensor:
242
- return self.linear(x)
243
-
244
- class RMSNorm(nn.Module):
245
- def __init__(self, dims: Union[int, Tensor, List, Tuple],
246
- eps = 1e-8, elementwise_affine = True):
247
- super(RMSNorm, self).__init__()
248
- if isinstance(dims, int):
249
- self.normalized_shape = (dims,)
250
- else:
251
- self.normalized_shape = tuple(dims)
252
- self.eps = eps
253
- self.elementwise_affine = elementwise_affine
254
- if self.elementwise_affine:
255
- self.weight = nn.Parameter(torch.empty(self.normalized_shape))
256
- init.ones_(self.weight)
257
- else:
258
- self.register_parameter("weight", None)
259
- def forward(self, x):
260
- return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
261
-
262
- def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
263
- weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
264
- eps: float = 1e-5) -> Tensor:
265
- return F.layer_norm(x, normalized_shape, weight, bias, eps)
266
-
267
- def get_device():
268
- return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
269
-
270
- def get_dtype():
271
- return torch.float32 if torch.cuda.is_available() else torch.float64
272
-
273
- def tox():
274
- return {"device": get_device(), "dtype": get_dtype()}
275
-
276
-
277
-
278
- class Sinusoids(nn.Module):
279
- def __init__(self, length, channels, max_tscale=10000):
280
- super().__init__()
281
- assert channels % 2 == 0
282
- log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
283
- inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
284
- scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
285
- pos1 = torch.sin(scaled_t)
286
- pos2 = torch.cos(scaled_t)
287
- positions = torch.cat([pos1, pos2], dim=1)
288
- self.embedding = nn.Embedding.from_pretrained(positions, freeze=False)
289
- def forward(self, positions):
290
- return self.embedding(positions)
291
-
292
- def sinusoids(length, channels, max_tscale=10000):
293
- assert channels % 2 == 0
294
- log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
295
- inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
296
- scaled_t = torch.arange(length)[:, None] * inv_tscales[None, :]
297
- pos1 = torch.sin(scaled_t)
298
- pos2 = torch.cos(scaled_t)
299
- positions = torch.cat([pos1, pos2], dim=1)
300
- return nn.Parameter(positions.clone())
301
-
302
- def accumulate_phase_mod(f0, t_frame, phi0=0.0):
303
- omega = 2 * torch.pi * f0
304
- dphi = omega * t_frame
305
- phi = torch.cumsum(dphi, dim=0) + phi0
306
- phi = torch.remainder(phi, 2 * torch.pi)
307
- return phi
308
-
309
  class rotary(nn.Module):
310
- def __init__(self, dims, head, max_ctx=1500, radii=True, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None, relative=False, freq_bins=None):
 
311
  super(rotary, self).__init__()
312
  self.use_pbias = use_pbias
313
  self.dims = dims
@@ -316,50 +32,27 @@ class rotary(nn.Module):
316
  self.radii = radii
317
  self.debug = debug
318
  self.counter = 0
319
-
320
  self.axial = axial
 
 
 
 
 
 
321
  if axial and spec_shape is not None:
322
  time_frames, freq_bins = spec_shape
323
  self.time_frames = time_frames
324
  self.freq_bins = freq_bins
 
325
  time_theta = 50.0
326
- time_freqs = 1.0 / (time_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
327
  self.register_buffer('time_freqs', time_freqs)
 
328
  freq_theta = 100.0
329
- freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
330
  self.register_buffer('freq_freqs', freq_freqs)
331
 
332
- self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
333
- theta = (torch.tensor(10000, device=device, dtype=dtype))
334
- self.theta = nn.Parameter(theta, requires_grad=True)
335
- self.theta_values = []
336
-
337
- self.relative = relative
338
- self.freq_bins = freq_bins
339
- self.true2d_dim = (dims // head) // 2
340
- self.omega_t = nn.Parameter(torch.randn(self.true2d_dim))
341
- self.omega_f = nn.Parameter(torch.randn(self.true2d_dim))
342
-
343
- def axial(self, seq_len):
344
- if not self.use_2d_axial:
345
- return None
346
- time_frames = self.time_frames
347
- freq_bins = self.freq_bins
348
- t = torch.arange(seq_len, device=device, dtype=dtype)
349
- t_x = (t % time_frames).float()
350
- t_y = torch.div(t, time_frames, rounding_mode='floor').float()
351
- freqs_x = torch.outer(t_x, self.time_freqs)
352
- freqs_y = torch.outer(t_y, self.freq_freqs)
353
- freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
354
- freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
355
- return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
356
-
357
- def mel_scale_scalar(self, freq: float) -> float:
358
- return 1127.0 * math.log(1.0 + freq / 700.0)
359
-
360
- def mel_scale(self, freq: Tensor) -> Tensor:
361
- return 1127.0 * (1.0 + freq / 700.0).log()
362
-
363
  def pitch_bias(self, f0):
364
  if f0 is None:
365
  return None
@@ -369,13 +62,6 @@ class rotary(nn.Module):
369
  f0_norm.unsqueeze(1)))
370
  return f0_sim.unsqueeze(0).unsqueeze(0)
371
 
372
- def accumulate_phase_mod(self, f0, t_frame, phi0=0.0):
373
- omega = 2 * torch.pi * f0
374
- dphi = omega * t_frame
375
- phi = torch.cumsum(dphi, dim=0) + phi0
376
- phi = torch.remainder(phi, 2 * torch.pi)
377
- return phi
378
-
379
  def theta_freqs(self, theta):
380
  if theta.dim() == 0:
381
  theta = theta.unsqueeze(0)
@@ -400,40 +86,52 @@ class rotary(nn.Module):
400
  return torch.polar(torch.ones_like(freqs), freqs), None
401
 
402
  def check_f0(self, f0, f0t, ctx):
403
- if f0 is not None and f0.dim() == 2:
404
- f0 = f0.squeeze(0)
405
- if f0t is not None and f0t.dim() == 2:
406
- f0t = f0t.squeeze(0)
407
- if f0 is not None and f0.shape[0] == ctx:
408
  return f0
409
- elif f0t is not None and f0t.shape[0] == ctx:
410
  return f0t
411
  else:
412
  return None
413
 
414
- def forward(self, x=None, enc=None, layer=None, feature=None) -> Tensor:
415
- ctx = x
416
- if self.axial and feature == "spectrogram":
417
- freqs_2d = self.axial_freqs(ctx)
418
- if freqs_2d is not None:
419
- return freqs_2d.unsqueeze(0)
420
-
421
- f0 = enc.get("f0") if enc is not None else None
422
- f0t = enc.get("f0t") if enc is not None else None
423
- f0 = self.check_f0(f0, f0t, ctx)
 
 
 
 
424
 
425
- theta = f0 + self.theta if f0 is not None else self.theta
 
 
 
426
 
427
- theta = f0
 
 
 
 
 
 
428
  freqs = self.theta_freqs(theta)
429
-
430
  t = torch.arange(ctx, device=device, dtype=dtype)
431
  freqs = t[:, None] * freqs
432
  freqs, radius = self._apply_radii(freqs, f0, ctx)
433
-
434
- if "radius" in self.debug and self.counter == 10:
435
- print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [ctx] {ctx}")
436
 
 
 
 
 
 
 
 
437
  self.counter += 1
438
  return freqs.unsqueeze(0)
439
 
@@ -450,161 +148,19 @@ class rotary(nn.Module):
450
  x1 = x1.view(orig_shape)
451
  return torch.cat([x1.type_as(x), x2], dim=-1)
452
 
453
-
454
- # @staticmethod
455
- # def apply_rotary(x, freqs):
456
- # # x: [batch, head, seq, head_dim]
457
- # # freqs: [1, seq, head_dim] or [1, seq, 2*head_dim] for 2D
458
- # if freqs.shape[-1] == x.shape[-1]:
459
- # # 1D rotary
460
- # x1 = x
461
- # orig_shape = x1.shape
462
- # if x1.ndim == 2:
463
- # x1 = x1.unsqueeze(0)
464
- # x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
465
- # x1 = torch.view_as_complex(x1) * freqs
466
- # x1 = torch.view_as_real(x1).flatten(-2)
467
- # x1 = x1.view(orig_shape)
468
- # return x1.type_as(x)
469
- # else:
470
- # # 2D rotary: split x and apply to each axis
471
- # head_dim = x.shape[-1] // 2
472
- # x_time = x[..., :head_dim]
473
- # x_freq = x[..., head_dim:]
474
- # f_time = freqs[..., :head_dim]
475
- # f_freq = freqs[..., head_dim:]
476
- # # Apply rotary to each axis
477
- # def apply_axis(xa, freqs):
478
- # orig_shape = xa.shape
479
- # xa = xa.float().reshape(*xa.shape[:-1], -1, 2).contiguous()
480
- # xa = torch.view_as_complex(xa) * freqs
481
- # xa = torch.view_as_real(xa).flatten(-2)
482
- # xa = xa.view(orig_shape)
483
- # return xa.type_as(x)
484
- # x_time = apply_axis(x_time, f_time)
485
- # x_freq = apply_axis(x_freq, f_freq)
486
- # return torch.cat([x_time, x_freq], dim=-1)
487
-
488
- # def true2d_relative_angle(self, t_q, f_q, t_k, f_k):
489
- # # t_q, f_q, t_k, f_k: [seq]
490
- # delta_t = t_q[:, None] - t_k[None, :] # [seq, seq]
491
- # delta_f = f_q[:, None] - f_k[None, :] # [seq, seq]
492
- # angle = delta_t[..., None] * self.omega_t + delta_f[..., None] * self.omega_f # [seq, seq, true2d_dim]
493
- # angle = torch.cat([angle, angle], dim=-1) # [seq, seq, head_dim]
494
- # return angle
495
-
496
- # def true2d_apply_rotary(self, q, k, freqs):
497
- # # q, k: [batch, head, seq, head_dim]
498
- # # freqs: [seq, seq, head_dim//2] complex, or [seq, seq, head_dim] if you want
499
- # b, h, seq, d = q.shape
500
- # d2 = d // 2
501
- # q_exp = q.unsqueeze(3).expand(b, h, seq, seq, d)
502
- # k_exp = k.unsqueeze(2).expand(b, h, seq, seq, d)
503
- # # Convert to complex
504
- # def to_complex(x):
505
- # return torch.complex(x[..., 0::2], x[..., 1::2]) # [b, h, seq, seq, d2]
506
- # q_c = to_complex(q_exp)
507
- # k_c = to_complex(k_exp)
508
- # # Multiply by freqs (which should be [seq, seq, d2] complex)
509
- # q_rot = q_c * freqs
510
- # k_rot = k_c * freqs
511
- # # Back to real
512
- # def to_real(x):
513
- # return torch.stack([x.real, x.imag], dim=-1).flatten(-2)
514
- # q_rot = to_real(q_rot)
515
- # k_rot = to_real(k_rot)
516
- # return q_rot, k_rot
517
-
518
-
519
- def parallel_slice(self, q, k, v, mask=None):
520
- batch, head, ctx, dims = q.shape
521
- head_dim = self.head_dim
522
- batch, ctx, dims = q.shape
523
- ctx_len = k.shape[1]
524
- head = dims // head_dim
525
- scores = torch.zeros(batch, head, ctx, ctx_len, device=q.device)
526
- for h in range(head):
527
- start_idx = h * head_dim
528
- end_idx = start_idx + head_dim
529
- q_h = q[:, :, start_idx:end_idx]
530
- k_h = k[:, :, start_idx:end_idx]
531
- scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
532
- if mask is not None:
533
- scores = scores + mask.unsqueeze(0).unsqueeze(0)
534
- attn_weights = F.softmax(scores, dim=-1)
535
- output = torch.zeros_like(q)
536
- for h in range(head):
537
- start_idx = h * head_dim
538
- end_idx = start_idx + head_dim
539
- v_h = v[:, :, start_idx:end_idx]
540
- output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
541
- return output
542
-
543
- class curiosity(nn.Module):
544
- def __init__(self, d, h, bias=True):
545
- super().__init__()
546
- self.h = h
547
- self.dh = d // h
548
- self.qkv = nn.Linear(d, d * 3, bias=bias)
549
- self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
550
- self.o = nn.Linear(d, d, bias=bias)
551
- self.g = nn.Parameter(torch.zeros(h))
552
-
553
- def split(self, x):
554
- b, t, _ = x.shape
555
- return x.view(b, t, self.h, self.dh).transpose(1, 2)
556
-
557
- def merge(self, x):
558
- b, h, t, dh = x.shape
559
- return x.transpose(1, 2).contiguous().view(b, t, h * dh)
560
-
561
- def forward(self, x, xa, mask=None):
562
- q, k, v = self.qkv(x).chunk(3, -1)
563
- qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
564
- q, k, v = map(self.split, (q, k, v))
565
- qa, ka, va = map(self.split, (qa, ka, va))
566
- dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
567
- dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
568
- if mask is not None: dots = dots.masked_fill(mask, -9e15)
569
- p = dots.softmax(-1)
570
- pa = dots_aux.softmax(-1)
571
- h_main = p @ v
572
- h_aux = pa @ va
573
- g = torch.sigmoid(self.g).view(1, -1, 1, 1)
574
- out = self.merge(h_main * (1 - g) + h_aux * g)
575
- return self.o(out)
576
-
577
- class OneShot(nn.Module):
578
- def __init__(self, dims: int, head: int, scale: float = 0.3):
579
- super().__init__()
580
- self.head = head
581
- self.hdim = dims // head
582
- self.scale = scale
583
- self.q_proj = Linear(dims, dims)
584
- self.k_proj = Linear(dims, dims)
585
-
586
- def forward(self, x: Tensor, guide: Tensor) -> Tensor | None:
587
- B, Q, _ = x.shape
588
- K = guide.size(1)
589
- q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
590
- k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
591
- bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
592
- return bias
593
-
594
  class MultiheadA(nn.Module):
 
 
595
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
596
- zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], use_pbias=False, relative=False, freq_bins=None, radii=True, axial=False, spec_shape=None, rbf=False):
597
-
598
  super(MultiheadA, self).__init__()
 
599
  self.dims = dims
600
  self.head = head
601
  self.head_dim = dims // head
602
  self.debug = debug
603
  self.counter = 0
604
  self.use_pbias = use_pbias
605
- self.relative = relative
606
- self.freq_bins = freq_bins
607
- self.rbf = rbf
608
 
609
  self.q = nn.Linear(dims, dims).to(device, dtype)
610
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
@@ -615,7 +171,8 @@ class MultiheadA(nn.Module):
615
  self.rotary_emb = rotary_emb
616
  self.minz = minz
617
  self.maxz = maxz
618
- self.zero_val = zero_val
 
619
  self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
620
 
621
  if rotary_emb:
@@ -623,10 +180,8 @@ class MultiheadA(nn.Module):
623
  dims=dims,
624
  head=head,
625
  debug=debug,
626
- radii=radii,
627
- relative=relative,
628
- freq_bins=freq_bins,
629
- )
630
  else:
631
  self.rope = None
632
 
@@ -651,7 +206,7 @@ class MultiheadA(nn.Module):
651
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
652
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
653
 
654
- def forward(self, x: Tensor, xa = None, mask = None, enc = None, layer = None, feature=None) -> tuple:
655
 
656
  x = x.to(device, dtype)
657
  if xa is not None:
@@ -670,23 +225,8 @@ class MultiheadA(nn.Module):
670
  q2 = q.shape[2]
671
  k2 = k.shape[2]
672
 
673
- if self.relative and feature == "spectrogram":
674
- seq_len = q2
675
- freq_bins = self.freq_bins
676
- idxs = torch.arange(seq_len, device=q.device)
677
- t_idx = idxs // freq_bins
678
- f_idx = idxs % freq_bins
679
- angle = self.rope.relative(t_idx, f_idx, t_idx, f_idx)
680
- q_rot, k_rot = self.rope.d2rotary(q, k, angle)
681
- scale = (self.dims // self.head) ** -0.25
682
- qk = (q_rot * scale * k_rot * scale).sum(-1)
683
- w = F.softmax(qk, dim=-1).to(q.dtype)
684
- wv = torch.einsum('bhij,bhjd->bhid', w, v.unsqueeze(2).expand(-1, -1, seq_len, -1, -1))
685
- wv = wv.permute(0, 2, 1, 3).flatten(start_dim=2)
686
- return self.o(wv), qk
687
- else:
688
- q = self.rope.apply_rotary(q, (self.rope(x=q2, enc=enc, layer=layer, feature=feature)))
689
- k = self.rope.apply_rotary(k, (self.rope(x=k2, enc=enc, layer=layer, feature=feature)))
690
  else:
691
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
692
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -697,30 +237,34 @@ class MultiheadA(nn.Module):
697
  if self.rbf:
698
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
699
  if self.use_pbias:
700
- pbias = self.rope.pitch_bias(f0 = enc.get("f0", None) if enc is not None else None)
701
  if pbias is not None:
702
  qk = qk + pbias[:,:,:q2,:q2]
703
 
704
- if mask is not None:
705
- mask = mask[:q2, :q2]
706
-
707
  token_ids = k[:, :, :, 0]
708
  zscale = torch.ones_like(token_ids)
709
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
710
  zscale[token_ids.float() == self.pad_token] = fzero
711
 
712
- if xa is not None:
 
 
 
713
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
 
714
  qk = qk * zscale.unsqueeze(-2)
715
  w = F.softmax(qk, dim=-1).to(q.dtype)
716
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
717
-
718
  if "multihead" in self.debug and self.counter % 100 == 0:
719
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
720
  self.counter += 1
721
  return self.o(wv), qk
722
 
723
-
 
 
 
724
 
725
  class t_gate(nn.Module):
726
  def __init__(self, dims, num_types=4, enabled=True):
@@ -788,20 +332,15 @@ class c_gate(nn.Module):
788
  return self.integ(comb)
789
 
790
  class mlp_gate(nn.Module):
791
- def __init__(self, dims, head, enabled=True, one_shot=False):
792
  super().__init__()
793
  self.enabled = enabled
794
  if enabled:
795
  self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
796
 
797
- if one_shot:
798
- self.one_shot = OneShot(dims, head)
799
-
800
- def forward(self, x, xa=None):
801
  if not self.enabled:
802
  return None
803
- if self.one_shot:
804
- x = self.one_shot(x, xa)
805
  return self.gate(x)
806
 
807
  class Residual(nn.Module):
@@ -823,7 +362,7 @@ class Residual(nn.Module):
823
  self.blend = nn.Parameter(torch.tensor(0.5))
824
  act_fn = get_activation(act)
825
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
826
- self.one_shot = OneShot(dims, head) if one_shot else None
827
 
828
  if not any([tgate, mgate, cgate]):
829
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
@@ -842,10 +381,10 @@ class Residual(nn.Module):
842
  self.lnb = RMSNorm(dims)
843
  self.lnc = RMSNorm(dims)
844
 
845
- def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature=None) -> Tensor:
846
 
847
  b = torch.sigmoid(self.blend)
848
- ax = x + self.attn(self.lna(x), xa=xa, mask=mask, enc=enc, layer=layer, feature=feature)[0]
849
  bx = b * ax + (1 - b) * x
850
  cx = self.lnb(bx)
851
  dx = self.mlp(cx)
@@ -854,8 +393,85 @@ class Residual(nn.Module):
854
  gx = self.lnc(fx)
855
  return gx
856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
  class FEncoder(nn.Module):
858
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
859
  super().__init__()
860
 
861
  self.head = head
@@ -863,54 +479,67 @@ class FEncoder(nn.Module):
863
  self.dropout = 0.01
864
  self.use_rope = use_rope
865
  self.dims = dims
866
-
867
  act_fn = get_activation(act)
868
-
 
 
 
 
 
 
 
 
 
 
 
 
869
  self.encoder = nn.Sequential(
870
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
871
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
872
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
873
-
874
  if use_rope:
875
  if spec_shape is not None:
876
- self.rope = rotary(
877
- dims=dims,
878
- head=head,
879
- use_2d_axial=True,
880
- spec_shape=spec_shape, debug=[])
881
- else:
882
- self.rope = rotary(
883
- dims=dims,
884
- head=head,
885
- use_2d_axial=False, debug=[])
886
  else:
887
  self.rope = None
888
- self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
889
-
890
  self.norm = RMSNorm(dims)
891
 
892
- def apply_rope_to_features(self, x, layer="FEncoder", feature="spectrogram"):
893
  batch, ctx, dims = x.shape
894
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
895
- if feature == "spectrogram" and self.rope is not None:
896
- rope_freqs = self.rope(ctx, layer=layer, feature="spectrogram")
897
- else:
898
- rope_freqs = self.rope(ctx, layer=layer, feature="audio")
899
- x = self.rope.apply_rotary(x, rope_freqs)
900
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
 
901
  return x
902
 
903
- def forward(self, x, enc=None, feature="spectrogram", layer="FEncoder"):
904
  x = self.encoder(x).permute(0, 2, 1)
905
  if self.use_rope:
906
- x = self.apply_rope_to_features(x, layer=layer, feature=feature)
907
  else:
908
- x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
909
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
910
- return self.norm(x)
 
911
 
912
  class WEncoder(nn.Module):
913
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
914
  super().__init__()
915
 
916
  self.head = head
@@ -918,230 +547,198 @@ class WEncoder(nn.Module):
918
  self.dropout = 0.01
919
  self.use_rope = use_rope
920
  self.dims = dims
921
-
922
  act_fn = get_activation(act)
923
- self.downsample = nn.Sequential(
924
- Conv1d(input_dims, dims//8, kernel_size=15, stride=8, padding=7), act_fn,
925
- Conv1d(dims//8, dims//4, kernel_size=7, stride=4, padding=3), act_fn,
926
- Conv1d(dims//4, dims, kernel_size=9, stride=5, padding=4), act_fn)
927
-
928
  self.encoder = nn.Sequential(
929
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
930
- Conv1d(dims, dims, kernel_size=1), act_fn)
 
 
931
  if use_rope:
932
- self.rope = rotary(
933
- dims=self.head_dim,
934
- head=self.head,
935
- debug=[])
936
  else:
937
  self.rope = None
938
- self.sinusoid_pos = lambda length, dims: sinusoids(length, dims, max_tscale=10000)
939
  self.norm = RMSNorm(dims)
940
 
941
- def apply_rope_to_features(self, x, layer="WEncoder", feature="waveform"):
942
- if not self.use_rope or self.rope is None:
943
- return x
944
  batch, ctx, dims = x.shape
945
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
946
- rope_freqs = self.rope(ctx, layer=layer, feature=feature)
947
- x = self.rope.apply_rotary(x, rope_freqs)
948
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
949
  return x
950
 
951
- def forward(self, x, enc=None, feature="waveform", layer="WEncoder"):
952
- x = self.downsample(x)
953
- x = self.encoder(x)
954
- x = x.permute(0, 2, 1)
955
  if self.use_rope:
956
- x = self.apply_rope_to_features(x, layer=layer)
957
  else:
958
- x = x + self.sinusoid_pos(x.shape[1], x.shape[-1]).to(x.device, x.dtype)
959
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
 
 
960
  return self.norm(x)
961
 
962
  class PEncoder(nn.Module):
963
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, one_shot=False):
964
  super().__init__()
965
 
966
  self.head = head
967
  self.head_dim = dims // head
 
968
  self.dropout = 0.01
969
  self.use_rope = use_rope
970
- self.dims = dims
971
- self.one_shot = one_shot
972
  act_fn = get_activation(act)
973
-
974
  self.encoder = nn.Sequential(
975
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=1, padding=kernel_size//2), act_fn,
976
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
977
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
978
 
979
-
980
  if use_rope:
981
- self.rope = rotary(
982
- dims=self.head_dim,
983
- head=self.head,
984
- debug=[])
985
  else:
986
  self.rope = None
987
- self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
988
  self.norm = RMSNorm(dims)
989
-
990
- def apply_rope_to_features(self, x, layer="PEncoder", feature="pitch"):
991
- if not self.use_rope or self.rope is None:
992
- return x
993
  batch, ctx, dims = x.shape
994
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
995
- rope_freqs = self.rope(ctx, layer=layer, feature=feature)
996
- x = self.rope.apply_rotary(x, rope_freqs)
997
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
998
  return x
999
 
1000
- def forward(self, xa, enc=None, layer="PEncoder", feature="pitch"):
1001
- xa = self.encoder(xa).permute(0, 2, 1)
 
 
 
 
1002
  if self.use_rope:
1003
- xa = self.apply_rope_to_features(xa, layer=layer)
1004
  else:
1005
- xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1], 10000).to(xa.device, xa.dtype)
1006
- if self.one_shot:
1007
- x = enc["input_ids"]
1008
- xa = self.one_shot(x, xa)
1009
- xa = nn.functional.dropout(xa, p=self.dropout, training=self.training)
1010
- return self.norm(xa)
1011
-
1012
- def win_mask(text_ctx, aud_ctx):
1013
- mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device), diagonal=0)
1014
- audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device))
1015
- full_mask = torch.cat([mask, audio_mask], dim=-1)
1016
- return full_mask.unsqueeze(0).unsqueeze(0)
1017
-
1018
- def causal_mask(seq_len, device):
1019
- return torch.tril(torch.ones(seq_len, seq_len, device=device), diagonal=0).unsqueeze(0).unsqueeze(0)
1020
 
1021
  class theBridge(nn.Module):
1022
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
1023
  debug: List[str], features: List[str], act: str = "gelu"):
1024
  super(theBridge, self).__init__()
1025
 
1026
- self.ctx = ctx
1027
- self.dims = dims
1028
- self.head = head
1029
- self.head_dim = dims // head
1030
  self.debug = debug
1031
  self.counter = 0
1032
  self.dropout = 0.01
1033
  self.features = features
1034
  self.do_blend = "no_blend" not in self.debug
1035
  self.sequential = "sequential" in self.debug
 
1036
 
1037
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
1038
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
1039
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
1040
- self.ln_dec = RMSNorm(dims)
1041
- self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
1042
 
1043
  with torch.no_grad():
1044
  self.token.weight[0].zero_()
1045
 
1046
- self.block = nn.ModuleList([
1047
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1048
- for _ in range(layer)])
1049
-
1050
- self.cross_attn = nn.ModuleList([
1051
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1052
- for _ in range(layer)])
1053
-
1054
- self.cross_modal = nn.ModuleList([
1055
- Residual(ctx=ctx, dims=dims, head=head, act="gelu", debug=debug, features=features)
1056
- for _ in range(layer)])
1057
-
1058
- self.register_buffer("mask", causal_mask(ctx, device), persistent=False)
1059
- self.register_buffer("mask_win", win_mask(ctx, ctx), persistent=False)
1060
-
1061
  act_fn = get_activation(act)
1062
  if features == ["spectrogram", "waveform", "pitch"]:
1063
  cgate=True
1064
  else:
1065
  cgate = False
1066
-
1067
- self.blockA = nn.ModuleDict({
1068
- "spectrogram": nn.ModuleList(
1069
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1070
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None),
1071
- "waveform": nn.ModuleList(
1072
- [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
1073
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None),
1074
- "pitch": nn.ModuleList(
1075
- [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=3, act=act, one_shot=False)] +
1076
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None),
1077
- "envelope": nn.ModuleList(
1078
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1079
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None),
1080
- "phase": nn.ModuleList(
1081
- [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
1082
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None)})
1083
-
1084
-
1085
-
1086
- def forward(self, x, enc, feature, layer='theBridge') -> Tensor:
1087
- f0 = enc.get("f0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1088
  out = {}
1089
- out.update(enc)
1090
- enc = dict_to(enc, device, dtype)
1091
- _text_len = x.shape[1]
1092
- x = self.token(x) + self.positional[:x.shape[1]]
1093
-
1094
- for f in enc:
1095
- if f in self.features:
1096
- xa = enc[f]
1097
- for block in self.blockA[f]:
1098
- xa = block(xa, enc=out, feature=feature, layer="enc_self")
1099
- xa = xa + self.sinusoid_pos(xa.shape[1], xa.shape[-1], 10000).to(xa.device, xa.dtype)
1100
- out[f] = xa
1101
-
1102
- for block in self.block:
1103
- x = block(x, xa=None, mask=self.mask, enc=enc, feature=feature, layer="dec_self")
1104
- out["input_ids"] = x
1105
-
1106
- if f in self.features:
1107
- out = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
1108
- if self.sequential:
1109
- x = out
1110
- else:
1111
- a = torch.sigmoid(self.blend)
1112
- x = a * out + (1 - a) * x
1113
- x = self.token(x) + self.positional[:x.shape[1]]
1114
- out[f] = x
1115
-
1116
- for block in self.cross_attn:
1117
- if f in self.features:
1118
- x = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
1119
- xa = block(xa, xa=x, mask=self.mask, enc=enc, feature=feature, layer="enc_cross")
1120
- out = block(x, xa=xa, mask=self.mask, enc=enc, feature=feature, layer="dec_cross")
1121
- if self.sequential:
1122
- x = out
1123
- else:
1124
- a = torch.sigmoid(self.blend)
1125
- x = a * out + (1 - a) * x
1126
- x = self.token(x) + self.positional[:x.shape[1]]
1127
- out[f] = x
1128
-
1129
- for block in self.cross_modal:
1130
- if f in self.features:
1131
- xcat = torch.cat([x, xa], dim=1)
1132
- x = block(xcat, xa=None, mask=self.mask, enc=enc, feature=feature, layer="cross_modal")
1133
- x = x[:, :_text_len]
1134
- out[f] = x
1135
 
1136
  if self.counter < 1 and "encoder" in self.debug:
1137
- shapes = {k: v.shape for k, v in enc.items()}
1138
- print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
1139
  self.counter += 1
1140
-
1141
- x = self.ln_dec(x)
1142
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1143
- return x, out
1144
 
 
 
1145
  class Echo(nn.Module):
1146
  def __init__(self, param: Dimensions):
1147
  super().__init__()
@@ -1169,32 +766,33 @@ class Echo(nn.Module):
1169
  f0t: Optional[torch.Tensor]=None,
1170
  harmonic: Optional[torch.Tensor]=None,
1171
  aperiodic: Optional[torch.Tensor]=None,
 
1172
  ) -> Dict[str, Optional[torch.Tensor]]:
1173
 
1174
- enc = {}
1175
- if spectrogram is not None:
1176
- enc["spectrogram"] = spectrogram
1177
- feature = "spectrogram"
1178
- if waveform is not None:
1179
- enc["waveform"] = waveform
1180
- feature = "waveform"
1181
- if pitch is not None:
1182
- enc["pitch"] = pitch
1183
- feature = "pitch"
1184
  if f0 is not None:
1185
- enc["f0"] = f0
1186
  if f0t is not None:
1187
- enc["f0t"] = f0t
1188
  if harmonic is not None:
1189
- enc["harmonic"] = harmonic
1190
  if aperiodic is not None:
1191
- enc["aperiodic"] = aperiodic
1192
- if input_ids is not None:
1193
- enc["input_ids"] = input_ids
1194
- feature = "input_ids"
 
 
 
 
 
 
 
 
1195
 
1196
- logits, out = self.processor(input_ids, enc, feature)
1197
- self.out = out
1198
 
1199
  loss = None
1200
  if labels is not None:
@@ -1214,7 +812,7 @@ class Echo(nn.Module):
1214
  std = 0.02
1215
  self.init_counts = {
1216
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
1217
- "Conv2d": 0, "SEBlock": 0, "SpeechTransformer": 0,
1218
  "Residual": 0, "MultiheadA": 0,
1219
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
1220
  "WEncoder": 0, "PEncoder": 0}
@@ -1243,7 +841,17 @@ class Echo(nn.Module):
1243
  self.init_counts["MultiheadA"] += 1
1244
  elif isinstance(module, Residual):
1245
  self.init_counts["Residual"] += 1
1246
-
 
 
 
 
 
 
 
 
 
 
1247
  def init_weights(self):
1248
  print("Initializing model weights...")
1249
  self.apply(self._init_weights)
@@ -1307,412 +915,117 @@ class Echo(nn.Module):
1307
  })
1308
  return Config()
1309
 
1310
- def setup_tokenizer(token: str):
1311
- from tokenizers import Tokenizer
1312
- tokenizer = Tokenizer.from_file("./tokenizer.json")
1313
- orig_encode = tokenizer.encode
1314
- def enc(text, add_special_tokens=True):
1315
- ids = orig_encode(text).ids
1316
- if not add_special_tokens:
1317
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1318
- ids = [id for id in ids if id not in sp_ids]
1319
- return ids
1320
-
1321
- def bdec(ids_list, skip_special_tokens=True, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1322
- results = []
1323
- for ids in ids_list:
1324
- if isinstance(ids, torch.Tensor):
1325
- ids = ids.tolist()
1326
- ids = [int(id) for id in ids if id != -100]
1327
- if skip_special_tokens:
1328
- ids = [id for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id)]
1329
-
1330
- if ids and ids and ids[0] == bos_token_id:
1331
- ids = ids[1:]
1332
- while ids and ids[-1] == eos_token_id:
1333
- ids = ids[:-1]
1334
- results.append(tokenizer.decode(ids))
1335
- return results
1336
-
1337
- def save_pretrained(save_dir):
1338
- os.makedirs(save_dir, exist_ok=True)
1339
- tokenizer.save(f"{save_dir}/tokenizer.json")
1340
- tokenizer.encode = enc
1341
- tokenizer.batch_decode = bdec
1342
- tokenizer.save_pretrained = save_pretrained
1343
- tokenizer.pad_token_id = 0
1344
- tokenizer.bos_token_id = 1
1345
- tokenizer.eos_token_id = 2
1346
- return tokenizer
1347
-
1348
- def tokenize_pitch(pitch_features, target_length):
1349
- pitch_len = pitch_features.shape[-1]
1350
- token_len = target_length
1351
- if pitch_len > token_len:
1352
- pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
1353
- else:
1354
- pitch_tokens = F.interpolate(pitch_features, token_len)
1355
- return pitch_tokens
1356
-
1357
- def load_wave(wave_data, sample_rate):
1358
- if isinstance(wave_data, str):
1359
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1360
- elif isinstance(wave_data, dict):
1361
- waveform = torch.tensor(data=wave_data["array"]).float()
1362
- sr = wave_data["sampling_rate"]
1363
- else:
1364
- raise TypeError("Invalid wave_data format.")
1365
-
1366
- return waveform
1367
-
1368
- def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
1369
- import librosa
1370
- mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
1371
- mel_basis = torch.from_numpy(mel_basis).float()
1372
- sp_mel = torch.matmul(sp, mel_basis.T)
1373
- ap_mel = torch.matmul(ap, mel_basis.T)
1374
- return sp_mel, ap_mel
1375
-
1376
- def extract_features(batch, tokenizer, waveform=False, spec=False, f0=True, f0t=True, pitch=True, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, **dataset_config):
1377
- dataset_config = {
1378
- "hop_length": 256,
1379
- "f_min": 150,
1380
- "f_max": 2000,
1381
- "n_mels": 128,
1382
- "n_fft": 1024,
1383
- "sample_rate": 16000,
1384
- "pad_mode": "constant",
1385
- "center": True,
1386
- "power": 1.0,
1387
- "window_fn": torch.hann_window,
1388
- "mel_scale": "htk",
1389
- "norm": None,
1390
- "normalized": False,
1391
- }
1392
-
1393
- audio = batch["audio"]
1394
- sr = audio["sampling_rate"]
1395
- labels = tokenizer.encode(batch["transcription"])
1396
-
1397
- wav = wavnp = f0_np = t = None
1398
- spectrogram = f0_tensor = f0t_tensor = harmonic = aperiodic = p_tensor = None
1399
-
1400
- if waveform or spec or f0 or f0t or harmonics or pitch:
1401
- wav = load_wave(wave_data=audio, sample_rate=sr)
1402
- wavnp = wav.numpy().astype(np.float64)
1403
-
1404
- if spec:
1405
- transform = torchaudio.transforms.MelSpectrogram(**dataset_config)
1406
- mel_spectrogram = transform(wav)
1407
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1408
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1409
- spectrogram = (log_mel + 4.0) / 4.0
1410
- spectrogram = torch.tensor(spectrogram)
1411
-
1412
- if f0 or f0t or harmonics or pitch:
1413
- f0_np, t = pw.dio(wavnp, sample_rate,
1414
- frame_period=hop_length / sample_rate * 1000)
1415
- f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
1416
- t = torch.tensor(t)
1417
-
1418
- if f0:
1419
- f0_tensor = torch.from_numpy(f0_np)
1420
- t_frame = torch.mean(t[1:] - t[:-1])
1421
- f0_tensor = accumulate_phase_mod(f0_tensor, t_frame)
1422
-
1423
- if f0t:
1424
- audio_duration = len(wavnp) / sample_rate
1425
- T = len(labels)
1426
- tok_dur_sec = audio_duration / T
1427
- token_starts = torch.arange(T) * tok_dur_sec
1428
- token_ends = token_starts + tok_dur_sec
1429
- start_idx = torch.searchsorted(t, token_starts, side="left")
1430
- end_idx = torch.searchsorted(t, token_ends, side="right")
1431
- pitch_tok = torch.zeros(T, dtype=torch.float32)
1432
- for i in range(T):
1433
- lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
1434
- segment = f0_np[lo:hi]
1435
- if mode == "mean":
1436
- pitch_tok[i] = segment.mean()
1437
- elif mode == "median":
1438
- pitch_tok[i] = torch.median(segment)
1439
- else:
1440
- pitch_tok[i] = segment[-1]
1441
- pitch_tok[pitch_tok < 100.0] = 0.0
1442
- bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1443
- f0t_tensor = torch.from_numpy(np.concatenate([[bos_pitch], pitch_tok]))
1444
- f0t_tensor = accumulate_phase_mod(f0t_tensor, t_frame)
1445
-
1446
- if pitch:
1447
- p_tensor = torch.from_numpy(f0_np)
1448
- p_tensor = p_tensor.unsqueeze(0)
1449
-
1450
- if harmonics:
1451
- spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
1452
- apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
1453
- harmonic = torch.from_numpy(spnp)
1454
- aperiodic = torch.from_numpy(apnp)
1455
- harmonic = harmonic[:, :128].contiguous().T
1456
- aperiodic = aperiodic[:, :128].contiguous().T
1457
- harmonic = torch.where(harmonic == 0.0, torch.zeros_like(harmonic), harmonic / 1.0)
1458
- aperiodic = torch.where(aperiodic == 0.0, torch.zeros_like(aperiodic), aperiodic / 1.0)
1459
-
1460
- if debug:
1461
- print(f"['f0']: {f0_tensor.shape if f0 is not None else None}")
1462
- print(f"['f0t']: {f0t_tensor.shape if f0t is not None else None}")
1463
- print(f"['harmonic']: {harmonic.shape if harmonic is not None else None}")
1464
- print(f"['aperiodic']: {aperiodic.shape if aperiodic is not None else None}")
1465
- print(f"['spectrogram']: {spectrogram.shape if spectrogram is not None else None}")
1466
- print(f"['waveform']: {wav.shape if wav is not None else None}")
1467
- print(f"['labels']: {len(labels) if labels is not None else None}")
1468
-
1469
- return {
1470
- "waveform": wav if waveform else None,
1471
- "spectrogram": spectrogram if spec else None,
1472
- "f0": f0_tensor if f0 else None,
1473
- "f0t": f0t_tensor if f0t else None,
1474
- "pitch": p_tensor if pitch else None,
1475
- "harmonic": harmonic if harmonics else None,
1476
- "aperiodic": aperiodic if harmonics else None,
1477
- "labels": labels,
1478
- }
1479
-
1480
- def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False, **dataset_config):
1481
-
1482
- if sanity_check:
1483
- test = load_dataset(
1484
- "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True
1485
- ).cast_column("audio", Audio(sampling_rate=sample_rate)).take(10)
1486
- dataset = test.map(
1487
- lambda x: extract_features(x, tokenizer, **dataset_config),
1488
- remove_columns=test.column_names)
1489
-
1490
- train_dataset = dataset
1491
- test_dataset = dataset
1492
- return train_dataset, test_dataset
1493
- else:
1494
-
1495
- cache_dir = "./processed_datasets"
1496
- os.makedirs(cache_dir, exist_ok=True)
1497
- cache_file_train = os.path.join(cache_dir, "train.arrow")
1498
- cache_file_test = os.path.join(cache_dir, "test.arrow")
1499
-
1500
- if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1501
- from datasets import Dataset
1502
- train_dataset = Dataset.load_from_disk(cache_file_train)
1503
- test_dataset = Dataset.load_from_disk(cache_file_test)
1504
- return train_dataset, test_dataset
1505
-
1506
- def filter_func(x):
1507
- return (0 < len(x["transcription"]) < 2048 and
1508
- len(x["audio"]["array"]) > 0 and
1509
- len(x["audio"]["array"]) < 2048 * 160)
1510
-
1511
- raw_train = load_dataset(
1512
- "google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
1513
- raw_test = load_dataset(
1514
- "google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
1515
-
1516
- raw_train = raw_train.filter(filter_func)
1517
- raw_test = raw_test.filter(filter_func)
1518
-
1519
- raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
1520
- raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
1521
-
1522
- train_dataset = raw_train.map(
1523
- lambda x: extract_features(x, tokenizer, **dataset_config),
1524
- remove_columns=raw_train.column_names)
1525
- test_dataset = raw_test.map(
1526
- lambda x: extract_features(x, tokenizer, **dataset_config),
1527
- remove_columns=raw_test.column_names)
1528
-
1529
- train_dataset.save_to_disk(cache_file_train) if sanity_check is False else None
1530
- test_dataset.save_to_disk(cache_file_test) if sanity_check is False else None
1531
- return train_dataset, test_dataset
1532
-
1533
- @dataclass
1534
- class DataCollator:
1535
- tokenizer: Any
1536
-
1537
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1538
- all_keys = set()
1539
- for f in features:
1540
- all_keys.update(f.keys())
1541
- batch = {}
1542
- pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1543
- bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1544
- eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1545
-
1546
- for key in all_keys:
1547
- if key == "labels":
1548
- labels_list = [f["labels"] for f in features]
1549
- max_len = max(len(l) for l in labels_list)
1550
- all_ids, all_labels = [], []
1551
- for label in labels_list:
1552
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1553
- decoder_input = [bos_token_id] + label_list
1554
- label_eos = label_list + [eos_token_id]
1555
- input_len = max_len + 1 - len(decoder_input)
1556
- label_len = max_len + 1 - len(label_eos)
1557
- padded_input = decoder_input + [pad_token_id] * input_len
1558
- padded_labels = label_eos + [pad_token_id] * label_len
1559
- all_ids.append(padded_input)
1560
- all_labels.append(padded_labels)
1561
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1562
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1563
-
1564
- elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0"]:
1565
- items = [f[key] for f in features if key in f]
1566
- items = [item for item in items if item is not None]
1567
- if not items:
1568
- continue
1569
- items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
1570
- max_len = max(item.shape[-1] for item in items)
1571
- padded = []
1572
- for item in items:
1573
- pad_width = max_len - item.shape[-1]
1574
- if pad_width > 0:
1575
- pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1576
- else:
1577
- pad_item = item
1578
- padded.append(pad_item)
1579
- batch[key] = torch.stack(padded)
1580
- if key == "spectrogram":
1581
- batch["spectrogram"] = batch[key]
1582
- return batch
1583
-
1584
- def levenshtein(reference_words, hypothesis_words):
1585
- m, n = len(reference_words), len(hypothesis_words)
1586
- dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
1587
- for i in range(m+1):
1588
- dist_matrix[i][0] = i
1589
- for j in range(n+1):
1590
- dist_matrix[0][j] = j
1591
- for i in range(1, m+1):
1592
- for j in range(1, n+1):
1593
- if reference_words[i-1] == hypothesis_words[j-1]:
1594
- dist_matrix[i][j] = dist_matrix[i-1][j-1]
1595
- else:
1596
- substitution = dist_matrix[i-1][j-1] + 1
1597
- insertion = dist_matrix[i][j-1] + 1
1598
- deletion = dist_matrix[i-1][j] + 1
1599
- dist_matrix[i][j] = min(substitution, insertion, deletion)
1600
- return dist_matrix[m][n]
1601
-
1602
- def wer_batch(references, hypotheses):
1603
- total_errors = 0
1604
- total_words = 0
1605
- for ref, hyp in zip(references, hypotheses):
1606
- ref_words = ref.lower().split()
1607
- errors = levenshtein(ref_words, hyp.lower().split())
1608
- total_errors += errors
1609
- total_words += len(ref_words)
1610
- return (total_errors / total_words) * 100 if total_words > 0 else 0.0
1611
-
1612
- def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1613
- if isinstance(ids, torch.Tensor):
1614
- ids = ids.tolist()
1615
- return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
1616
-
1617
- def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
1618
- return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
1619
-
1620
- def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, optimizer=None, scheduler=None):
1621
-
1622
- label_ids = pred.label_ids
1623
- pred_ids = pred.predictions[0]
1624
-
1625
- label_ids = clean_batch(label_ids, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
1626
- pred_ids = clean_batch(pred_ids, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
1627
-
1628
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1629
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1630
-
1631
- if print_pred:
1632
- for i in range(min(num_samples, len(pred_ids))):
1633
- print(f"Pred tokens: {pred_ids[i]}")
1634
- print(f"Label tokens: {label_ids[i]}")
1635
- print(f"Pred: '{pred_str[i]}'")
1636
- print(f"Label: '{label_str[i]}'")
1637
- print("-" * 40)
1638
-
1639
- wer = wer_batch(label_str, pred_str)
1640
- if model is not None:
1641
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000
1642
- efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0
1643
- else:
1644
- trainable_params = 0.0
1645
- efficiency_score = 0.0
1646
- return {
1647
- "wer": float(wer),
1648
- "efficiency_score": float(efficiency_score),
1649
- }
1650
-
1651
- def preprocess_logits_for_metrics(logits, labels):
1652
- pred_ids = torch.argmax(logits, dim=-1)
1653
- labels = torch.where(labels == -100, 0, labels)
1654
- pred_ids = torch.where(pred_ids == -100, 0, pred_ids)
1655
-
1656
- return pred_ids, labels
1657
-
1658
  def main():
1659
  token = ""
1660
  log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
1661
  os.makedirs(log_dir, exist_ok=True)
1662
- tokenizer = setup_tokenizer(token)
1663
- train_dataset, test_dataset = prepare_datasets(
1664
- tokenizer,
1665
- token,
1666
- sanity_check=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1667
 
1668
- )
1669
-
1670
  param = Dimensions(
1671
  vocab=40000,
1672
  mels=128,
1673
- ctx=1500,
1674
  dims=512,
1675
  head=4,
1676
  layer=4,
1677
  act="swish",
1678
- debug={"radius", "encoder"},
1679
- features = ["pitch"],
1680
  )
1681
-
 
 
 
1682
  model = Echo(param).to('cuda')
1683
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
1684
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
1685
-
1686
- training_args = Seq2SeqTrainingArguments(
1687
- output_dir=log_dir,
1688
- per_device_train_batch_size=1,
1689
- per_device_eval_batch_size=1,
1690
- max_steps=1000,
1691
- eval_steps=100,
1692
- save_steps=1000,
1693
- warmup_steps=100,
1694
- logging_steps=10,
1695
- logging_dir=log_dir,
1696
- eval_strategy="steps",
1697
- save_strategy="no",
1698
- report_to=["tensorboard"],
1699
- push_to_hub=False,
1700
- disable_tqdm=False,
1701
- save_total_limit=1,
1702
- label_names=["labels"],
1703
- save_safetensors=False,
1704
- eval_on_start=True,
1705
- batch_eval_metrics=False,
1706
- )
1707
  from functools import partial
1708
  metrics_fn = partial(compute_metrics,
1709
- print_pred=False,
1710
- num_samples=2,
1711
  tokenizer=tokenizer, model=model)
1712
 
1713
- optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025, eps=1e-8, weight_decay=0.025, betas=(0.9, 0.999),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1714
  amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
1715
-
1716
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
1717
 
1718
  trainer = Seq2SeqTrainer(
@@ -1728,24 +1041,7 @@ def main():
1728
 
1729
  model.init_weights()
1730
  trainer.train()
1731
-
1732
  if __name__ == "__main__":
1733
- main()
1734
-
1735
-
1736
 
1737
-
1738
-
1739
-
1740
-
1741
-
1742
-
1743
-
1744
-
1745
-
1746
-
1747
-
1748
-
1749
-
1750
-
1751
 
 
1
  import os
 
2
  import math
3
  import warnings
4
  import logging
5
+ from itertools import chain
6
  import torch
 
7
  import torch.nn.functional as F
 
8
  from torch import nn, Tensor
9
+ from tensordict import TensorDict
10
+ from typing import Optional, Dict, Union, List, Tuple
 
11
  import numpy as np
12
  from functools import partial
13
  from datetime import datetime
14
+ from tensordict import TensorDict
15
  from transformers.trainer_seq2seq import Seq2SeqTrainer
16
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
17
+ from echoutils import *
 
 
 
 
 
 
 
18
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
  dtype = torch.float32
 
21
  warnings.filterwarnings("ignore")
22
  logging.basicConfig(level=logging.ERROR)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class rotary(nn.Module):
25
+ def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
26
+
27
  super(rotary, self).__init__()
28
  self.use_pbias = use_pbias
29
  self.dims = dims
 
32
  self.radii = radii
33
  self.debug = debug
34
  self.counter = 0
35
+ self.last_theta = None
36
  self.axial = axial
37
+
38
+ self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
39
+ theta = (torch.tensor(10000, device=device, dtype=dtype))
40
+ self.theta = nn.Parameter(theta, requires_grad=True)
41
+ self.theta_values = []
42
+
43
  if axial and spec_shape is not None:
44
  time_frames, freq_bins = spec_shape
45
  self.time_frames = time_frames
46
  self.freq_bins = freq_bins
47
+
48
  time_theta = 50.0
49
+ time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
50
  self.register_buffer('time_freqs', time_freqs)
51
+
52
  freq_theta = 100.0
53
+ freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
54
  self.register_buffer('freq_freqs', freq_freqs)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def pitch_bias(self, f0):
57
  if f0 is None:
58
  return None
 
62
  f0_norm.unsqueeze(1)))
63
  return f0_sim.unsqueeze(0).unsqueeze(0)
64
 
 
 
 
 
 
 
 
65
  def theta_freqs(self, theta):
66
  if theta.dim() == 0:
67
  theta = theta.unsqueeze(0)
 
86
  return torch.polar(torch.ones_like(freqs), freqs), None
87
 
88
  def check_f0(self, f0, f0t, ctx):
89
+ if f0 is not None and f0.shape[1] == ctx:
 
 
 
 
90
  return f0
91
+ elif f0t is not None and f0t.shape[1] == ctx:
92
  return f0t
93
  else:
94
  return None
95
 
96
+ def axial_freqs(self, ctx):
97
+ if not self.axial:
98
+ return None
99
+ time_frames = self.time_frames
100
+ freq_bins = self.freq_bins
101
+
102
+ t = torch.arange(ctx, device=device, dtype=dtype)
103
+ t_x = (t % time_frames).float()
104
+ t_y = torch.div(t, time_frames, rounding_mode='floor').float()
105
+ freqs_x = torch.outer(t_x, self.time_freqs)
106
+ freqs_y = torch.outer(t_y, self.freq_freqs)
107
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
108
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
109
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
110
 
111
+ def forward(self, x=None, en=None, f=None, layer=None) -> Tensor:
112
+ ctx=x
113
+ f0 = en.get("f0") if en is not None else None
114
+ f0t = en.get("f0t") if en is not None else None
115
 
116
+ f0 = self.check_f0(f0, f0t, ctx)
117
+ if f0 is not None:
118
+ if f0.dim() == 2:
119
+ f0 = f0.squeeze(0)
120
+ theta = f0 + self.theta
121
+ else:
122
+ theta = self.theta
123
  freqs = self.theta_freqs(theta)
 
124
  t = torch.arange(ctx, device=device, dtype=dtype)
125
  freqs = t[:, None] * freqs
126
  freqs, radius = self._apply_radii(freqs, f0, ctx)
 
 
 
127
 
128
+ if self.axial and f == "spectrogram":
129
+ freqs_2d = self.axial_freqs(ctx)
130
+ if freqs_2d is not None:
131
+ return freqs_2d.unsqueeze(0)
132
+
133
+ if "radius" in self.debug and self.counter == 10:
134
+ print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
135
  self.counter += 1
136
  return freqs.unsqueeze(0)
137
 
 
148
  x1 = x1.view(orig_shape)
149
  return torch.cat([x1.type_as(x), x2], dim=-1)
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  class MultiheadA(nn.Module):
152
+
153
+ rbf = False
154
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
155
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
 
156
  super(MultiheadA, self).__init__()
157
+
158
  self.dims = dims
159
  self.head = head
160
  self.head_dim = dims // head
161
  self.debug = debug
162
  self.counter = 0
163
  self.use_pbias = use_pbias
 
 
 
164
 
165
  self.q = nn.Linear(dims, dims).to(device, dtype)
166
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
 
171
  self.rotary_emb = rotary_emb
172
  self.minz = minz
173
  self.maxz = maxz
174
+ self.zero_val = zero_val
175
+ self.optim_attn = optim_attn
176
  self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
177
 
178
  if rotary_emb:
 
180
  dims=dims,
181
  head=head,
182
  debug=debug,
183
+ radii=False,
184
+ )
 
 
185
  else:
186
  self.rope = None
187
 
 
206
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
207
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
208
 
209
+ def forward(self, x: Tensor, xa = None, mask = None, en= None, layer = None, f=None) -> tuple:
210
 
211
  x = x.to(device, dtype)
212
  if xa is not None:
 
225
  q2 = q.shape[2]
226
  k2 = k.shape[2]
227
 
228
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, en=en, f=f, layer=layer)))
229
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, en=en, f=f, layer=layer)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  else:
231
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
232
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
237
  if self.rbf:
238
  qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
239
  if self.use_pbias:
240
+ pbias = self.rope.pitch_bias(f0 = en.get("f0", None) if en is not None else None)
241
  if pbias is not None:
242
  qk = qk + pbias[:,:,:q2,:q2]
243
 
 
 
 
244
  token_ids = k[:, :, :, 0]
245
  zscale = torch.ones_like(token_ids)
246
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
247
  zscale[token_ids.float() == self.pad_token] = fzero
248
 
249
+ if mask is not None:
250
+ if mask.dim() == 4:
251
+ mask = mask[0, 0]
252
+ mask = mask[:q2, :k2] if xa is not None else mask[:q2, :q2]
253
  qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
254
+
255
  qk = qk * zscale.unsqueeze(-2)
256
  w = F.softmax(qk, dim=-1).to(q.dtype)
257
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
258
+
259
  if "multihead" in self.debug and self.counter % 100 == 0:
260
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
261
  self.counter += 1
262
  return self.o(wv), qk
263
 
264
+ @staticmethod
265
+ def split(X: Tensor) -> (Tensor, Tensor):
266
+ half_dim = X.shape[-1] // 2
267
+ return X[..., :half_dim], X[..., half_dim:]
268
 
269
  class t_gate(nn.Module):
270
  def __init__(self, dims, num_types=4, enabled=True):
 
332
  return self.integ(comb)
333
 
334
  class mlp_gate(nn.Module):
335
+ def __init__(self, dims, head, enabled=True, one_shot=True):
336
  super().__init__()
337
  self.enabled = enabled
338
  if enabled:
339
  self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
340
 
341
+ def forward(self, x, xa=None, f=None):
 
 
 
342
  if not self.enabled:
343
  return None
 
 
344
  return self.gate(x)
345
 
346
  class Residual(nn.Module):
 
362
  self.blend = nn.Parameter(torch.tensor(0.5))
363
  act_fn = get_activation(act)
364
  self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
365
+ self.curiosity = curiosity(dims, head)
366
 
367
  if not any([tgate, mgate, cgate]):
368
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
 
381
  self.lnb = RMSNorm(dims)
382
  self.lnc = RMSNorm(dims)
383
 
384
+ def forward(self, x, xa=None, mask=None, en=None, layer=None, f=None) -> Tensor:
385
 
386
  b = torch.sigmoid(self.blend)
387
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, en=en, layer=layer, f=f)[0]
388
  bx = b * ax + (1 - b) * x
389
  cx = self.lnb(bx)
390
  dx = self.mlp(cx)
 
393
  gx = self.lnc(fx)
394
  return gx
395
 
396
+ class OneShot(nn.Module):
397
+ def __init__(self, dims: int, head: int, scale: float = 0.3):
398
+ super().__init__()
399
+ self.head = head
400
+ self.hdim = dims // head
401
+ self.scale = scale
402
+ self.q_proj = Linear(dims, dims)
403
+ self.k_proj = Linear(dims, dims)
404
+
405
+ def forward(self, x: Tensor, guide: Tensor, f=None) -> Tensor | None:
406
+ B, Q, _ = x.shape
407
+ K = guide.size(1)
408
+ q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
409
+ k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
410
+ bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
411
+ return bias
412
+
413
+ class curiosity(nn.Module):
414
+ def __init__(self, d, h, bias=True):
415
+ super().__init__()
416
+ self.h = h
417
+ self.dh = d // h
418
+ self.qkv = nn.Linear(d, d * 3, bias=bias)
419
+ self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
420
+ self.o = nn.Linear(d, d, bias=bias)
421
+ self.g = nn.Parameter(torch.zeros(h))
422
+
423
+ def split(self, x):
424
+ b, t, _ = x.shape
425
+ return x.view(b, t, self.h, self.dh).transpose(1, 2)
426
+
427
+ def merge(self, x):
428
+ b, h, t, dh = x.shape
429
+ return x.transpose(1, 2).contiguous().view(b, t, h * dh)
430
+
431
+ def forward(self, x, xa, mask=None):
432
+ q, k, v = self.qkv(x).chunk(3, -1)
433
+ qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
434
+ q, k, v = map(self.split, (q, k, v))
435
+ qa, ka, va = map(self.split, (qa, ka, va))
436
+ dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
437
+ dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
438
+ if mask is not None: dots = dots.masked_fill(mask, -9e15)
439
+ p = dots.softmax(-1)
440
+ pa = dots_aux.softmax(-1)
441
+ h_main = p @ v
442
+ h_aux = pa @ va
443
+ g = torch.sigmoid(self.g).view(1, -1, 1, 1)
444
+ out = self.merge(h_main * (1 - g) + h_aux * g)
445
+ return self.o(out)
446
+
447
+ class PositionalEncoding(nn.Module):
448
+ def __init__(self, dims, ctx):
449
+ super(PositionalEncoding, self).__init__()
450
+ self.dims = dims
451
+ self.ctx = ctx
452
+ self.pe = self.get_positional_encoding(max_ctx=ctx)
453
+
454
+ def get_positional_encoding(self, max_ctx):
455
+ pe = torch.zeros(max_ctx, self.dims)
456
+ position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
457
+ div_term = torch.exp(
458
+ torch.arange(0, self.dims, 2, dtype=torch.float32)
459
+ * (-math.log(10000.0) / self.dims)
460
+ )
461
+ pe[:, 0::2] = torch.sin(position * div_term)
462
+ pe[:, 1::2] = torch.cos(position * div_term)
463
+ pe = pe.unsqueeze(0)
464
+ return pe.to(device)
465
+
466
+ def forward(self, x):
467
+ ctx = x.size(1)
468
+ pe = self.pe[:, :ctx, :]
469
+ x = x * math.sqrt(self.dims)
470
+ x = x + pe
471
+ return x
472
+
473
  class FEncoder(nn.Module):
474
+ def __init__(self, mels, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None, debug=[]):
475
  super().__init__()
476
 
477
  self.head = head
 
479
  self.dropout = 0.01
480
  self.use_rope = use_rope
481
  self.dims = dims
482
+ self.debug = debug
483
  act_fn = get_activation(act)
484
+ self.attend_pitch = False
485
+
486
+ if self.attend_pitch:
487
+ self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
488
+ self.mlp = nn.Sequential(
489
+ nn.Linear(dims, dims),
490
+ nn.ReLU(),
491
+ nn.Linear(dims, dims),
492
+ )
493
+ else:
494
+ self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
495
+ self.mlp = None
496
+
497
  self.encoder = nn.Sequential(
498
+ Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
499
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
500
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
501
+
502
  if use_rope:
503
  if spec_shape is not None:
504
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
 
 
 
 
 
 
 
 
 
505
  else:
506
  self.rope = None
507
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
508
  self.norm = RMSNorm(dims)
509
 
510
+ def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
511
  batch, ctx, dims = x.shape
512
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
513
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
514
+ x = self.rope.apply_rotary(x, freqs)
 
 
 
515
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
516
+
517
  return x
518
 
519
+ def forward(self, x: Tensor, en=None, f=None, layer = None):
520
  x = self.encoder(x).permute(0, 2, 1)
521
  if self.use_rope:
522
+ x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
523
  else:
524
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
525
+
526
+ if self.mlp is not None:
527
+ x = self.mlp(x)
528
+
529
+ if self.attend_pitch:
530
+ xa = en["input_ids"]
531
+ if xa is not None:
532
+ q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
533
+ out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
534
+ out = self.o(out)
535
+ x = x + out
536
+
537
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
538
+ x = self.norm(x)
539
+ return x
540
 
541
  class WEncoder(nn.Module):
542
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
543
  super().__init__()
544
 
545
  self.head = head
 
547
  self.dropout = 0.01
548
  self.use_rope = use_rope
549
  self.dims = dims
550
+ self.debug = debug
551
  act_fn = get_activation(act)
552
+ self.target_length = None
 
 
 
 
553
  self.encoder = nn.Sequential(
554
+ Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
555
+ Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
556
+ Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
557
+
558
  if use_rope:
559
+ if spec_shape is not None:
560
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
 
 
561
  else:
562
  self.rope = None
563
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
564
  self.norm = RMSNorm(dims)
565
 
566
+ def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
 
 
567
  batch, ctx, dims = x.shape
568
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
569
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
570
+ x = self.rope.apply_rotary(x, freqs)
571
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
572
  return x
573
 
574
+ def forward(self, x: Tensor, en= None, f=None, layer = None):
575
+ x = self.encoder(x).permute(0, 2, 1)
576
+ if self.target_length and x.shape[1] != self.target_length:
577
+ x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
578
  if self.use_rope:
579
+ x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
580
  else:
581
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
582
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
583
+
584
+ x = self.ln(x)
585
+ print(f"X: {x.shape} {f}") if "encoder" in self.debug else None
586
  return self.norm(x)
587
 
588
  class PEncoder(nn.Module):
589
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=True, debug=[], one_shot=False, spec_shape=None):
590
  super().__init__()
591
 
592
  self.head = head
593
  self.head_dim = dims // head
594
+ self.dims = dims
595
  self.dropout = 0.01
596
  self.use_rope = use_rope
597
+ self.debug = debug
 
598
  act_fn = get_activation(act)
599
+
600
  self.encoder = nn.Sequential(
601
+ Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
602
+ Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
603
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
604
 
 
605
  if use_rope:
606
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
 
 
 
607
  else:
608
  self.rope = None
609
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
610
+
611
  self.norm = RMSNorm(dims)
612
+
613
+ def rope_to_feature(self, x, en=None, f="pitch", layer="PEncoder"):
 
 
614
  batch, ctx, dims = x.shape
615
  x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
616
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
617
+ x = self.rope.apply_rotary(x, freqs)
618
  x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
619
  return x
620
 
621
+ def forward(self, x: Tensor, en= None, f="pitch", layer="PEncoder"):
622
+
623
+ if x.dim() == 2:
624
+ x = x.unsqueeze(0)
625
+
626
+ x = self.encoder(x).permute(0, 2, 1)
627
  if self.use_rope:
628
+ x = self.rope_to_feature(x, en=en, f=f, layer=layer)
629
  else:
630
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
631
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
632
+ x = self.norm(x)
633
+ print(f"X: {x.shape} {f}") if "PEncoder" in self.debug else None
634
+ return x
 
 
 
 
 
 
 
 
 
 
635
 
636
  class theBridge(nn.Module):
637
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
638
  debug: List[str], features: List[str], act: str = "gelu"):
639
  super(theBridge, self).__init__()
640
 
641
+ tgate = True
642
+ mgate = False
643
+ cgate = False
644
+
645
  self.debug = debug
646
  self.counter = 0
647
  self.dropout = 0.01
648
  self.features = features
649
  self.do_blend = "no_blend" not in self.debug
650
  self.sequential = "sequential" in self.debug
651
+ self.layer = layer
652
 
653
  self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
654
  self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
655
  self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
656
+ self.norm = RMSNorm(dims)
657
+ self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, 10000)
658
+ self.rotary = rotary(dims=dims, head=head, debug=debug, radii=False)
659
 
660
  with torch.no_grad():
661
  self.token.weight[0].zero_()
662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  act_fn = get_activation(act)
664
  if features == ["spectrogram", "waveform", "pitch"]:
665
  cgate=True
666
  else:
667
  cgate = False
668
+
669
+ self.blockA = nn.ModuleDict()
670
+ self.blockA["waveform"] = nn.ModuleList(
671
+ [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
672
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
673
+ for _ in range(layer)] if "waveform" in features else None)
674
+
675
+ for feature_type in ["spectrogram", "aperiodic", "harmonic"]:
676
+ if feature_type in features:
677
+ self.blockA[feature_type] = nn.ModuleList(
678
+ [FEncoder(mels=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
679
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
680
+ else:
681
+ self.blockA[feature_type] = None
682
+
683
+ for feature_type in ["pitch", "phase"]:
684
+ if feature_type in features:
685
+ self.blockA[feature_type] = nn.ModuleList(
686
+ [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act_fn)] +
687
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
688
+ else:
689
+ self.blockA[feature_type] = None
690
+
691
+ self.blockB = nn.ModuleList([
692
+ Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
693
+ for _ in range(layer)])
694
+
695
+ self.modal = nn.ModuleList([
696
+ Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
697
+ for _ in range(layer)])
698
+
699
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
700
+ self.register_buffer("mask", mask, persistent=False)
701
+
702
+ self.norm = RMSNorm(dims)
703
+
704
+ def forward(self, x, xa, en, f, sequential=False) -> Tensor:
705
+ mask = self.mask[:x.shape[1], :x.shape[1]]
706
+ x = self.token(x.long()) + self.positional[:x.shape[1]]
707
+
708
  out = {}
709
+ out["input_ids"] = x
710
+ out.update(en)
711
+
712
+ for b in chain(self.blockA[f] or []):
713
+ xa = b(x=xa, en=out, f=f, layer="en")
714
+
715
+ for b in chain(self.blockB or []):
716
+ x = b(x=x, xa=None, mask=mask, en=out, f=f, layer="dec")
717
+ y = b(x, xa=xa, mask=None, en=out, f=f, layer="cross")
718
+ if sequential:
719
+ x = y
720
+ else:
721
+ a = torch.sigmoid(self.blend)
722
+ x = a * y + (1 - a) * x
723
+ for b in self.modal:
724
+ xc = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None, en=out, f=f, layer="modal")
725
+ xm = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None, en=out, f=f, layer="modal")
726
+ if sequential:
727
+ x = xm
728
+ else:
729
+ a = torch.sigmoid(self.blend)
730
+ x = a * x + (1 - a) * xm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
  if self.counter < 1 and "encoder" in self.debug:
733
+ shapes = {k: v.shape for k, v in en.items()}
734
+ print(f"Step {self.counter}: mode: {list(en.keys()) }: shapes: {shapes}")
735
  self.counter += 1
736
+
737
+ x = self.norm(x)
738
  x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
 
739
 
740
+ return x
741
+
742
  class Echo(nn.Module):
743
  def __init__(self, param: Dimensions):
744
  super().__init__()
 
766
  f0t: Optional[torch.Tensor]=None,
767
  harmonic: Optional[torch.Tensor]=None,
768
  aperiodic: Optional[torch.Tensor]=None,
769
+ phase: Optional[torch.Tensor]=None,
770
  ) -> Dict[str, Optional[torch.Tensor]]:
771
 
772
+ en= TensorDict(batch_size=[1], device=self.device, dtype=self.dtype)
773
+
774
+ en= {}
 
 
 
 
 
 
 
775
  if f0 is not None:
776
+ en["f0"] = f0
777
  if f0t is not None:
778
+ en["f0t"] = f0t
779
  if harmonic is not None:
780
+ en["harmonic"] = harmonic
781
  if aperiodic is not None:
782
+ en["aperiodic"] = aperiodic
783
+ if phase is not None:
784
+ en["phase"] = phase
785
+ if pitch is not None:
786
+ en["pitch"] = pitch
787
+ if waveform is not None:
788
+ en["waveform"] = waveform
789
+ if spectrogram is not None:
790
+ en["spectrogram"] = spectrogram
791
+
792
+ x = input_ids
793
+ for f, xa in en.items():
794
 
795
+ logits = self.processor(x, xa, en, f)
 
796
 
797
  loss = None
798
  if labels is not None:
 
812
  std = 0.02
813
  self.init_counts = {
814
  "Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
815
+ "Conv2d": 0, "theBridge": 0, "Echo": 0,
816
  "Residual": 0, "MultiheadA": 0,
817
  "MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
818
  "WEncoder": 0, "PEncoder": 0}
 
841
  self.init_counts["MultiheadA"] += 1
842
  elif isinstance(module, Residual):
843
  self.init_counts["Residual"] += 1
844
+ elif isinstance(module, PEncoder):
845
+ self.init_counts["PEncoder"] += 1
846
+ elif isinstance(module, FEncoder):
847
+ self.init_counts["FEncoder"] += 1
848
+ elif isinstance(module, WEncoder):
849
+ self.init_counts["WEncoder"] += 1
850
+ elif isinstance(module, theBridge):
851
+ self.init_counts["theBridge"] += 1
852
+ elif isinstance(module, Echo):
853
+ self.init_counts["Echo"] += 1
854
+
855
  def init_weights(self):
856
  print("Initializing model weights...")
857
  self.apply(self._init_weights)
 
915
  })
916
  return Config()
917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  def main():
919
  token = ""
920
  log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
921
  os.makedirs(log_dir, exist_ok=True)
922
+ tokenizer = setup_tokenizer("./")
923
+
924
+ sanity_check = False
925
+ streaming = False
926
+ load_saved = False
927
+ save_dataset = False
928
+ cache_dir = None
929
+ extract_args = None
930
+
931
+ extract_args = {
932
+ "waveform": False,
933
+ "spec": True,
934
+ "f0": False,
935
+ "f0t": False,
936
+ "pitch": True,
937
+ "harmonics": False,
938
+ "aperiodics": False,
939
+ "phase_mod": False,
940
+ "crepe": False,
941
+ "sample_rate": 16000,
942
+ "hop_length": 256,
943
+ "mode": "mean",
944
+ "debug": False,
945
+ }
946
 
 
 
947
  param = Dimensions(
948
  vocab=40000,
949
  mels=128,
950
+ ctx=2048,
951
  dims=512,
952
  head=4,
953
  layer=4,
954
  act="swish",
955
+ debug={"encoder"},
956
+ features = ["spectrogram", "pitch"],
957
  )
958
+
959
+ train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=sanity_check, sample_rate=16000, streaming=streaming,
960
+ load_saved=load_saved, save_dataset=save_dataset, cache_dir=cache_dir, extract_args=extract_args, max_ctx=param.ctx)
961
+
962
  model = Echo(param).to('cuda')
963
  print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
964
  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
965
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
  from functools import partial
967
  metrics_fn = partial(compute_metrics,
968
+ print_pred=True,
969
+ num_samples=1,
970
  tokenizer=tokenizer, model=model)
971
 
972
+ if sanity_check:
973
+ training_args = Seq2SeqTrainingArguments(
974
+ output_dir=log_dir,
975
+ per_device_train_batch_size=1,
976
+ per_device_eval_batch_size=1,
977
+ max_steps=10,
978
+ eval_steps=5,
979
+ save_steps=0,
980
+ warmup_steps=0,
981
+ logging_steps=1,
982
+ logging_dir=log_dir,
983
+ eval_strategy="steps",
984
+ save_strategy="no",
985
+ logging_strategy="no",
986
+ report_to=["tensorboard"],
987
+ push_to_hub=False,
988
+ save_total_limit=1,
989
+ label_names=["labels"],
990
+ save_safetensors=False,
991
+ eval_on_start=True,
992
+ batch_eval_metrics=False,
993
+ disable_tqdm=False,
994
+ include_tokens_per_second=True,
995
+ include_num_input_tokens_seen=True,
996
+ learning_rate=1e-7,
997
+ weight_decay=0.01,
998
+ )
999
+ else:
1000
+ training_args = Seq2SeqTrainingArguments(
1001
+ output_dir=log_dir,
1002
+ per_device_train_batch_size=1,
1003
+ per_device_eval_batch_size=1,
1004
+ max_steps=1000,
1005
+ eval_steps=100,
1006
+ save_steps=1000,
1007
+ warmup_steps=100,
1008
+ logging_steps=10,
1009
+ logging_dir=log_dir,
1010
+ logging_strategy="steps",
1011
+ eval_strategy="steps",
1012
+ save_strategy="no",
1013
+ report_to=["tensorboard"],
1014
+ push_to_hub=False,
1015
+ save_total_limit=1,
1016
+ label_names=["labels"],
1017
+ save_safetensors=False,
1018
+ eval_on_start=True,
1019
+ batch_eval_metrics=False,
1020
+ disable_tqdm=False,
1021
+ include_tokens_per_second=True,
1022
+ include_num_input_tokens_seen=True,
1023
+ learning_rate=0.00025,
1024
+ weight_decay=0.025,
1025
+ )
1026
+
1027
+ optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
1028
  amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
 
1029
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
1030
 
1031
  trainer = Seq2SeqTrainer(
 
1041
 
1042
  model.init_weights()
1043
  trainer.train()
 
1044
  if __name__ == "__main__":
 
 
 
1045
 
1046
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1047