Sin2pi commited on
Commit
4adcb60
·
verified ·
1 Parent(s): de2e988

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1021 -277
model.py CHANGED
@@ -1,6 +1,7 @@
1
 
2
  import pyworld as pw
3
  import os
 
4
  import math, random
5
  import warnings
6
  import logging
@@ -10,22 +11,32 @@ import torch
10
  import torchaudio
11
  import torch.nn.functional as F
12
  import torch.nn.init as init
13
- from torch import nn, Tensor
14
  import numpy as np
 
 
15
  from typing import Optional, Dict, Union, List, Tuple, Any
16
  from functools import partial
17
  from datetime import datetime
18
- from datasets import load_dataset, Audio, concatenate_datasets
19
  from transformers.trainer_seq2seq import Seq2SeqTrainer
20
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
21
  import transformers
22
  import evaluate
23
  from dataclasses import dataclass
24
- import matplotlib.pyplot as plt
25
 
26
- device = torch.device(device="cuda:0")
 
 
 
 
 
27
  dtype = torch.float32
28
 
 
 
 
29
  extractor = None
30
  tokenizer = None
31
  optimizer = None
@@ -171,6 +182,11 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
171
  plt.show()
172
  return fig
173
 
 
 
 
 
 
174
  def exists(v):
175
  return v is not None
176
 
@@ -226,7 +242,7 @@ def get_device():
226
  def get_dtype():
227
  return torch.float32 if torch.cuda.is_available() else torch.float64
228
 
229
- def get_tox():
230
  return {"device": get_device(), "dtype": get_dtype()}
231
 
232
  def sinusoids(length, channels, max_timescale=10000):
@@ -237,55 +253,93 @@ def sinusoids(length, channels, max_timescale=10000):
237
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
238
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
239
 
240
- class ParameterCycler:
241
- def __init__(self, parameters):
242
- self.parameters = parameters
243
- self.current_idx = 0
244
- def toggle_requires_grad(self):
245
- x = random.randint(0, len(self.parameters) - 1)
246
- for x, param in enumerate(self.parameters):
247
- param.requires_grad = (x == self.current_idx)
248
- print(f"Parameter {x}: requires_grad={param.requires_grad}")
249
- self.current_idx = (self.current_idx + 1) % len(self.parameters)
250
-
251
- def extract_f0(waveform, sampling_rate=16000, hop_length=128, device="cuda:0"):
252
- """Extract F0 from waveform - handle various input types"""
253
- if waveform is None:
254
- return None
255
-
256
- if isinstance(waveform, list):
257
- if len(waveform) == 0:
258
- return None
259
- waveform = waveform[0]
260
- print(f"DEBUG: Converted list to tensor, new type: {type(waveform)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- if not isinstance(waveform, torch.Tensor):
263
- waveform = torch.tensor(waveform)
264
-
265
- if isinstance(waveform, torch.Tensor):
266
- if waveform.dim() == 3:
267
- waveform = waveform.squeeze(1)
268
- if waveform.dim() == 2:
269
- waveform = waveform[0]
270
-
271
- wav_np = waveform.detach().cpu().numpy().astype(np.float64)
272
  else:
273
- wav_np = np.array(waveform).astype(np.float64)
274
-
275
- f0, t = pw.dio(wav_np, sampling_rate,
276
- frame_period=hop_length/sampling_rate*1000)
277
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- f0_tensor = torch.from_numpy(f0).float().to(device)
280
- return f0_tensor.unsqueeze(0).unsqueeze(0)
 
 
 
281
 
282
  class rotary(nn.Module):
283
  _seen = set()
284
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
285
- learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [], use_pbias = False):
 
286
  super().__init__()
287
 
288
- self.use_pbias = use_pbias
 
 
 
289
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
290
  self.dtype = torch.float32
291
  self.debug = debug
@@ -305,117 +359,146 @@ class rotary(nn.Module):
305
 
306
  self.theta = nn.Parameter(torch.tensor(theta, device=self.device, dtype=self.dtype), requires_grad=True)
307
  self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale, device=self.device, dtype=self.dtype), requires_grad=True)
308
- freqs = 1. / (theta ** (torch.arange(0, dims, 2, device=self.device, dtype=self.dtype)[:(dims // 2)].float() / dims))
309
- self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
311
 
312
- def forward(self, x=None, layer=None, enc=None) -> Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- f0 = enc.get("f0") if enc else None
315
  if isinstance(x, int):
316
  ctx = x
317
- else:
318
  batch, ctx, dims = x.shape
319
- t = torch.arange(ctx, device=self.device).float()
320
-
 
 
 
 
 
 
 
 
321
  if f0 is not None:
322
- f0_mean=f0.mean()+1e-8
323
- theta=f0_mean*self.pitch_scale
324
- freqs = 1. / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() /self.dims))
325
- else:
 
 
326
  freqs = self.freqs
327
 
328
- freqs = torch.einsum('i,j->ij', t, freqs)
329
- freqs = freqs.float()
330
- # print(f"{layer} : {f0_mean} : {theta:.2f} : {ctx} ")
331
  if self.radii:
332
- # radius = self.align_f0(f0, ctx)
333
- radius = enc.get("f0d") if enc else self.radius
334
- radius = radius.float()
335
-
336
- else:
337
- radius = self.radius
338
- # freqs = torch.polar(self.radius.unsqueeze(-1), freqs)
339
- freqs = torch.polar(radius.unsqueeze(-1), freqs)
340
-
341
- if "rotary" in self.debug:
342
  if f0 is not None:
343
- key = f"{self._counter}_{theta:.2f}"
344
- if key not in rotary._seen:
345
- if not hasattr(self, '_prev_f0_theta'):
346
- self._prev_f0_theta = theta
347
- # print(f"Step {self._counter}: Theta: {theta:.2f} Hz")
348
- elif abs(self._prev_f0_theta - theta) > 100.0:
349
- # print(f"Step {self._counter}: Theta: {theta:.2f} Hz, freqs: {freqs.shape}")
350
- print(f"{layer} : {f0_mean} : Theta: {theta:.2f} : {theta:.2f} : {ctx} ")
351
- if self.radii:
352
- print(f"radius: {radius} Hz, enc: {layer} Hz, ctx: {ctx}")
353
- self._prev_f0_theta = theta
354
- rotary._seen.add(key)
355
- self._counter += 1
356
- return freqs
357
 
358
  @staticmethod
359
  def apply_rotary(x, freqs):
360
- multihead_format = len(freqs.shape) == 4
361
- if multihead_format:
362
- x1 = x[..., :freqs.shape[-1]*2]
363
- x2 = x[..., freqs.shape[-1]*2:]
364
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
365
- x1 = torch.view_as_complex(x1)
366
- x1 = x1 * freqs
367
- x1 = torch.view_as_real(x1).flatten(-2)
368
- return torch.cat([x1.type_as(x), x2], dim=-1)
369
- else:
370
- x1 = x[..., :freqs.shape[-1]*2]
371
- x2 = x[..., freqs.shape[-1]*2:]
372
-
373
- if x.ndim == 2:
374
- x1 = x1.unsqueeze(0)
375
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
376
- x1 = torch.view_as_complex(x1)
377
- x1 = x1 * freqs
378
- x1 = torch.view_as_real(x1).flatten(-2)
379
- x1 = x1.squeeze(0)
380
- return torch.cat([x1.type_as(x), x2], dim=-1)
381
- else:
382
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
383
- x1 = torch.view_as_complex(x1)
384
- x1 = x1 * freqs
385
- x1 = torch.view_as_real(x1).flatten(-2)
386
- return torch.cat([x1.type_as(x), x2], dim=-1)
387
 
388
  class MultiheadA(nn.Module):
389
  _seen = set()
390
  rbf = False
391
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
392
- zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
393
-
394
  super(MultiheadA, self).__init__()
395
 
396
  self.dims = dims
397
  self.head = head
398
  self.head_dim = dims // head
399
-
400
- self.q = Linear(dims, dims)
401
- self.k = Linear(dims, dims, bias=False)
402
- self.v = Linear(dims, dims)
403
- self.o = Linear(dims, dims)
404
-
405
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
406
- self.dtype = torch.float32
407
  self.debug = debug
408
  self._counter = 0
409
 
 
 
 
 
 
410
  self.pad_token = 0
411
  self.rotary_emb = rotary_emb
412
  self.minz = minz
413
  self.maxz = maxz
414
  self.zero_val = zero_val
415
  self.optim_attn = optim_attn
416
- self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=False)
417
 
418
  if rotary_emb:
 
 
 
 
 
 
419
  self.rope = rotary(
420
  dims=self.head_dim,
421
  debug = debug,
@@ -426,8 +509,9 @@ class MultiheadA(nn.Module):
426
  learned_radius=False,
427
  )
428
  else:
429
- self.rope = None
430
-
 
431
  def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
432
  scale = (self.dims // self.head) ** -0.25
433
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
@@ -440,45 +524,52 @@ class MultiheadA(nn.Module):
440
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
441
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
442
 
443
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None) -> tuple:
444
-
 
 
 
 
445
  scale = (self.dims // self.head) ** -0.25
446
 
447
- z = xa if xa is not None else x
448
- q = self.q(x).to(x.dtype)
449
- k = self.k(z).to(x.dtype)
450
- v = self.v(z).to(x.dtype)
451
- batch, ctx, dims = q.shape
 
452
 
453
  if self.rotary_emb:
454
- qf = self.rope(q.size(1), layer=layer, feat=feat)
455
- kf = self.rope(k.size(1), layer=layer, feat=feat)
456
-
457
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
458
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
459
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
460
-
461
- q = self.rope.apply_rotary(q, qf)
462
- k = self.rope.apply_rotary(k, kf)
463
-
 
 
 
 
 
464
  else:
465
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
466
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
467
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
468
  batch, head, ctx, head_dim = q.shape
469
-
470
  if self.rbf:
471
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
472
 
473
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
474
- if self.rope.use_pbias:
475
- pbias = self.rope.pbias(feat.get("f0"))
476
  if pbias is not None:
477
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
478
  token_ids = k[:, :, :, 0]
479
  zscale = torch.ones_like(token_ids)
480
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
481
- zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
482
 
483
  if mask is not None:
484
  mask = mask[:q.shape[2], :q.shape[2]]
@@ -488,9 +579,7 @@ class MultiheadA(nn.Module):
488
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
489
 
490
  if "multihead" in self.debug and self._counter % 100 == 0:
491
- print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
492
- print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
493
- print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
494
  self._counter += 1
495
  return self.o(wv), qk.detach()
496
 
@@ -539,7 +628,6 @@ class c_gate(nn.Module):
539
  s = self.s_gate(x) * s_feat
540
  w = self.w_gate(x) * w_feat
541
  p = self.p_gate(x) * p_feat
542
-
543
  comb = torch.cat([s, w, p], dim=-1)
544
  return self.integ(comb)
545
 
@@ -549,8 +637,6 @@ class Residual(nn.Module):
549
  tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
550
  super().__init__()
551
 
552
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
553
- self.dtype = torch.float32
554
  self.dims = dims
555
  self.head = head
556
  self.ctx = ctx
@@ -590,12 +676,16 @@ class Residual(nn.Module):
590
  if not any([t_gate, m_gate, c_gate]):
591
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
592
 
593
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, feat=None, layer = None):
 
 
 
 
594
  bln = self.blend
595
- x = x + self.attna(self.lna(x), xa=None, mask=mask, layer=layer, feat=feat)[0]
596
 
597
  if self.attnb and xa is not None:
598
- c = self.attnb(self.lnb(x), xa, mask=None, layer=layer, feat=feat)[0]
599
  b = torch.sigmoid(bln)
600
  x = b * x + (1 - b) * c
601
 
@@ -610,7 +700,7 @@ class Residual(nn.Module):
610
  gate = self.m_gate(normx)
611
  x = x + gate * mlp_out
612
 
613
- elif self.c_gate is not None:
614
  gate_output = self.c_gate(normx, self.features)
615
  x = x + gate_output
616
 
@@ -635,34 +725,74 @@ class Residual(nn.Module):
635
 
636
  return x
637
 
638
- class PEncoder(nn.Module):
639
- def __init__(self, input_dims, dims, head, layer, kernel_size, act):
640
  super().__init__()
641
 
642
- self.head_dim = dims // head
643
- self.dropout = 0.01
 
 
 
644
 
645
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
646
  act_fn = act_map.get(act, nn.GELU())
647
 
648
  self.encoder = nn.Sequential(
649
- Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
650
- Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
651
- Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
652
 
653
- def forward(self, x, feat=None, layer=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  x = self.encoder(x).permute(0, 2, 1)
655
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
 
 
 
656
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
657
- x = self.norm(x)
658
  return x
659
-
660
  class WEncoder(nn.Module):
661
- def __init__(self, input_dims, dims, head, layer, kernel_size, act):
662
  super().__init__()
663
 
 
664
  self.head_dim = dims // head
665
  self.dropout = 0.01
 
 
666
 
667
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
668
  act_fn = act_map.get(act, nn.GELU())
@@ -675,78 +805,83 @@ class WEncoder(nn.Module):
675
  self.encoder = nn.Sequential(
676
  Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
677
  Conv1d(dims, dims, kernel_size=1), act_fn)
678
-
679
- self.positional = lambda length: sinusoids(length, dims)
 
 
 
 
 
 
680
  self.norm = RMSNorm(dims)
 
 
 
 
 
 
 
 
 
 
681
 
682
- def forward(self, x, feat=None, layer=None):
683
  x = self.downsample(x)
684
  x = self.encoder(x)
685
  x = x.permute(0, 2, 1)
686
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
 
 
 
687
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
688
  return self.norm(x)
689
 
690
- class FEncoder(nn.Module):
691
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
692
  super().__init__()
693
 
694
- self.head_dim = dims // head
695
- self.dropout = 0.01
 
 
 
696
 
697
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
698
  act_fn = act_map.get(act, nn.GELU())
699
 
700
  self.encoder = nn.Sequential(
701
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
702
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
703
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
704
 
705
- self.positional = lambda length: sinusoids(length, dims)
 
 
 
 
 
 
 
706
  self.norm = RMSNorm(dims)
707
- self._norm = RMSNorm(dims)
708
 
709
- def forward(self, x, feat=None, layer=None):
710
- x = self.encoder(x).permute(0, 2, 1)
711
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
712
- x = nn.functional.dropout(x, p=self.dropout, training=self.training)
713
- x = self._norm(x)
 
 
 
714
  return x
715
-
716
- class F0Encoder(nn.Module):
717
- def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1):
718
- super().__init__()
719
 
720
- self.head_dim = dims // head
721
- self.dropout = 0.01
722
-
723
- act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(),
724
- "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(),
725
- "softplus": nn.Softplus(), "softshrink": nn.Softshrink(),
726
- "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
727
- act_fn = act_map.get(act, nn.GELU())
728
-
729
- self.encoder = nn.Sequential(
730
- Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
731
- Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
732
- Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
733
-
734
- self.positional = lambda length: sinusoids(length, dims)
735
- self.norm = RMSNorm(dims)
736
- self._norm = RMSNorm(dims)
737
-
738
- def forward(self, x, feat=None, layer=None):
739
- if x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == 1:
740
- pass
741
- elif x.dim() == 2:
742
- x = x.unsqueeze(1)
743
- elif x.dim() == 1:
744
- x = x.unsqueeze(0).unsqueeze(0)
745
- x = self.encoder(x)
746
- x = x.permute(0, 2, 1)
747
- x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
748
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
749
- x = self._norm(x)
750
  return x
751
 
752
  class AudioEncoder(nn.Module):
@@ -759,20 +894,15 @@ class AudioEncoder(nn.Module):
759
  self.head = head
760
  self.ctx = ctx
761
  self.head_dim = dims // head
762
-
763
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
764
- dtype = torch.float32
765
- self.device = device
766
- self.dtype = dtype
767
  self.debug = debug
768
  self._counter = 0
769
 
770
  self.features = features
771
  self.dropout = 0.01
772
- self.f0_rotary = f0_rotary
773
 
774
  self.rope = rotary(
775
- dims=self.head_dim)
 
776
 
777
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
778
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
@@ -809,34 +939,44 @@ class AudioEncoder(nn.Module):
809
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
810
  for _ in range(layer)])
811
 
812
- def forward(self, feat, layer="encoder"):
813
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814
  if self._counter < 1:
815
- s = feat.get("spectrogram")
816
- w = feat.get("waveform")
817
- p = default(feat.get("f0"), feat.get("pitch"))
818
  plot_waveform(x=s, w=w, p=p, hop_length=128)
819
 
820
- enc = {}
821
- enc.update(feat)
822
-
823
  for f in self.features:
824
- if f in feat and f in self.blocks:
825
- x = feat[f]
826
  for block in self.blocks[f]:
827
- x = block(x, feat=feat, layer=layer)
828
- enc[f] = x
829
-
830
  if "encoder" in self.debug and self._counter % 100 == 0:
831
- names = list(feat.keys())
832
- shapes = {k: v.shape for k, v in feat.items()}
833
- print(f"Step {self._counter}: mode: {names}")
834
- print(f"shapes: {shapes}")
835
- for name, param in self.named_parameters():
836
- if param.requires_grad:
837
- print(f"ENCODER LAYER {name}: grad_norm={param.median():.4f}")
838
  self._counter += 1
839
- return enc
840
 
841
  class TextDecoder(nn.Module):
842
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
@@ -848,10 +988,8 @@ class TextDecoder(nn.Module):
848
  self.ctx = ctx
849
  self.head_dim = dims // head
850
 
851
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
852
- dtype = torch.float32
853
- self.device = device
854
- self.dtype = dtype
855
  self.debug = debug
856
  self._counter = 0
857
 
@@ -878,48 +1016,36 @@ class TextDecoder(nn.Module):
878
 
879
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
880
  self.register_buffer("mask", mask, persistent=False)
881
-
882
- rotary_emb = False
883
- if rotary_emb:
884
- self.rope = rotary(
885
- dims=self.head_dim,
886
- debug = debug,
887
- radii=False,
888
- learned_pitch=False,
889
- learned_freq=False,
890
- learned_theta=False,
891
- learned_radius=False,
892
- )
893
- else:
894
- self.rope = None
895
 
896
- def forward(self, x, feat, order=None, layer='decoder') -> Tensor:
897
-
898
- bln = self.blend
899
  x = x.to(device)
 
 
900
  if order is None:
901
  order = self.features
 
902
  mask = self.mask[:x.shape[1], :x.shape[1]]
903
  x = self.token(x) + self.positional[:x.shape[1]]
904
  x = F.dropout(x, p=self.dropout, training=self.training)
905
-
906
  for block in self.block:
907
- x = block(x, xa=None, mask=mask, feat=feat, layer=layer)
908
 
909
  for f in order:
910
- if f in feat:
911
- xa = feat[f]
912
  for block in self.blocks[f]:
913
- out = block(x=x, xa=xa, mask=None, feat=feat, layer=layer)
 
914
  a = torch.sigmoid(bln[f])
915
  x = a * out + (1 - a) * x
916
- x = self.ln_dec(x)
917
-
918
  if "decoder" in self.debug and self._counter % 100 == 0:
919
- for name, param in self.named_parameters():
920
- if param.requires_grad:
921
- print(f"DECODER LAYER {name}: grad_norm={param.median():.4f}")
922
- self._counter += 1
923
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
924
 
925
  class Echo(nn.Module):
@@ -999,8 +1125,8 @@ class Echo(nn.Module):
999
  if f0d is not None:
1000
  encoder_inputs["f0d"] = f0d
1001
 
1002
- encoder_outputs = self.encoder(encoder_inputs)
1003
- logits = self.decoder(input_ids, encoder_outputs)
1004
 
1005
  loss = None
1006
  if labels is not None:
@@ -1017,6 +1143,7 @@ class Echo(nn.Module):
1017
  "encoder_output": encoder_outputs,
1018
  }
1019
 
 
1020
  def device(self):
1021
  return next(self.parameters()).device
1022
  @property
@@ -1071,7 +1198,7 @@ class Echo(nn.Module):
1071
  print(f"{module_type}: {count}")
1072
 
1073
  def register_gradient_hooks(self):
1074
-
1075
  for name, param in self.named_parameters():
1076
  if param.requires_grad:
1077
  if "encoder" in name:
@@ -1096,6 +1223,623 @@ class Echo(nn.Module):
1096
  return None
1097
 
1098
  def reset_counter(self):
 
1099
  self._counter = 0
1100
  print("Counter reset to 0.")
1101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import pyworld as pw
3
  import os
4
+ from torch.amp import autocast
5
  import math, random
6
  import warnings
7
  import logging
 
11
  import torchaudio
12
  import torch.nn.functional as F
13
  import torch.nn.init as init
14
+ from torch import nn, einsum, broadcast_tensors, Tensor
15
  import numpy as np
16
+ from einops import rearrange, repeat
17
+ import matplotlib.pyplot as plt
18
  from typing import Optional, Dict, Union, List, Tuple, Any
19
  from functools import partial
20
  from datetime import datetime
21
+ from datasets import load_dataset, Audio
22
  from transformers.trainer_seq2seq import Seq2SeqTrainer
23
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
24
  import transformers
25
  import evaluate
26
  from dataclasses import dataclass
27
+ from math import pi, log
28
 
29
+ torch.backends.cudnn.allow_tf32 = True
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.set_float32_matmul_precision('high')
32
+ transformers.utils.logging.set_verbosity_error()
33
+
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  dtype = torch.float32
36
 
37
+ warnings.filterwarnings("ignore")
38
+ logging.basicConfig(level=logging.ERROR)
39
+
40
  extractor = None
41
  tokenizer = None
42
  optimizer = None
 
182
  plt.show()
183
  return fig
184
 
185
+ def dict_to(d, device, dtype=dtype):
186
+ """Because PyTorch should have this built-in but doesn't"""
187
+ return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
188
+ for k, v in d.items()}
189
+
190
  def exists(v):
191
  return v is not None
192
 
 
242
  def get_dtype():
243
  return torch.float32 if torch.cuda.is_available() else torch.float64
244
 
245
+ def tox():
246
  return {"device": get_device(), "dtype": get_dtype()}
247
 
248
  def sinusoids(length, channels, max_timescale=10000):
 
253
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
254
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
255
 
256
+ def rotate_half(x):
257
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
258
+ x1, x2 = x.unbind(dim = -1)
259
+ x = torch.stack((-x2, x1), dim = -1)
260
+ return rearrange(x, '... d r -> ... (d r)')
261
+
262
+ def broadcat(tensors, dim = -1):
263
+ broadcasted_tensors = broadcast_tensors(*tensors)
264
+ return torch.cat(broadcasted_tensors, dim = dim)
265
+
266
+ def slice_at_dim(t, dim_slice: slice, *, dim):
267
+ dim += (t.ndim if dim < 0 else 0)
268
+ colons = [slice(None)] * t.ndim
269
+ colons[dim] = dim_slice
270
+ return t[tuple(colons)]
271
+
272
+ def align_f0(f0, ctx):
273
+ b, l = f0.shape
274
+ if l == ctx:
275
+ return f0.squeeze(0).float()
276
+ frames_per_token = l / ctx
277
+ idx = torch.arange(ctx, device=device, dtype=dtype)
278
+ src_idx = (idx * frames_per_token).long().clamp(0, l-1)
279
+ batch_idx = torch.arange(b, device=device, dtype=dtype).unsqueeze(1)
280
+ f0 = f0[batch_idx, src_idx]
281
+ return f0.squeeze(0).float()
282
+
283
+ def align_f0(f0, target_length, method='nearest', device=device, dtype=dtype):
284
+ if device is None:
285
+ device = f0.device
286
+ if dtype is None:
287
+ dtype = f0.dtype
288
+ original_shape = f0.shape
289
+ squeeze_batch = False
290
+ reshape_back = None
291
 
292
+ if f0.dim() == 1:
293
+ f0 = f0.unsqueeze(0)
294
+ squeeze_batch = True
295
+ elif f0.dim() == 2:
296
+ pass
297
+ elif f0.dim() == 3:
298
+ batch_size, seq_len, length = f0.shape
299
+ f0 = f0.view(-1, length)
300
+ reshape_back = (batch_size, seq_len)
 
301
  else:
302
+ raise ValueError(f"F0 tensor must be 1D, 2D, or 3D, got {f0.dim()}D")
303
+ batch_size, current_length = f0.shape
304
+ if current_length == target_length:
305
+ result = f0
306
+ elif method == 'nearest':
307
+ frames_per_token = current_length / target_length
308
+ target_indices = torch.arange(target_length, device=device, dtype=torch.float32)
309
+ source_indices = (target_indices * frames_per_token).long().clamp(0, current_length - 1)
310
+ batch_indices = torch.arange(batch_size, device=device, dtype=torch.long).unsqueeze(1)
311
+ result = f0[batch_indices, source_indices]
312
+ else:
313
+ import torch.nn.functional as F
314
+ f0_for_interp = f0.unsqueeze(1)
315
+ mode_map = {'linear': 'linear', 'cubic': 'bicubic'}
316
+ if method not in mode_map:
317
+ raise ValueError(f"Method '{method}' not supported. Use 'nearest', 'linear', or 'cubic'")
318
+
319
+ result = F.interpolate(
320
+ f0_for_interp.float(),
321
+ size=target_length,
322
+ mode=mode_map[method],
323
+ align_corners=False
324
+ ).squeeze(1)
325
 
326
+ if reshape_back is not None:
327
+ result = result.view(reshape_back[0], reshape_back[1], target_length)
328
+ elif squeeze_batch:
329
+ result = result.squeeze(0)
330
+ return result.to(dtype)
331
 
332
  class rotary(nn.Module):
333
  _seen = set()
334
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
335
+ learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = [],
336
+ use_pbias=False, use_2d_axial=False, spec_shape=None):
337
  super().__init__()
338
 
339
+ self.use_pbias = False
340
+ self.use_2d_axial = use_2d_axial
341
+ self.spec_shape = spec_shape
342
+ self.last_f0_theta = None
343
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
344
  self.dtype = torch.float32
345
  self.debug = debug
 
359
 
360
  self.theta = nn.Parameter(torch.tensor(theta, device=self.device, dtype=self.dtype), requires_grad=True)
361
  self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale, device=self.device, dtype=self.dtype), requires_grad=True)
362
+
363
+ if use_2d_axial and spec_shape is not None:
364
+ time_frames, freq_bins = spec_shape
365
+ self.time_frames = time_frames
366
+ self.freq_bins = freq_bins
367
+
368
+ time_theta = 50.0
369
+ time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
370
+ self.register_buffer('time_freqs', time_freqs)
371
+
372
+ freq_theta = 100.0
373
+ freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
374
+ self.register_buffer('freq_freqs', freq_freqs)
375
+ else:
376
+ freqs = 1. / (theta ** (torch.arange(0, dims, 2, device=self.device, dtype=self.dtype)[:(dims // 2)].float() / dims))
377
+ self.freqs = nn.Parameter(torch.tensor(freqs, device=self.device, dtype=self.dtype), requires_grad=True)
378
  self.radius = nn.Parameter(torch.ones(radius, device=self.device, dtype=self.dtype), requires_grad=True)
379
 
380
+ def compute_2d_axial_freqs(self, seq_len):
381
+ if not self.use_2d_axial:
382
+ return None
383
+ time_frames = self.time_frames
384
+ freq_bins = self.freq_bins
385
+
386
+ t = torch.arange(seq_len, device=self.device, dtype=self.dtype)
387
+ t_x = (t % time_frames).float()
388
+ t_y = torch.div(t, time_frames, rounding_mode='floor').float()
389
+ freqs_x = torch.outer(t_x, self.time_freqs)
390
+ freqs_y = torch.outer(t_y, self.freq_freqs)
391
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
392
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
393
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
394
+
395
+ def align_f0(self, f0, ctx):
396
+ b, l = f0.shape
397
+ if l == ctx:
398
+ return f0.squeeze(0).float()
399
+ frames_per_token = l / ctx
400
+ idx = torch.arange(ctx, device=self.device, dtype=self.dtype)
401
+ src_idx = (idx * frames_per_token).long().clamp(0, l-1)
402
+ batch_idx = torch.arange(b, device=self.device, dtype=self.dtype).unsqueeze(1)
403
+ f0 = f0[batch_idx, src_idx]
404
+ return f0.squeeze(0).float()
405
+
406
+ def get_pitch_bias(self, f0):
407
+ if f0 is None:
408
+ return None
409
+ f0_flat = f0.squeeze().float()
410
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
411
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
412
+ f0_norm.unsqueeze(1)) * self.pitch_scale)
413
+ return f0_sim.unsqueeze(0).unsqueeze(0)
414
 
415
+ def forward(self, x=None, f0=None, layer=None, input_type="audio") -> Tensor:
416
  if isinstance(x, int):
417
  ctx = x
418
+ elif isinstance(x, torch.Tensor) and x.ndim == 3:
419
  batch, ctx, dims = x.shape
420
+ else:
421
+ batch, head, ctx, head_dim = x.shape
422
+
423
+ if self.use_2d_axial and input_type == "spectrogram":
424
+ freqs_2d = self.compute_2d_axial_freqs(ctx)
425
+ if freqs_2d is not None:
426
+ return freqs_2d.unsqueeze(0)
427
+
428
+ t = torch.arange(ctx, device=self.device, dtype=self.dtype)
429
+
430
  if f0 is not None:
431
+ f0_mean = f0.mean() + 1e-8
432
+ theta = f0_mean * self.pitch_scale
433
+ freqs = 1.0 / (theta ** (torch.arange(0, self.dims, 2, device=self.device, dtype=self.dtype)[:(self.dims // 2)].float() / self.dims))
434
+ if "rotary" in self.debug:
435
+ print(f"{layer}: {theta:.2f} : {f0_mean:.2f} : {ctx} ")
436
+ else:
437
  freqs = self.freqs
438
 
439
+ freqs = t[:, None] * freqs[None, :]
440
+
 
441
  if self.radii:
 
 
 
 
 
 
 
 
 
 
442
  if f0 is not None:
443
+ radius = self.align_f0(f0, ctx)
444
+ else:
445
+ radius = self.radius
446
+ if "rotary" in self.debug:
447
+ print(f"{layer} radius: {radius} ctx: {ctx}")
448
+ else:
449
+ radius = freqs
450
+
451
+ freqs = torch.polar(torch.ones_like(radius), freqs.unsqueeze(0))
452
+
453
+ self._counter += 1
454
+ return freqs.unsqueeze(0)
 
 
455
 
456
  @staticmethod
457
  def apply_rotary(x, freqs):
458
+ x1 = x[..., :freqs.shape[-1]*2]
459
+ x2 = x[..., freqs.shape[-1]*2:]
460
+ orig_shape = x1.shape
461
+ if x1.ndim == 2:
462
+ x1 = x1.unsqueeze(0)
463
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
464
+ x1 = torch.view_as_complex(x1) * freqs
465
+ x1 = torch.view_as_real(x1).flatten(-2)
466
+ x1 = x1.view(orig_shape)
467
+ return torch.cat([x1.type_as(x), x2], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  class MultiheadA(nn.Module):
470
  _seen = set()
471
  rbf = False
472
  def __init__(self, dims: int, head: int, rotary_emb: bool = True,
473
+ zero_val: float = 1e-4, minz: float = 1e-6, maxz: float = 1e-3, debug: List[str] = [], optim_attn=False):
 
474
  super(MultiheadA, self).__init__()
475
 
476
  self.dims = dims
477
  self.head = head
478
  self.head_dim = dims // head
 
 
 
 
 
 
 
 
479
  self.debug = debug
480
  self._counter = 0
481
 
482
+ self.q = Linear(dims, dims).to(device, dtype)
483
+ self.k = Linear(dims, dims, bias=False).to(device, dtype)
484
+ self.v = Linear(dims, dims).to(device, dtype)
485
+ self.o = Linear(dims, dims).to(device, dtype)
486
+
487
  self.pad_token = 0
488
  self.rotary_emb = rotary_emb
489
  self.minz = minz
490
  self.maxz = maxz
491
  self.zero_val = zero_val
492
  self.optim_attn = optim_attn
493
+ self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
494
 
495
  if rotary_emb:
496
+ self.rope_2d = rotary(
497
+ dims=self.head_dim,
498
+ use_2d_axial=False,
499
+ spec_shape=(1500, 128),
500
+ debug=debug
501
+ )
502
  self.rope = rotary(
503
  dims=self.head_dim,
504
  debug = debug,
 
509
  learned_radius=False,
510
  )
511
  else:
512
+ self.rope_2d = None
513
+ self.rope = None
514
+
515
  def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
516
  scale = (self.dims // self.head) ** -0.25
517
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
 
524
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
525
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
526
 
527
+ def forward(self, x: Tensor, xa: Tensor = None, f0: Tensor = None, mask: Tensor = None, layer = None, feature_type="audio") -> tuple:
528
+ x = x.to(device, dtype)
529
+ if xa is not None:
530
+ xa = xa.to(device, dtype)
531
+
532
+ batch, ctx, dims = x.shape
533
  scale = (self.dims // self.head) ** -0.25
534
 
535
+ z = default(xa, x).to(device, dtype)
536
+ q = self.q(x)
537
+ k = self.k(z)
538
+ v = self.v(z)
539
+ qlen = q.shape[1]
540
+ klen = k.shape[1]
541
 
542
  if self.rotary_emb:
 
 
 
543
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
544
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
545
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
546
+ qlen = q.shape[2]
547
+ klen = k.shape[2]
548
+
549
+ if feature_type == "spectrogram":
550
+ input_type="spectrogram"
551
+ else:
552
+ input_type="audio"
553
+ q = self.rope.apply_rotary(q, (self.rope(qlen, f0=f0, layer=layer, input_type=input_type)))
554
+ k = self.rope.apply_rotary(k, (self.rope(klen, f0=f0, layer=layer, input_type=input_type)))
555
  else:
556
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
557
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
558
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
559
  batch, head, ctx, head_dim = q.shape
560
+
561
  if self.rbf:
562
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
563
 
564
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
565
+ if f0 is not None and self.rope.use_pbias:
566
+ pbias = self.rope.use_pbias(f0)
567
  if pbias is not None:
568
  qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
569
  token_ids = k[:, :, :, 0]
570
  zscale = torch.ones_like(token_ids)
571
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
572
+ zscale[token_ids.float() == self.pad_token] = fzero
573
 
574
  if mask is not None:
575
  mask = mask[:q.shape[2], :q.shape[2]]
 
579
  wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
580
 
581
  if "multihead" in self.debug and self._counter % 100 == 0:
582
+ print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
 
 
583
  self._counter += 1
584
  return self.o(wv), qk.detach()
585
 
 
628
  s = self.s_gate(x) * s_feat
629
  w = self.w_gate(x) * w_feat
630
  p = self.p_gate(x) * p_feat
 
631
  comb = torch.cat([s, w, p], dim=-1)
632
  return self.integ(comb)
633
 
 
637
  tgate=True, mgate=False, cgate=False, mem_size=512, features=None):
638
  super().__init__()
639
 
 
 
640
  self.dims = dims
641
  self.head = head
642
  self.ctx = ctx
 
676
  if not any([t_gate, m_gate, c_gate]):
677
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
678
 
679
+ def forward(self, x, xa=None, mask=None, f0=None, mode=None, layer=None, feature_type="audio") -> Tensor:
680
+ x = x.to(device, dtype)
681
+ if xa is not None:
682
+ xa = xa.to(device, dtype)
683
+
684
  bln = self.blend
685
+ x = x + self.attna(self.lna(x), xa=None, mask=mask, f0=f0, layer=layer, feature_type=feature_type)[0]
686
 
687
  if self.attnb and xa is not None:
688
+ c = self.attnb(self.lnb(x), xa=xa, f0=f0, mask=None, layer=layer, feature_type=feature_type)[0]
689
  b = torch.sigmoid(bln)
690
  x = b * x + (1 - b) * c
691
 
 
700
  gate = self.m_gate(normx)
701
  x = x + gate * mlp_out
702
 
703
+ elif self.c_gate and mode is not None:
704
  gate_output = self.c_gate(normx, self.features)
705
  x = x + gate_output
706
 
 
725
 
726
  return x
727
 
728
+ class FEncoder(nn.Module):
729
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
730
  super().__init__()
731
 
732
+ self.head = head
733
+ self.head_dim = dims // head
734
+ self.dropout = 0.01
735
+ self.use_rope = use_rope
736
+ self.dims = dims
737
 
738
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
739
  act_fn = act_map.get(act, nn.GELU())
740
 
741
  self.encoder = nn.Sequential(
742
+ Conv1d(input_dims, dims, kernel_size=kernel_size, stride=stride, padding=kernel_size//2), act_fn,
743
+ Conv1d(dims, dims, kernel_size=5, padding=2), act_fn,
744
+ Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims), act_fn)
745
 
746
+ if use_rope:
747
+ if spec_shape is not None:
748
+ self.rope = rotary(
749
+ dims=self.head_dim,
750
+ use_2d_axial=True,
751
+ spec_shape=spec_shape, debug=[])
752
+ else:
753
+ self.rope = rotary(
754
+ dims=self.head_dim,
755
+ use_2d_axial=False, debug=[])
756
+ else:
757
+ self.rope = None
758
+ self.positional = lambda length: sinusoids(length, dims)
759
+
760
+ self.norm = RMSNorm(dims)
761
+ self._norm = RMSNorm(dims)
762
+
763
+ def apply_rope_to_features(self, x, f0=None, layer=None, feature_type="audio"):
764
+ if not self.use_rope or self.rope is None:
765
+ return x
766
+
767
+ batch, seq_len, dims = x.shape
768
+ x = x.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3)
769
+ if feature_type == "spectrogram" and hasattr(self.rope, 'use_2d_axial') and self.rope.use_2d_axial:
770
+ rope_freqs = self.rope(seq_len, f0=f0, layer=layer, input_type="spectrogram")
771
+ else:
772
+ rope_freqs = self.rope(seq_len, f0=f0, layer=layer, input_type="audio")
773
+ x = self.rope.apply_rotary(x, rope_freqs)
774
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, dims)
775
+ return x
776
+
777
+ def forward(self, x, f0=None, layer=None, feature_type="audio"):
778
  x = self.encoder(x).permute(0, 2, 1)
779
+ if self.use_rope:
780
+ x = self.apply_rope_to_features(x, f0=f0, layer=layer, feature_type=feature_type)
781
+ else:
782
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
783
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
784
+ x = self._norm(x)
785
  return x
786
+
787
  class WEncoder(nn.Module):
788
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
789
  super().__init__()
790
 
791
+ self.head = head
792
  self.head_dim = dims // head
793
  self.dropout = 0.01
794
+ self.use_rope = use_rope
795
+ self.dims = dims
796
 
797
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
798
  act_fn = act_map.get(act, nn.GELU())
 
805
  self.encoder = nn.Sequential(
806
  Conv1d(dims, dims, kernel_size=3, padding=1, groups=dims//8), act_fn,
807
  Conv1d(dims, dims, kernel_size=1), act_fn)
808
+ if use_rope:
809
+ self.rope = rotary(
810
+ dims=self.head_dim,
811
+ use_2d_axial=False,
812
+ theta=50.0, debug=[])
813
+ else:
814
+ self.rope = None
815
+ self.positional = lambda length: sinusoids(length, dims)
816
  self.norm = RMSNorm(dims)
817
+
818
+ def apply_rope_to_features(self, x, f0=None, layer=None):
819
+ if not self.use_rope or self.rope is None:
820
+ return x
821
+ batch, seq_len, dims = x.shape
822
+ x = x.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3)
823
+ rope_freqs = self.rope(seq_len, f0=f0, layer=layer, input_type="waveform")
824
+ x = self.rope.apply_rotary(x, rope_freqs)
825
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, dims)
826
+ return x
827
 
828
+ def forward(self, x, f0=None, layer=None, feature_type="waveform"):
829
  x = self.downsample(x)
830
  x = self.encoder(x)
831
  x = x.permute(0, 2, 1)
832
+ if self.use_rope:
833
+ x = self.apply_rope_to_features(x, f0=f0, layer=layer)
834
+ else:
835
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
836
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
837
  return self.norm(x)
838
 
839
+ class PEncoder(nn.Module):
840
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False):
841
  super().__init__()
842
 
843
+ self.head = head
844
+ self.head_dim = dims // head
845
+ self.dropout = 0.01
846
+ self.use_rope = use_rope
847
+ self.dims = dims
848
 
849
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
850
  act_fn = act_map.get(act, nn.GELU())
851
 
852
  self.encoder = nn.Sequential(
853
+ Conv1d(input_dims, dims//4, kernel_size=7, stride=8, padding=3), act_fn,
854
+ Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
855
+ Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2), act_fn)
856
 
857
+ if use_rope:
858
+ self.rope = rotary(
859
+ dims=self.head_dim,
860
+ use_2d_axial=False,
861
+ theta=100.0, debug=[])
862
+ else:
863
+ self.rope = None
864
+ self.positional = lambda length: sinusoids(length, dims)
865
  self.norm = RMSNorm(dims)
 
866
 
867
+ def apply_rope_to_features(self, x, f0=None, layer=None):
868
+ if not self.use_rope or self.rope is None:
869
+ return x
870
+ batch, seq_len, dims = x.shape
871
+ x = x.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3)
872
+ rope_freqs = self.rope(seq_len, f0=f0, layer=layer, input_type="pitch")
873
+ x = self.rope.apply_rotary(x, rope_freqs)
874
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, dims)
875
  return x
 
 
 
 
876
 
877
+ def forward(self, x, f0=None, layer=None, feature_type="pitch"):
878
+ x = self.encoder(x).permute(0, 2, 1)
879
+ if self.use_rope:
880
+ x = self.apply_rope_to_features(x, f0=f0, layer=layer)
881
+ else:
882
+ x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
884
+ x = self.norm(x)
885
  return x
886
 
887
  class AudioEncoder(nn.Module):
 
894
  self.head = head
895
  self.ctx = ctx
896
  self.head_dim = dims // head
 
 
 
 
 
897
  self.debug = debug
898
  self._counter = 0
899
 
900
  self.features = features
901
  self.dropout = 0.01
 
902
 
903
  self.rope = rotary(
904
+ dims=self.head_dim,
905
+ )
906
 
907
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(),
908
  "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
 
939
  FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)
940
  for _ in range(layer)])
941
 
942
+ self.rope_2d = rotary(
943
+ dims=self.head_dim,
944
+ use_2d_axial=True,
945
+ spec_shape=(ctx, mels),
946
+ debug=debug
947
+ )
948
+
949
+ self.rope_1d = rotary(
950
+ dims=self.head_dim,
951
+ use_2d_axial=False,
952
+ debug=debug
953
+ )
954
+
955
+ def forward(self, enc, f0=None, layer="ENC"):
956
+ enc = dict_to(enc, device, dtype)
957
+
958
  if self._counter < 1:
959
+ s = enc.get("spectrogram")
960
+ w = enc.get("waveform")
961
+ p = f0 if f0 is not None else default(enc.get("pitch"), enc.get("f0"))
962
  plot_waveform(x=s, w=w, p=p, hop_length=128)
963
 
964
+ out = {}
965
+ out.update(enc)
966
+
967
  for f in self.features:
968
+ if f in enc and f in self.blocks:
969
+ x = enc[f]
970
  for block in self.blocks[f]:
971
+ x = block(x, f0=f0, layer=layer, feature_type=f)
972
+ out[f] = x
973
+
974
  if "encoder" in self.debug and self._counter % 100 == 0:
975
+ names = list(x.keys())
976
+ shapes = {k: v.shape for k, v in x.items()}
977
+ print(f"Step {self._counter}: mode: {names}: shapes: {shapes}")
 
 
 
 
978
  self._counter += 1
979
+ return out
980
 
981
  class TextDecoder(nn.Module):
982
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
 
988
  self.ctx = ctx
989
  self.head_dim = dims // head
990
 
991
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
992
+ self.dtype = torch.float32
 
 
993
  self.debug = debug
994
  self._counter = 0
995
 
 
1016
 
1017
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1018
  self.register_buffer("mask", mask, persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
 
1020
+ def forward(self, x, enc, f0=None, order=None, layer='DEC') -> Tensor:
1021
+ enc = dict_to(enc, device, dtype)
 
1022
  x = x.to(device)
1023
+ bln = self.blend
1024
+
1025
  if order is None:
1026
  order = self.features
1027
+
1028
  mask = self.mask[:x.shape[1], :x.shape[1]]
1029
  x = self.token(x) + self.positional[:x.shape[1]]
1030
  x = F.dropout(x, p=self.dropout, training=self.training)
1031
+
1032
  for block in self.block:
1033
+ x = block(x, xa=None, f0=f0, mask=mask, layer=layer)
1034
 
1035
  for f in order:
1036
+ if f in enc:
1037
+ xa = enc[f]
1038
  for block in self.blocks[f]:
1039
+ out = block(x=x, xa=xa, f0=f0, mask=None, layer=layer)
1040
+
1041
  a = torch.sigmoid(bln[f])
1042
  x = a * out + (1 - a) * x
1043
+
 
1044
  if "decoder" in self.debug and self._counter % 100 == 0:
1045
+ print(f"Step {self._counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
1046
+ self._counter += 1
1047
+
1048
+ x = self.ln_dec(x)
1049
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1050
 
1051
  class Echo(nn.Module):
 
1125
  if f0d is not None:
1126
  encoder_inputs["f0d"] = f0d
1127
 
1128
+ encoder_outputs = self.encoder(encoder_inputs, f0=f0)
1129
+ logits = self.decoder(input_ids, encoder_outputs, f0=f0d)
1130
 
1131
  loss = None
1132
  if labels is not None:
 
1143
  "encoder_output": encoder_outputs,
1144
  }
1145
 
1146
+ @property
1147
  def device(self):
1148
  return next(self.parameters()).device
1149
  @property
 
1198
  print(f"{module_type}: {count}")
1199
 
1200
  def register_gradient_hooks(self):
1201
+ """Add this method to your Echo model class"""
1202
  for name, param in self.named_parameters():
1203
  if param.requires_grad:
1204
  if "encoder" in name:
 
1223
  return None
1224
 
1225
  def reset_counter(self):
1226
+ """Reset the internal counter for debugging purposes."""
1227
  self._counter = 0
1228
  print("Counter reset to 0.")
1229
 
1230
+ metric = evaluate.load(path="wer")
1231
+
1232
+ def align_f0(f0, ctx):
1233
+ ctx = torch.tensor(ctx)
1234
+ bat, length = f0.shape
1235
+ if length == ctx:
1236
+ return f0
1237
+ frames = length / ctx
1238
+ idx = torch.arange(ctx, device=f0.device)
1239
+ idx = (idx * frames).long()
1240
+ batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
1241
+ return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
1242
+
1243
+ @dataclass
1244
+ class DataCollator:
1245
+ tokenizer: Any
1246
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1247
+ pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1248
+ bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
1249
+
1250
+ batch = {}
1251
+
1252
+ if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
1253
+ spectrogram_list = [f["spectrogram"] for f in features]
1254
+ max_len_feat = max(f.shape[-1] for f in spectrogram_list)
1255
+ pad_spectrogram = []
1256
+ for feat in spectrogram_list:
1257
+ current_len = feat.shape[-1]
1258
+ padding = max_len_feat - current_len
1259
+ if padding > 0:
1260
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1261
+ else:
1262
+ pad_feat = feat
1263
+ pad_spectrogram.append(pad_feat)
1264
+ batch["spectrogram"] = torch.stack(pad_spectrogram)
1265
+
1266
+ if "waveform" in features[0] and features[0]["waveform"] is not None:
1267
+ waveform_list = [f["waveform"] for f in features]
1268
+ max_len_wav = max(w.shape[-1] for w in waveform_list)
1269
+ pad_waveforms = []
1270
+ for wav in waveform_list:
1271
+ current_len = wav.shape[-1]
1272
+ padding = max_len_wav - current_len
1273
+ if padding > 0:
1274
+ if wav.ndim == 1:
1275
+ wav = wav.unsqueeze(0)
1276
+ pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1277
+ else:
1278
+ pad_wav = wav
1279
+ pad_waveforms.append(pad_wav)
1280
+ batch["waveform"] = torch.stack(pad_waveforms)
1281
+
1282
+ if "label" in features[0] and features[0]["label"] is not None:
1283
+ labels_list = [f["label"] for f in features]
1284
+ max_len = max(len(l) for l in labels_list)
1285
+ all_ids = []
1286
+ all_labels = []
1287
+
1288
+ for label in labels_list:
1289
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1290
+ decoder_input = [bos_token_id] + label_list
1291
+ label_eos = label_list + [pad_token_id]
1292
+ input_len = max_len + 1 - len(decoder_input)
1293
+ label_len = max_len + 1 - len(label_eos)
1294
+ padded_input = decoder_input + [pad_token_id] * input_len
1295
+ padded_labels = label_eos + [pad_token_id] * label_len
1296
+ all_ids.append(padded_input)
1297
+ all_labels.append(padded_labels)
1298
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1299
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1300
+
1301
+ if "pitch" in features[0] and features[0]["pitch"] is not None:
1302
+ pitch_list = [f["pitch"] for f in features]
1303
+ max_len_pitch = max(e.shape[-1] for e in pitch_list)
1304
+ pad_pitch = []
1305
+ for pitch in pitch_list:
1306
+ current_len = pitch.shape[-1]
1307
+ padding = max_len_pitch - current_len
1308
+ if padding > 0:
1309
+ pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1310
+ else:
1311
+ pad_pitch_item = pitch
1312
+ pad_pitch.append(pad_pitch_item)
1313
+ batch["pitch"] = torch.stack(pad_pitch)
1314
+
1315
+ if "f0" in features[0] and features[0]["f0"] is not None:
1316
+ input_ids_batch = batch.get("input_ids", None)
1317
+ if input_ids_batch is not None:
1318
+ target_length = input_ids_batch.shape[-1]
1319
+ aligned_list = []
1320
+ original_list = []
1321
+ for feature in features:
1322
+ f0 = feature["f0"]
1323
+ original_list.append(f0)
1324
+ if f0.shape[-1] != target_length:
1325
+ aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
1326
+ else:
1327
+ aligned_f0 = f0
1328
+ aligned_list.append(aligned_f0)
1329
+ batch["f0d"] = torch.stack(aligned_list)
1330
+ batch["f0"] = torch.stack(original_list)
1331
+
1332
+ if "envelope" in features[0] and features[0]["envelope"] is not None:
1333
+ env_list = [f["envelope"] for f in features]
1334
+ max_len = max(f.shape[-1] for f in env_list)
1335
+ pad_env = []
1336
+ for feat in env_list:
1337
+ current_len = feat.shape[-1]
1338
+ padding = max_len_feat - current_len
1339
+ if padding > 0:
1340
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1341
+ else:
1342
+ pad_feat = feat
1343
+ pad_env.append(pad_feat)
1344
+ batch["envelope"] = torch.stack(pad_env)
1345
+
1346
+ if "phase" in features[0] and features[0]["phase"] is not None:
1347
+ ph_list = [f["phase"] for f in features]
1348
+ max_len = max(f.shape[-1] for f in ph_list)
1349
+ pad_ph = []
1350
+ for feat in ph_list:
1351
+ current_len = feat.shape[-1]
1352
+ padding = max_len_feat - current_len
1353
+ if padding > 0:
1354
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1355
+ else:
1356
+ pad_feat = feat
1357
+ pad_ph.append(pad_feat)
1358
+ batch["phase"] = torch.stack(pad_ph)
1359
+
1360
+ return batch
1361
+
1362
+ def hilbert_transform(x):
1363
+ N = x.shape[-1]
1364
+ xf = torch.fft.rfft(x)
1365
+ h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
1366
+ if N % 2 == 0:
1367
+ h[0] = h[N//2] = 1
1368
+ h[1:N//2] = 2
1369
+ else:
1370
+ h[0] = 1
1371
+ h[1:(N+1)//2] = 2
1372
+ return torch.fft.irfft(xf * h, n=N)
1373
+
1374
+ def analytic_signal(x):
1375
+ return x + 1j * hilbert_transform(x)
1376
+
1377
+ def hilbert_transform_2d(x, dim=-1):
1378
+ N = x.shape[dim]
1379
+ if dim == -1 or dim == len(x.shape) - 1:
1380
+ xf = torch.fft.rfft(x)
1381
+ else:
1382
+ xf = torch.fft.rfft(x, dim=dim)
1383
+ h_shape = [1] * len(x.shape)
1384
+ h_shape[dim] = N // 2 + 1
1385
+ h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
1386
+ if dim == -1 or dim == len(x.shape) - 1:
1387
+ if N % 2 == 0:
1388
+ h[..., 0] = h[..., -1] = 1
1389
+ h[..., 1:-1] = 2
1390
+ else:
1391
+ h[..., 0] = 1
1392
+ h[..., 1:] = 2
1393
+ else:
1394
+ pass
1395
+ return torch.fft.irfft(xf * h, n=N, dim=dim)
1396
+
1397
+ def hilbert_transform_true_2d(x):
1398
+ xf = torch.fft.rfft2(x)
1399
+ h1, h2 = torch.meshgrid(
1400
+ torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
1401
+ torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
1402
+ indexing='ij')
1403
+ h = -1j / (math.pi * (h1 + 1j*h2))
1404
+ h[0, 0] = 0
1405
+ return torch.fft.irfft2(xf * h.to(x.device))
1406
+
1407
+ def process_spectrogram_with_hilbert(spec):
1408
+ analytic = spec + 1j * hilbert_transform(spec)
1409
+ envelope = torch.abs(analytic)
1410
+ phase = torch.angle(analytic)
1411
+ return envelope, phase
1412
+
1413
+ def load_wave(wave_data, sample_rate):
1414
+ if isinstance(wave_data, str):
1415
+ waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1416
+ elif isinstance(wave_data, dict):
1417
+ waveform = torch.tensor(data=wave_data["array"]).float()
1418
+ sr = wave_data["sampling_rate"]
1419
+ else:
1420
+ raise TypeError("Invalid wave_data format.")
1421
+
1422
+ if waveform.dim() == 1:
1423
+ waveform = waveform.unsqueeze(0)
1424
+
1425
+ if sr != sample_rate:
1426
+ original_length = waveform.shape[1]
1427
+ target_length = int(original_length * (sample_rate / sr))
1428
+
1429
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1430
+ waveform = resampler(waveform)
1431
+
1432
+ return waveform.flatten()
1433
+
1434
+ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1435
+ hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
1436
+ pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1437
+ norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1438
+
1439
+ dtype = torch.float32
1440
+ device = torch.device("cuda:0")
1441
+ audio = batch["audio"]
1442
+ sampling_rate = audio["sampling_rate"]
1443
+ sr = audio["sampling_rate"]
1444
+ wav = load_wave(wave_data=audio, sample_rate=sr)
1445
+
1446
+ if spectrogram:
1447
+ transform = torchaudio.transforms.MelSpectrogram(
1448
+ f_max=fmax,
1449
+ f_min=fmin,
1450
+ n_mels=n_mels,
1451
+ sample_rate=sr,
1452
+ n_fft=n_fft,
1453
+ hop_length=hop_length,
1454
+ norm=norm,
1455
+ normalized=normalized,
1456
+ power=power,
1457
+ center=center,
1458
+ mel_scale=mel_scale,
1459
+ window_fn=window_fn,
1460
+ pad_mode=pad_mode)
1461
+
1462
+ mel_spectrogram = transform(wav)
1463
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1464
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1465
+ spec = (log_mel + 4.0) / 4.0
1466
+ spec = torch.tensor(spec)
1467
+ batch["spectrogram"] = spec
1468
+
1469
+ if hilbert:
1470
+ envelope_list = []
1471
+ phase_list = []
1472
+
1473
+ for ch_idx in range(spec.shape[0]):
1474
+ envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
1475
+ envelope_list.append(envelope)
1476
+ phase_list.append(phase)
1477
+
1478
+ batch["envelope"] = torch.stack(envelope_list)
1479
+ batch["phase"] = torch.stack(phase_list)
1480
+
1481
+ wav_1d = wav.unsqueeze(0)
1482
+
1483
+ if waveforms:
1484
+ batch["waveform"] = wav_1d
1485
+
1486
+ if pitch:
1487
+ wav_np = wav.numpy().astype(np.float64)
1488
+ f0, t = pw.dio(wav_np, sampling_rate,
1489
+ frame_period=hop_length/sampling_rate*1000)
1490
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1491
+ f0 = torch.from_numpy(f0).float()
1492
+ batch["pitch"] = f0.unsqueeze(0)
1493
+
1494
+ if frequency:
1495
+ wav_np = wav.numpy().astype(np.float64)
1496
+ f0, t = pw.dio(wav_np, sampling_rate, frame_period=hop_length/sampling_rate*1000)
1497
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1498
+ f0 = torch.from_numpy(f0).float()
1499
+ batch["f0"] = f0
1500
+
1501
+ if spectrogram and waveforms and pitch:
1502
+ spec_mean = batch["spectrogram"].mean()
1503
+ spec_std = batch["spectrogram"].std() + 1e-6
1504
+ batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
1505
+
1506
+ wav_mean = batch["waveform"].mean()
1507
+ wav_std = batch["waveform"].std() + 1e-6
1508
+ batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
1509
+
1510
+ if batch["pitch"].max() > 1.0:
1511
+ pitch_min = 50.0
1512
+ pitch_max = 500.0
1513
+ batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1514
+
1515
+ batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1516
+ return batch
1517
+
1518
+ def compute_metrics(eval_pred, compute_result: bool = True,
1519
+ print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
1520
+
1521
+ pred_logits = eval_pred.predictions
1522
+ label_ids = eval_pred.label_ids
1523
+
1524
+ if hasattr(pred_logits, "cpu"):
1525
+ pred_logits = pred_logits.cpu()
1526
+ if hasattr(label_ids, "cpu"):
1527
+ label_ids = label_ids.cpu()
1528
+ if isinstance(pred_logits, tuple):
1529
+ pred_ids = pred_logits[0]
1530
+ else:
1531
+ pred_ids = pred_logits
1532
+ if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1533
+ if not isinstance(pred_ids, torch.Tensor):
1534
+ pred_ids = torch.tensor(pred_ids)
1535
+ pred_ids = pred_ids.argmax(dim=-1)
1536
+ pred_ids = pred_ids.tolist()
1537
+
1538
+ if hasattr(label_ids, "tolist"):
1539
+ label_ids = label_ids.tolist()
1540
+
1541
+ label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1542
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1543
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1544
+
1545
+ if print_pred:
1546
+ for i in range(min(num_samples, len(pred_str))):
1547
+ print(f"Preds: {pred_str[i]}")
1548
+ print(f"Label: {label_str[i]}")
1549
+ print(f"preds: {pred_ids[i]}")
1550
+ print(f"label: {label_ids[i]}")
1551
+ print("--------------------------------")
1552
+
1553
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1554
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1555
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1556
+
1557
+ if model is None:
1558
+ global global_model
1559
+ if 'global_model' in globals():
1560
+ model = global_model
1561
+
1562
+ if model is not None:
1563
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1564
+ if trainable_params > 0:
1565
+ efficiency_score = (100 - wer) / trainable_params
1566
+ else:
1567
+ print("Warning: Zero trainable parameters detected")
1568
+ efficiency_score = 0.0
1569
+ else:
1570
+ print("Warning: Model not available for parameter counting")
1571
+ trainable_params = 0.0
1572
+ efficiency_score = 0.0
1573
+
1574
+ if hasattr(wer, "item"):
1575
+ wer = wer.item()
1576
+
1577
+ metrics = {
1578
+ "wer": float(wer),
1579
+ "trainable_params_M": float(trainable_params),
1580
+ "efficiency_score": float(efficiency_score),
1581
+ }
1582
+
1583
+ return metrics
1584
+
1585
+ logger = logging.getLogger(__name__)
1586
+
1587
+ def create_model(param: Dimensions) -> Echo:
1588
+ model = Echo(param).to('cuda')
1589
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1590
+ total_params = sum(p.numel() for p in model.parameters())
1591
+ logger.info(f"Trainable parameters: {trainable_params:,}")
1592
+ logger.info(f"Total parameters: {total_params:,}")
1593
+ print(f"Trainable parameters: {trainable_params:,}")
1594
+ print(f"Total parameters: {total_params:,}")
1595
+
1596
+ return model
1597
+
1598
+ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
1599
+ from tokenizers import Tokenizer
1600
+ tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
1601
+ orig_encode = tokenizer.encode
1602
+ def enc(text, add_special_tokens=True):
1603
+ ids = orig_encode(text).ids
1604
+ if not add_special_tokens:
1605
+ sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1606
+ ids = [id for id in ids if id not in sp_ids]
1607
+ return ids
1608
+ def bdec(ids_list, skip_special_tokens=True):
1609
+ results = []
1610
+ for ids in ids_list:
1611
+ if skip_special_tokens:
1612
+ ids = [id for id in ids if id not in [0, 1, 2]]
1613
+ results.append(tokenizer.decode(ids))
1614
+ return results
1615
+ def save_pretrained(save_dir):
1616
+ os.makedirs(save_dir, exist_ok=True)
1617
+ tokenizer.save(f"{save_dir}/tokenizer.json")
1618
+ tokenizer.encode = enc
1619
+ tokenizer.batch_decode = bdec
1620
+ tokenizer.save_pretrained = save_pretrained
1621
+ tokenizer.pad_token_id = 0
1622
+ tokenizer.bos_token_id = 1
1623
+ tokenizer.eos_token_id = 2
1624
+ return tokenizer
1625
+
1626
+ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1627
+ if dataset_config is None:
1628
+ dataset_config = {
1629
+ "spectrogram": True,
1630
+ "waveforms": True,
1631
+ "pitch": True,
1632
+ "frequency": True,
1633
+ "downsamples": True,
1634
+ "hop_length": 128,
1635
+ "fmin": 50,
1636
+ "fmax": 2000,
1637
+ "n_mels": 128,
1638
+ "n_fft": 1024,
1639
+ "sampling_rate": 16000,
1640
+ }
1641
+
1642
+ dataset = load_dataset(
1643
+ "google/fleurs",
1644
+ "en_us",
1645
+ token=token,
1646
+ trust_remote_code=True,
1647
+ streaming=False)
1648
+
1649
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1650
+
1651
+ if sanity_check:
1652
+ dataset = dataset["test"].take(10)
1653
+ dataset = dataset.select_columns(["audio", "transcription"])
1654
+ logger.info(f"Sanity dataset size: {dataset.num_rows}")
1655
+ print(f"Sanity dataset size: {dataset.num_rows}")
1656
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1657
+
1658
+ dataset = dataset.map(
1659
+ function=prepare_fn,
1660
+ remove_columns=["audio", "transcription"]
1661
+ ).with_format(type="torch")
1662
+ train_dataset = dataset
1663
+ test_dataset = dataset
1664
+ else:
1665
+ def filter_func(x):
1666
+ return (0 < len(x["transcription"]) < 512 and
1667
+ len(x["audio"]["array"]) > 0 and
1668
+ len(x["audio"]["array"]) < 1500 * 160)
1669
+
1670
+ dataset = dataset.filter(filter_func).shuffle(seed=4)
1671
+ logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1672
+ print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1673
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1674
+ columns_to_remove = list(next(iter(dataset.values())).features)
1675
+ train_dataset = dataset["train"]
1676
+ test_dataset = dataset["test"].take(50)
1677
+ logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
1678
+
1679
+ train_dataset = train_dataset.map(
1680
+ function=prepare_fn,
1681
+ remove_columns=columns_to_remove
1682
+ ).with_format(type="torch")
1683
+
1684
+ test_dataset = test_dataset.map(
1685
+ function=prepare_fn,
1686
+ remove_columns=columns_to_remove
1687
+ ).with_format(type="torch")
1688
+
1689
+ return train_dataset, test_dataset
1690
+
1691
+ def get_training_args(
1692
+ log_dir: str,
1693
+ batch_eval_metrics: bool = False,
1694
+ max_steps: int = 10,
1695
+ save_steps: int = 1000,
1696
+ eval_steps: int = 1,
1697
+ warmup_steps: int = 0,
1698
+ num_train_epochs: int = 1,
1699
+ logging_steps: int = 1,
1700
+ eval_on_start: bool = False,
1701
+ learning_rate: float = 1e-4,
1702
+ weight_decay: float = 0.01,
1703
+ max_grad_norm: float = 1.0,
1704
+ ) -> Seq2SeqTrainingArguments:
1705
+
1706
+ return Seq2SeqTrainingArguments(
1707
+ output_dir=log_dir,
1708
+ per_device_train_batch_size=1,
1709
+ per_device_eval_batch_size=1,
1710
+ gradient_accumulation_steps=1,
1711
+ eval_accumulation_steps=1,
1712
+ eval_strategy="steps",
1713
+ save_strategy="steps",
1714
+ max_steps=max_steps,
1715
+ save_steps=save_steps,
1716
+ eval_steps=eval_steps,
1717
+ warmup_steps=warmup_steps,
1718
+ num_train_epochs=num_train_epochs,
1719
+ logging_steps=logging_steps,
1720
+ logging_dir=log_dir,
1721
+ logging_strategy="steps",
1722
+ report_to=["tensorboard"],
1723
+ push_to_hub=False,
1724
+ disable_tqdm=False,
1725
+ save_total_limit=1,
1726
+ label_names=["labels"],
1727
+ optim="adamw_torch",
1728
+ lr_scheduler_type="cosine",
1729
+ learning_rate=learning_rate,
1730
+ weight_decay=weight_decay,
1731
+ save_safetensors=False,
1732
+ eval_on_start=eval_on_start,
1733
+ batch_eval_metrics=batch_eval_metrics,
1734
+ max_grad_norm=max_grad_norm,
1735
+ )
1736
+
1737
+ def main():
1738
+
1739
+ token = ""
1740
+ log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H_%M_%S'))
1741
+ os.makedirs(name=log_dir, exist_ok=True)
1742
+ tokenizer = setup_tokenizer(token)
1743
+
1744
+ def sanity(sanity: bool):
1745
+
1746
+ if sanity:
1747
+ training_args = get_training_args(
1748
+ log_dir,
1749
+ batch_eval_metrics = False,
1750
+ max_steps = 10,
1751
+ save_steps = 0,
1752
+ eval_steps = 1,
1753
+ warmup_steps = 0,
1754
+ logging_steps = 1,
1755
+ eval_on_start = False,
1756
+ learning_rate = 5e-6,
1757
+ weight_decay = 0.01,
1758
+ )
1759
+ else:
1760
+ training_args = get_training_args(
1761
+ log_dir,
1762
+ batch_eval_metrics = False,
1763
+ max_steps = 1000,
1764
+ save_steps = 1005,
1765
+ eval_steps = 100,
1766
+ warmup_steps = 100,
1767
+ logging_steps = 10,
1768
+ eval_on_start = False,
1769
+ learning_rate = 2.5e-4,
1770
+ weight_decay = 0.01,
1771
+ )
1772
+
1773
+ return training_args
1774
+
1775
+ param = Dimensions(
1776
+ mels=128,
1777
+ aud_ctx=1500,
1778
+ aud_head=4,
1779
+ aud_dims=512,
1780
+ aud_idx=4,
1781
+ vocab=40000,
1782
+ text_ctx=512,
1783
+ text_head=4,
1784
+ text_dims=512,
1785
+ text_idx=4,
1786
+ act="swish",
1787
+ debug={"rotary"},
1788
+ cross_attn=True,
1789
+ f0_rotary=False,
1790
+ features = ["spectrogram"]
1791
+ )
1792
+
1793
+ sanity_check = False
1794
+ training_args = sanity(sanity_check)
1795
+ dataset_config = {
1796
+ "spectrogram": True,
1797
+ "waveforms": False,
1798
+ "pitch": False,
1799
+ "downsamples": False,
1800
+ "frequency": False,
1801
+ "hilbert": False,
1802
+ "hop_length": 128,
1803
+ "fmin": 150,
1804
+ "fmax": 2000,
1805
+ "n_mels": 128,
1806
+ "n_fft": 1024,
1807
+ "sampling_rate": 16000,
1808
+ "pad_mode": "constant",
1809
+ "center": True,
1810
+ "power": 2.0,
1811
+ "window_fn": torch.hann_window,
1812
+ "mel_scale": "htk",
1813
+ "norm": None,
1814
+ "normalized": False}
1815
+
1816
+ model = create_model(param)
1817
+
1818
+ global global_model
1819
+ global_model = model
1820
+
1821
+ metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1822
+ tokenizer=tokenizer, model=model)
1823
+
1824
+ print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1825
+ train_dataset, test_dataset = prepare_datasets(
1826
+ tokenizer=tokenizer,
1827
+ token=token,
1828
+ sanity_check=sanity_check,
1829
+ dataset_config=dataset_config)
1830
+
1831
+ trainer = Seq2SeqTrainer(
1832
+ args=training_args,
1833
+ model=model,
1834
+ train_dataset=train_dataset,
1835
+ eval_dataset=test_dataset,
1836
+ data_collator=DataCollator(tokenizer=tokenizer),
1837
+ compute_metrics=metrics_fn,
1838
+ )
1839
+
1840
+ model.init_weights()
1841
+ trainer.train()
1842
+
1843
+ if __name__ == "__main__":
1844
+ main()
1845
+