Sin2pi commited on
Commit
7963664
·
verified ·
1 Parent(s): ba60ad6

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +198 -223
model.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import pyworld as pw
3
  import os
4
  import math
@@ -6,6 +5,7 @@ import warnings
6
  import logging
7
  import gzip
8
  import base64
 
9
  import torch
10
  import torchaudio
11
  import torchcrepe
@@ -16,12 +16,13 @@ import numpy as np
16
  from typing import Optional, Dict, Union, List, Tuple, Any
17
  from functools import partial
18
  from datetime import datetime
19
- from datasets import load_dataset, Audio
20
  from transformers.trainer_seq2seq import Seq2SeqTrainer
21
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
22
  import transformers
23
  import evaluate
24
  from dataclasses import dataclass
 
25
 
26
  torch.backends.cudnn.allow_tf32 = True
27
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -36,6 +37,11 @@ warnings.filterwarnings("ignore")
36
  logging.basicConfig(level=logging.ERROR)
37
  tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
38
 
 
 
 
 
 
39
  extractor = None
40
  tokenizer = None
41
  optimizer = None
@@ -61,7 +67,7 @@ class Dimensions:
61
  cross_attn: bool
62
  features: List[str]
63
  f0_rotary: bool
64
-
65
  def exists(v):
66
  return v is not None
67
 
@@ -99,7 +105,7 @@ class RMSNorm(nn.Module):
99
  self.eps = eps
100
  self.elementwise_affine = elementwise_affine
101
  if self.elementwise_affine:
102
- self.weight = nn.Parameter(torch.empty(self.normalized_shape))
103
  init.ones_(self.weight)
104
  else:
105
  self.register_parameter("weight", None)
@@ -128,6 +134,7 @@ def sinusoids(length, channels, max_timescale=10000):
128
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
129
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
130
 
 
131
  class rotary(nn.Module):
132
  _seen = set()
133
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
@@ -144,32 +151,29 @@ class rotary(nn.Module):
144
  self.dims = dims
145
  self.max_ctx = max_ctx
146
  self.radii = radii
147
- pitch_scale = 1.0
148
- # theta_rescale = 1.0
149
- # theta *= theta_rescale ** (dims / (dims - 2))
 
150
 
151
- self.min_theta = nn.Parameter(
152
- torch.tensor(20.0), requires_grad=learned_theta)
153
- self.max_theta = nn.Parameter(
154
- torch.tensor(400.0), requires_grad=learned_theta)
155
 
156
- self.theta = nn.Parameter(
157
- torch.tensor(float(theta)), requires_grad=learned_theta)
158
 
159
- self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale),
160
- requires_grad=learned_pitch)
161
 
162
  freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
163
  self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
164
 
165
  if radii:
166
- self.radius = nn.Parameter(torch.ones(dims // 2),
167
- requires_grad=learned_radius)
168
 
169
  def get_pitch_bias(self, f0):
170
  if f0 is None:
171
  return None
172
-
173
  f0_flat = f0.squeeze().float()
174
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
175
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
@@ -195,21 +199,18 @@ class rotary(nn.Module):
195
  def align_f0(self, f0, token_length):
196
  batch_size, f0_length = f0.shape
197
  if f0_length == token_length:
198
- return f0
199
  frames_per_token = f0_length / token_length
200
-
201
  indices = torch.arange(token_length, device=f0.device)
202
  indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
203
- #center_positions = ((indices + 0.5) * frames_per_token).long()
204
- batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1)
205
  return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
206
 
207
  def scale_f0(self, f0):
208
  f0_min = f0.min(dim=1, keepdim=True)[0]
209
  f0_max = f0.max(dim=1, keepdim=True)[0]
210
  denom = f0_max - f0_min + 1e-8
211
- normalized_f0 = (f0 - f0_min) / denom
212
- # normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
213
  normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
214
  return normalized_f0
215
 
@@ -231,47 +232,68 @@ class rotary(nn.Module):
231
  log_freq = torch.log(freq)
232
  log_min_freq = torch.log(min_freq)
233
  log_max_freq = torch.log(max_freq)
234
-
235
  mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
236
  return mapped_log_freq
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def forward(self, x=None, f0=None, stage=None) -> Tensor:
239
  if isinstance(x, int):
240
- seq_len = x
241
  else:
242
- batch, seq_len, _ = x.shape
243
- t = torch.arange(seq_len, device=self.device).float()
 
 
 
 
 
 
 
244
 
245
  if f0 is not None:
246
- f0_mean = f0.mean() + 1e-8
247
- theta = self.theta
248
- f0_theta = f0_mean#(theta / 2) * (f0_mean * 1e-2 + 1.0)
249
- freqs = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
250
  else:
251
  freqs = self.freqs
252
 
253
  freqs = torch.einsum('i,j->ij', t, freqs)
254
  freqs = freqs.float()
255
 
256
- if self.radii and f0 is not None:
257
- radius = self.align_f0(f0, seq_len)
 
258
  # radius = self.scale_f0(radius)
259
  radius = F.softplus(self.radius) * radius
260
- # radius = radius.unsqueeze(-1)
261
  freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
262
  else:
263
  freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
264
  # print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
265
  if "rotary" in self.debug:
266
  if f0 is not None:
267
- key = f"{self._counter}_{f0_theta:.2f}"
268
  if key not in rotary._seen:
269
  if not hasattr(self, '_prev_f0_theta'):
270
- self._prev_f0_theta = f0_theta
271
- print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
272
- elif abs(self._prev_f0_theta - f0_theta) > 1000.0:
273
- print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
274
- self._prev_f0_theta = f0_theta
275
  rotary._seen.add(key)
276
  self._counter += 1
277
  return freqs
@@ -306,81 +328,6 @@ class rotary(nn.Module):
306
  x1 = torch.view_as_real(x1).flatten(-2)
307
  return torch.cat([x1.type_as(x), x2], dim=-1)
308
 
309
- class SliceAttention(nn.Module):
310
- def __init__(self, dims, heads, dropout=0.0):
311
- super().__init__()
312
- self.dims = dims
313
- self.heads = heads
314
- self.head_dim = dims // heads
315
- self.scale = self.head_dim ** -0.5
316
-
317
- self.q_proj = Linear(dims, dims)
318
- self.k_proj = Linear(dims, dims)
319
- self.v_proj = Linear(dims, dims)
320
- self.out_proj = Linear(dims, dims)
321
- self.dropout = nn.Dropout(dropout)
322
-
323
- assert dims % heads == 0, f"Dimensions {dims} not divisible by heads {heads}"
324
-
325
- def parallel_slice(self, q, k, v, mask=None):
326
- batch, heads, ctx, dims = q.shape
327
- head_dim = self.head_dim
328
- batch, ctx, dims = q.shape
329
- ctx_len = k.shape[1]
330
- num_heads = dims // head_dim
331
-
332
- scores = torch.zeros(batch, num_heads, ctx, ctx_len, device=q.device)
333
-
334
- for h in range(num_heads):
335
- start_idx = h * head_dim
336
- end_idx = start_idx + head_dim
337
- q_h = q[:, :, start_idx:end_idx]
338
- k_h = k[:, :, start_idx:end_idx]
339
-
340
- scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
341
-
342
- if mask is not None:
343
- scores = scores + mask.unsqueeze(0).unsqueeze(0)
344
-
345
- attn_weights = F.softmax(scores, dim=-1)
346
-
347
- output = torch.zeros_like(q)
348
- for h in range(num_heads):
349
- start_idx = h * head_dim
350
- end_idx = start_idx + head_dim
351
- v_h = v[:, :, start_idx:end_idx]
352
- output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
353
- return output
354
-
355
- def forward(self, x, context=None, mask=None):
356
- batch, ctx, _ = x.shape
357
- if context is None:
358
- context = x
359
-
360
- ctx_len = context.shape[1]
361
- q = self.q_proj(x)
362
- k = self.k_proj(context)
363
- v = self.v_proj(context)
364
- output = torch.zeros_like(q)
365
-
366
- for h in range(self.heads):
367
- start_idx = h * self.head_dim
368
- end_idx = start_idx + self.head_dim
369
-
370
- q_h = q[:, :, start_idx:end_idx]
371
- k_h = k[:, :, start_idx:end_idx]
372
- v_h = v[:, :, start_idx:end_idx]
373
-
374
- attn_scores = torch.bmm(q_h, k_h.transpose(1, 2)) * self.scale
375
- if mask is not None:
376
- attn_scores = attn_scores + mask[:ctx, :ctx_len].unsqueeze(0)
377
-
378
- attn_weights = F.softmax(attn_scores, dim=-1)
379
- attn_weights = self.dropout(attn_weights)
380
- head_output = torch.bmm(attn_weights, v_h)
381
- output[:, :, start_idx:end_idx] = head_output
382
- return self.out_proj(output)
383
-
384
  def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
385
 
386
  batch, heads, ctx, dims = q.shape
@@ -403,44 +350,50 @@ def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
403
  class MultiheadA(nn.Module):
404
  _seen = set()
405
  rbf = False
406
- def __init__(self, dims: int, head: int, rotary_emb: bool = False,
407
  zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
408
 
409
  super(MultiheadA, self).__init__()
410
 
411
- self.debug = debug
412
- self.pad_token = 0
413
  self.dims = dims
414
  self.head = head
415
  self.head_dim = dims // head
416
- self.rotary_emb = rotary_emb
417
- self.minz = minz
418
- self.maxz = maxz
419
- self.zero_val = zero_val
420
- self.optim_attn = optim_attn
421
- self._counter = 0
422
- if dims % head != 0:
423
- raise ValueError(f"Dimensions {dims} must be divisible by number of heads {head}.")
424
- if zero_val < minz or zero_val > maxz:
425
- raise ValueError(f"Zero value {zero_val} must be between {minz} and {maxz}.")
426
 
427
  self.q = Linear(dims, dims)
428
  self.k = Linear(dims, dims, bias=False)
429
  self.v = Linear(dims, dims)
430
  self.o = Linear(dims, dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True)
432
 
433
  if rotary_emb:
434
  self.rope = rotary(
435
  dims=self.head_dim,
436
- debug = debug,
437
- max_ctx=1500,
438
- )
 
 
 
 
439
  else:
440
  self.rope = None
441
 
442
  def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
443
- scale = (self.dims // self.head) ** -0.25
444
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
445
  if rbf_ratio <= 0.0:
446
  return dot_scores
@@ -452,37 +405,25 @@ class MultiheadA(nn.Module):
452
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
453
 
454
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
455
- return_attn: bool = False, f0: Tensor = None) -> tuple:
456
 
457
  batch, ctx, dims = x.shape
458
- scale = (self.dims // self.head) ** -0.25
459
 
460
  z = default(xa, x)
461
  q = self.q(x).to(x.dtype)
462
  k = self.k(z).to(x.dtype)
463
  v = self.v(z).to(x.dtype)
464
 
465
- if self.rotary_emb:
466
- if f0 is not None:
467
- qf = self.rope(q.size(1), f0=f0)
468
- kf = self.rope(k.size(1), f0=f0)
469
- else:
470
- qf = self.rope(q.size(1))
471
- kf = self.rope(k.size(1))
472
-
473
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
474
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
475
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
476
-
477
  q = self.rope.apply_rotary(q, qf)
478
- k = self.rope.apply_rotary(k, kf)
479
-
480
- else:
481
- q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
482
- k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
483
- v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
484
- batch, head, ctx, head_dim = q.shape
485
 
 
 
486
  if self.optim_attn and not return_attn:
487
  wv = optim_attn(q * scale, k * scale, v, mask=mask,
488
  pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
@@ -490,33 +431,43 @@ class MultiheadA(nn.Module):
490
 
491
  if self.rbf:
492
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
 
 
493
 
494
- qk = (q * scale) @ (k * scale).transpose(-1, -2)
495
- if f0 is not None and self.rope.use_pbias:
496
  pbias = self.rope.pbias(f0)
497
  if pbias is not None:
498
- qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
499
- token_ids = k[:, :, :, 0]
 
 
500
  zscale = torch.ones_like(token_ids)
501
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
502
  zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
 
503
 
504
  if mask is not None:
505
- mask = mask[:q.shape[2], :q.shape[2]]
506
- qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
 
507
  qk = qk * zscale.unsqueeze(-2)
 
508
  if return_attn:
509
- return qk, v
 
 
510
  w = F.softmax(qk, dim=-1).to(q.dtype)
511
- wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
 
512
 
513
  if "multihead" in self.debug and self._counter % 100 == 0:
514
  print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
515
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
516
  print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
517
- self._counter += 1
 
518
  return self.o(wv), qk.detach()
519
-
520
  class FCGate(nn.Module):
521
  def __init__(self, dims, dim):
522
  super().__init__()
@@ -572,7 +523,7 @@ class CMGate(nn.Module):
572
  self.integration = Linear(dims*3, dims)
573
 
574
  def forward(self, x, features):
575
- sfeat = features.get("spectrogram", x)
576
  wfeat = features.get("waveform", x)
577
  pfeat = features.get("pitch", x)
578
  spec = self.sgate(x) * sfeat
@@ -582,18 +533,26 @@ class CMGate(nn.Module):
582
  combined = torch.cat([spec, wave, pitch], dim=-1)
583
  return self.integration(combined)
584
 
585
- class Residual(nn.Module):
586
  _seen = set()
587
  def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [],
588
  fgate=False, tgate=False, mgate=False, cgate=False,
589
  memory_size=512, features=None):
590
  super().__init__()
591
- self.ctx = ctx
592
- self._counter = 0
593
- self.dropout = 0.01
594
  self.dims = dims
595
  self.head = head
 
596
  self.head_dim = dims // head
 
 
 
 
 
 
 
 
 
597
  self.cross_attn = cross_attn
598
  self.debug = debug
599
  self.fgate = fgate
@@ -627,12 +586,13 @@ class Residual(nn.Module):
627
  if not any([fgate, tgate, mgate, cgate]):
628
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
629
 
630
- def forward(self, x, xa=None, mask=None, f0=None, mode=None):
631
- x = x + self.attna(self.lna(x), mask=mask, f0=f0)[0]
 
632
 
633
  if self.attnb and xa is not None:
634
- cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None)[0]
635
- blend = torch.sigmoid(self.blend)
636
  x = blend * x + (1 - blend) * cross
637
 
638
  normx = self.lnc(x)
@@ -661,11 +621,12 @@ class Residual(nn.Module):
661
  x = x + mlp_gate * mlp_out
662
  else:
663
  x = x + mlp_out
 
664
  if "residual" in self.debug and self._counter % 100 == 0:
665
  print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
666
  self._counter += 1
667
  return x
668
-
669
  class PEncoder(nn.Module):
670
  def __init__(self, input_dims, dims, head, layer, kernel_size, act):
671
  super().__init__()
@@ -681,7 +642,7 @@ class PEncoder(nn.Module):
681
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
682
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
683
 
684
- def forward(self, x, f0=None):
685
  x = self.encoder(x).permute(0, 2, 1)
686
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
687
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
@@ -710,7 +671,7 @@ class WEncoder(nn.Module):
710
  self.positional = lambda length: sinusoids(length, dims)
711
  self.norm = RMSNorm(dims)
712
 
713
- def forward(self, x, f0=None):
714
  x = self.downsample(x)
715
  x = self.encoder(x)
716
  x = x.permute(0, 2, 1)
@@ -737,32 +698,34 @@ class FEncoder(nn.Module):
737
  self.norm = RMSNorm(dims)
738
  self._norm = RMSNorm(dims)
739
 
740
- def forward(self, x, f0=None):
741
  x = self.encoder(x).permute(0, 2, 1)
742
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
743
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
744
  x = self._norm(x)
745
  return x
746
-
747
  class AudioEncoder(nn.Module):
748
  _seen = set()
749
  def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str],
750
  debug: List[str], f0_rotary: bool = False, act: str = "gelu"):
751
  super(AudioEncoder, self).__init__()
752
-
 
 
 
 
 
 
 
 
 
753
  self.debug = debug
754
- self.features = features
755
  self._counter = 0
 
 
756
  self.dropout = 0.01
757
  self.f0_rotary = f0_rotary
758
- self.dims = dims
759
- self.ctx = ctx
760
- self.head = head
761
- self.head_dim = dims // head
762
-
763
- self.rope = rotary(
764
- dims=self.head_dim,
765
- debug=debug,)
766
 
767
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
768
  act_fn = act_map.get(act, nn.GELU())
@@ -794,7 +757,7 @@ class AudioEncoder(nn.Module):
794
  [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
795
  })
796
 
797
- def forward(self, x, f0=None):
798
  outputs = {}
799
  if self.f0_rotary:
800
  f0 = f0 if f0 is not None else x.get("pitch")
@@ -804,7 +767,7 @@ class AudioEncoder(nn.Module):
804
  if y in x and y in self.blocks:
805
  f = x[y]
806
  for block in self.blocks[y]:
807
- f = block(f, f0=f0)
808
  outputs[y] = f
809
 
810
  if "encoder" in self.debug and self._counter % 100 == 0:
@@ -820,9 +783,19 @@ class TextDecoder(nn.Module):
820
  features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False):
821
  super(TextDecoder, self).__init__()
822
 
 
 
 
 
 
 
 
 
 
 
823
  self._counter = 0
 
824
  self.dropout = 0.01
825
- self.debug = debug
826
  self.sequential = sequential
827
  self.features = features
828
  self.f0_rotary = f0_rotary
@@ -841,13 +814,13 @@ class TextDecoder(nn.Module):
841
  for _ in range(layer)]) for f in features})
842
 
843
  self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
844
-
845
  self.ln_dec = RMSNorm(dims)
846
 
847
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
848
  self.register_buffer("mask", mask, persistent=False)
849
 
850
- def forward(self, x, enc, order=None, f0=None) -> Tensor:
 
851
  x = x.to(device)
852
  if self.f0_rotary:
853
  f0 = f0
@@ -859,18 +832,18 @@ class TextDecoder(nn.Module):
859
  x = self.token(x) + self.positional[:x.shape[1]]
860
  x = F.dropout(x, p=self.dropout, training=self.training)
861
  for block in self._blocks:
862
- x = block(x, f0=f0, mask=mask)
863
  for f in order:
864
  if f in enc:
865
  xa = enc[f]
866
  for block in self.blocks[f]:
867
- out = block(x=x, xa=xa, f0=f0, mask=None)
868
- a = torch.sigmoid(self.blend[f])
869
  x = a * out + (1 - a) * x
870
  x = self.ln_dec(x)
871
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
872
 
873
- class Echo(nn.Module):
874
  def __init__(self, param: Dimensions):
875
  super().__init__()
876
  self.param = param
@@ -1054,7 +1027,7 @@ class DataCollator:
1054
 
1055
  if "label" in features[0] and features[0]["label"] is not None:
1056
  labels_list = [f["label"] for f in features]
1057
- max_len = max(len(l) for l in labels_list)
1058
  all_ids = []
1059
  all_labels = []
1060
 
@@ -1230,6 +1203,8 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
1230
  batch["envelope"] = torch.stack(envelope_list)
1231
  batch["phase"] = torch.stack(phase_list)
1232
 
 
 
1233
  wav_1d = wav.unsqueeze(0)
1234
 
1235
  if waveforms:
@@ -1271,6 +1246,7 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
1271
  f0, t = pw.dio(wav_np, sampling_rate,
1272
  frame_period=hop_length/sampling_rate*1000)
1273
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
 
1274
  batch["f0"] = torch.from_numpy(f0).float()
1275
 
1276
  if spectrogram and waveforms and pitch:
@@ -1287,6 +1263,7 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
1287
  pitch_max = 600.0
1288
  batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1289
 
 
1290
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1291
  return batch
1292
 
@@ -1356,6 +1333,7 @@ def compute_metrics(eval_pred, compute_result: bool = True,
1356
  }
1357
 
1358
  print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}")
 
1359
  return metrics
1360
 
1361
  logger = logging.getLogger(__name__)
@@ -1420,12 +1398,12 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
1420
  "en_us",
1421
  token=token,
1422
  trust_remote_code=True,
1423
- streaming=False
1424
- )
1425
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000))
1426
 
1427
  if sanity_check:
1428
- dataset = dataset["test"].take(10).shuffle()
1429
  dataset = dataset.select_columns(["audio", "transcription"])
1430
  logger.info(f"Sanity dataset size: {dataset.num_rows}")
1431
  print(f"Sanity dataset size: {dataset.num_rows}")
@@ -1443,13 +1421,14 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
1443
  len(x["audio"]["array"]) > 0 and
1444
  len(x["audio"]["array"]) < 1500 * 160)
1445
 
1446
- dataset = dataset.filter(filter_func).shuffle()
1447
  logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1448
  print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1449
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1450
- train_dataset = dataset["train"]
1451
- test_dataset = dataset["test"]
1452
  columns_to_remove = list(next(iter(dataset.values())).features)
 
 
 
1453
 
1454
  train_dataset = train_dataset.map(
1455
  function=prepare_fn,
@@ -1468,10 +1447,10 @@ def get_training_args(
1468
  batch_eval_metrics: bool = False,
1469
  max_steps: int = 10,
1470
  save_steps: int = 1000,
1471
- eval_steps: int = 100,
1472
  warmup_steps: int = 0,
1473
  num_train_epochs: int = 1,
1474
- logging_steps: int = 10,
1475
  eval_on_start: bool = False,
1476
  learning_rate: float = 1e-4,
1477
  weight_decay: float = 0.01,
@@ -1530,7 +1509,7 @@ def main():
1530
  eval_steps = 1,
1531
  warmup_steps = 0,
1532
  logging_steps = 1,
1533
- eval_on_start = True,
1534
  learning_rate = 5e-6,
1535
  weight_decay = 0.01,
1536
  )
@@ -1538,11 +1517,11 @@ def main():
1538
  training_args = get_training_args(
1539
  log_dir,
1540
  batch_eval_metrics = False,
1541
- max_steps = 10000,
1542
- save_steps = 10000,
1543
- eval_steps = 1000,
1544
- warmup_steps = 1000,
1545
- logging_steps = 100,
1546
  eval_on_start = False,
1547
  learning_rate = 2.5e-4,
1548
  weight_decay = 0.01,
@@ -1562,22 +1541,20 @@ def main():
1562
  text_dims=512,
1563
  text_idx=4,
1564
  act="swish",
1565
- debug={}, #{"encoder", "decoder", "residual", "rotary"}, debug prints for specific modules
1566
  cross_attn=True,
1567
- f0_rotary=True,
1568
- features = ["spectrogram"], # ["spectrogram", "waveform", "pitch"] any combo and order matters
1569
- )
1570
 
1571
  sanity_check = False
1572
-
1573
  training_args = sanity(sanity_check)
1574
-
1575
  dataset_config = {
1576
  "spectrogram": True,
1577
  "waveforms": False,
1578
  "pitch": False,
1579
  "downsamples": False,
1580
- "f0": True,
1581
  "hilbert": False,
1582
  "hop_length": 128,
1583
  "fmin": 150,
@@ -1608,7 +1585,6 @@ def main():
1608
  sanity_check=sanity_check,
1609
  dataset_config=dataset_config)
1610
 
1611
-
1612
  trainer = Seq2SeqTrainer(
1613
  args=training_args,
1614
  model=model,
@@ -1617,10 +1593,9 @@ def main():
1617
  data_collator=DataCollator(tokenizer=tokenizer),
1618
  compute_metrics=metrics_fn,
1619
  )
1620
-
1621
  trainer.train()
1622
 
1623
  if __name__ == "__main__":
1624
  main()
1625
 
1626
-
 
 
1
  import pyworld as pw
2
  import os
3
  import math
 
5
  import logging
6
  import gzip
7
  import base64
8
+ from einops import rearrange, repeat
9
  import torch
10
  import torchaudio
11
  import torchcrepe
 
16
  from typing import Optional, Dict, Union, List, Tuple, Any
17
  from functools import partial
18
  from datetime import datetime
19
+ from datasets import load_dataset, Audio, concatenate_datasets
20
  from transformers.trainer_seq2seq import Seq2SeqTrainer
21
  from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
22
  import transformers
23
  import evaluate
24
  from dataclasses import dataclass
25
+ from math import pi, log
26
 
27
  torch.backends.cudnn.allow_tf32 = True
28
  torch.backends.cuda.matmul.allow_tf32 = True
 
37
  logging.basicConfig(level=logging.ERROR)
38
  tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
39
 
40
+ # %xmode Minimal
41
+ # %xmode Plain
42
+ # %xmode Context
43
+ # %xmode Verbose
44
+
45
  extractor = None
46
  tokenizer = None
47
  optimizer = None
 
67
  cross_attn: bool
68
  features: List[str]
69
  f0_rotary: bool
70
+
71
  def exists(v):
72
  return v is not None
73
 
 
105
  self.eps = eps
106
  self.elementwise_affine = elementwise_affine
107
  if self.elementwise_affine:
108
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
109
  init.ones_(self.weight)
110
  else:
111
  self.register_parameter("weight", None)
 
134
  scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
135
  return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
136
 
137
+
138
  class rotary(nn.Module):
139
  _seen = set()
140
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
 
151
  self.dims = dims
152
  self.max_ctx = max_ctx
153
  self.radii = radii
154
+ max_freq = 10.0
155
+ f0_scale_factor = 0.5
156
+ self.learned_adaptation: bool = False
157
+ pitch_scale = 1.0 # theta_rescale = 1.0
158
 
159
+ if self.learned_adaptation:
160
+ self.f0_scale = nn.Parameter(torch.tensor(f0_scale_factor))
161
+ else:
162
+ self.register_buffer('f0_scale', torch.tensor(f0_scale_factor))
163
 
164
+ self.theta = nn.Parameter(torch.tensor(float(theta)), requires_grad=learned_theta)
 
165
 
166
+ self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale), requires_grad=learned_pitch)
 
167
 
168
  freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
169
  self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
170
 
171
  if radii:
172
+ self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius)
 
173
 
174
  def get_pitch_bias(self, f0):
175
  if f0 is None:
176
  return None
 
177
  f0_flat = f0.squeeze().float()
178
  f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
179
  f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
 
199
  def align_f0(self, f0, token_length):
200
  batch_size, f0_length = f0.shape
201
  if f0_length == token_length:
202
+ return f0
203
  frames_per_token = f0_length / token_length
 
204
  indices = torch.arange(token_length, device=f0.device)
205
  indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
206
+ batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1) #center_positions = ((indices + 0.5) * frames_per_token).long()
 
207
  return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
208
 
209
  def scale_f0(self, f0):
210
  f0_min = f0.min(dim=1, keepdim=True)[0]
211
  f0_max = f0.max(dim=1, keepdim=True)[0]
212
  denom = f0_max - f0_min + 1e-8
213
+ normalized_f0 = (f0 - f0_min) / denom # normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
 
214
  normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
215
  return normalized_f0
216
 
 
232
  log_freq = torch.log(freq)
233
  log_min_freq = torch.log(min_freq)
234
  log_max_freq = torch.log(max_freq)
 
235
  mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
236
  return mapped_log_freq
237
 
238
+ def get_f0_adapted_freqs(self, ctx, f0=None):
239
+ f0_min: float = 80.0, # Typical human voice low range
240
+ f0_max: float = 500.0, # Typical human voice high range
241
+ base_freq: float = 1.0,
242
+ positions = torch.arange(ctx, device=device, dtype=torch.float)
243
+ freqs = base_freq.clone()
244
+ if f0 is not None:
245
+ f0_norm = torch.clamp((f0 - f0_min) / (f0_max - f0_min), 0.0, 1.0)
246
+ freq_mod = torch.pow(torch.linspace(0.5, 1.5, self.dims//2, device=device),
247
+ f0_norm.unsqueeze(-1) * self.f0_scale)
248
+ freqs = freqs * freq_mod
249
+ freqs = torch.outer(positions, freqs)
250
+ return torch.polar(torch.ones_like(freqs), freqs)
251
+
252
  def forward(self, x=None, f0=None, stage=None) -> Tensor:
253
  if isinstance(x, int):
254
+ ctx = x
255
  else:
256
+ batch, ctx, dims = x.shape
257
+ t = torch.arange(ctx, device=self.device).float()
258
+
259
+ if self.learned_adaptation:
260
+ freqs = self.get_f0_adapted_freqs(ctx, f0)
261
+ x_complex = torch.view_as_complex(
262
+ x.float().reshape(*x.shape[:-1], -1, 2).contiguous())
263
+ x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)
264
+ freqs = torch.view_as_real(x_rotated).flatten(3).type_as(x)
265
 
266
  if f0 is not None:
267
+ f0_mean=f0.mean()+1e-8
268
+ pitch_scale=self.pitch_scale
269
+ theta=f0_mean*pitch_scale # f0_theta = f0_mean #(f0_mean * 1e-2 + 1.0) * (theta / 2 )
270
+ freqs = 1.0 / (theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
271
  else:
272
  freqs = self.freqs
273
 
274
  freqs = torch.einsum('i,j->ij', t, freqs)
275
  freqs = freqs.float()
276
 
277
+ if self.radii and f0 is not None:
278
+
279
+ radius = self.align_f0(f0, ctx)
280
  # radius = self.scale_f0(radius)
281
  radius = F.softplus(self.radius) * radius
282
+ # radius = radius.unsqueeze(-1) # Ensure radius is of shape (batch, ctx, dims//2)
283
  freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
284
  else:
285
  freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
286
  # print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
287
  if "rotary" in self.debug:
288
  if f0 is not None:
289
+ key = f"{self._counter}_{theta:.2f}"
290
  if key not in rotary._seen:
291
  if not hasattr(self, '_prev_f0_theta'):
292
+ self._prev_f0_theta = theta
293
+ print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
294
+ elif abs(self._prev_f0_theta - theta) > 200.0:
295
+ print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
296
+ self._prev_f0_theta = theta
297
  rotary._seen.add(key)
298
  self._counter += 1
299
  return freqs
 
328
  x1 = torch.view_as_real(x1).flatten(-2)
329
  return torch.cat([x1.type_as(x), x2], dim=-1)
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
332
 
333
  batch, heads, ctx, dims = q.shape
 
350
  class MultiheadA(nn.Module):
351
  _seen = set()
352
  rbf = False
353
+ def __init__(self, dims: int, head: int, rotary_emb: bool = True,
354
  zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
355
 
356
  super(MultiheadA, self).__init__()
357
 
 
 
358
  self.dims = dims
359
  self.head = head
360
  self.head_dim = dims // head
 
 
 
 
 
 
 
 
 
 
361
 
362
  self.q = Linear(dims, dims)
363
  self.k = Linear(dims, dims, bias=False)
364
  self.v = Linear(dims, dims)
365
  self.o = Linear(dims, dims)
366
+
367
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
368
+ dtype = torch.float32
369
+ self.device = device
370
+ self.dtype = dtype
371
+ self.debug = debug
372
+ self._counter = 0
373
+
374
+ self.pad_token = 0
375
+ self.rotary_emb = rotary_emb
376
+ self.minz = minz
377
+ self.maxz = maxz
378
+ self.zero_val = zero_val
379
+ self.optim_attn = optim_attn
380
  self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True)
381
 
382
  if rotary_emb:
383
  self.rope = rotary(
384
  dims=self.head_dim,
385
+ debug=debug,
386
+ radii=False,
387
+ learned_pitch=False,
388
+ learned_freq=False,
389
+ learned_theta=False,
390
+ learned_radius=False,
391
+ )
392
  else:
393
  self.rope = None
394
 
395
  def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
396
+ scale = self.head_dim ** -0.25
397
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
398
  if rbf_ratio <= 0.0:
399
  return dot_scores
 
405
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
406
 
407
  def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
408
+ return_attn: bool = False, f0: Tensor = None, stage=None) -> tuple:
409
 
410
  batch, ctx, dims = x.shape
411
+ scale = self.head_dim ** -0.25
412
 
413
  z = default(xa, x)
414
  q = self.q(x).to(x.dtype)
415
  k = self.k(z).to(x.dtype)
416
  v = self.v(z).to(x.dtype)
417
 
418
+ if self.rotary_emb:
419
+ qf = self.rope(q.size(1), f0=f0, stage=stage)
420
+ kf = self.rope(k.size(1), f0=f0, stage=stage)
421
+
 
 
 
 
 
 
 
 
422
  q = self.rope.apply_rotary(q, qf)
423
+ k = self.rope.apply_rotary(k, kf)
 
 
 
 
 
 
424
 
425
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.head), [q, k, v]) #[batch_size, sequence_length, model_dimension]
426
+
427
  if self.optim_attn and not return_attn:
428
  wv = optim_attn(q * scale, k * scale, v, mask=mask,
429
  pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
 
431
 
432
  if self.rbf:
433
  qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
434
+ else:
435
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
436
 
437
+ if f0 is not None and self.rotary_emb and self.rope.use_pbias:
 
438
  pbias = self.rope.pbias(f0)
439
  if pbias is not None:
440
+ pbias = rearrange(pbias, 'b h i j -> (b h) i j') # batch * head, sequence_length, sequence_length
441
+ qk = qk + pbias[:, :q.shape[1], :q.shape[1]]
442
+
443
+ token_ids = rearrange(k, '(b h) n d -> b h n d', b=batch)[:, :, :, 0]
444
  zscale = torch.ones_like(token_ids)
445
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
446
  zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
447
+ zscale = rearrange(zscale, 'b h n -> (b h) n')
448
 
449
  if mask is not None:
450
+ mask = mask[:q.shape[1], :q.shape[1]]
451
+ expanded_mask = mask.unsqueeze(0).expand(batch * self.head, -1, -1)
452
+ qk = qk + expanded_mask * zscale.unsqueeze(-2)
453
  qk = qk * zscale.unsqueeze(-2)
454
+
455
  if return_attn:
456
+ v_reshaped = rearrange(v, '(b h) n d -> b h n d', b=batch) # (batch_size, head, sequence_length, head_dim)
457
+ return qk, v_reshaped
458
+
459
  w = F.softmax(qk, dim=-1).to(q.dtype)
460
+ wv = w @ v
461
+ wv = rearrange(wv, '(b h) n d -> b n (h d)', b=batch, h=self.head) # (batch_size, sequence_length, model_dimension)
462
 
463
  if "multihead" in self.debug and self._counter % 100 == 0:
464
  print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
465
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
466
  print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
467
+ self._counter += 1
468
+
469
  return self.o(wv), qk.detach()
470
+
471
  class FCGate(nn.Module):
472
  def __init__(self, dims, dim):
473
  super().__init__()
 
523
  self.integration = Linear(dims*3, dims)
524
 
525
  def forward(self, x, features):
526
+ sfeat = features.get("spectrogram", x) # Default to input if missing
527
  wfeat = features.get("waveform", x)
528
  pfeat = features.get("pitch", x)
529
  spec = self.sgate(x) * sfeat
 
533
  combined = torch.cat([spec, wave, pitch], dim=-1)
534
  return self.integration(combined)
535
 
536
+ class Residual(nn.Module): # noqa: F811
537
  _seen = set()
538
  def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [],
539
  fgate=False, tgate=False, mgate=False, cgate=False,
540
  memory_size=512, features=None):
541
  super().__init__()
542
+
 
 
543
  self.dims = dims
544
  self.head = head
545
+ self.ctx = ctx
546
  self.head_dim = dims // head
547
+
548
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
549
+ dtype = torch.float32
550
+ self.device = device
551
+ self.dtype = dtype
552
+ self.debug = debug
553
+ self._counter = 0
554
+
555
+ self.dropout = 0.01
556
  self.cross_attn = cross_attn
557
  self.debug = debug
558
  self.fgate = fgate
 
586
  if not any([fgate, tgate, mgate, cgate]):
587
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
588
 
589
+ def forward(self, x, xa=None, mask=None, f0=None, mode=None, stage=None):
590
+ bln = self.blend
591
+ x = x + self.attna(self.lna(x), mask=mask, f0=f0, stage=stage)[0]
592
 
593
  if self.attnb and xa is not None:
594
+ cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None, stage=stage)[0]
595
+ blend = torch.sigmoid(bln)
596
  x = blend * x + (1 - blend) * cross
597
 
598
  normx = self.lnc(x)
 
621
  x = x + mlp_gate * mlp_out
622
  else:
623
  x = x + mlp_out
624
+
625
  if "residual" in self.debug and self._counter % 100 == 0:
626
  print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
627
  self._counter += 1
628
  return x
629
+
630
  class PEncoder(nn.Module):
631
  def __init__(self, input_dims, dims, head, layer, kernel_size, act):
632
  super().__init__()
 
642
  Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
643
  Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
644
 
645
+ def forward(self, x, f0=None, stage=None):
646
  x = self.encoder(x).permute(0, 2, 1)
647
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
648
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
 
671
  self.positional = lambda length: sinusoids(length, dims)
672
  self.norm = RMSNorm(dims)
673
 
674
+ def forward(self, x, f0=None, stage=None):
675
  x = self.downsample(x)
676
  x = self.encoder(x)
677
  x = x.permute(0, 2, 1)
 
698
  self.norm = RMSNorm(dims)
699
  self._norm = RMSNorm(dims)
700
 
701
+ def forward(self, x, f0=None, stage=None):
702
  x = self.encoder(x).permute(0, 2, 1)
703
  x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
704
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
705
  x = self._norm(x)
706
  return x
707
+
708
  class AudioEncoder(nn.Module):
709
  _seen = set()
710
  def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str],
711
  debug: List[str], f0_rotary: bool = False, act: str = "gelu"):
712
  super(AudioEncoder, self).__init__()
713
+
714
+ self.dims = dims
715
+ self.head = head
716
+ self.ctx = ctx
717
+ self.head_dim = dims // head
718
+
719
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
720
+ dtype = torch.float32
721
+ self.device = device
722
+ self.dtype = dtype
723
  self.debug = debug
 
724
  self._counter = 0
725
+
726
+ self.features = features
727
  self.dropout = 0.01
728
  self.f0_rotary = f0_rotary
 
 
 
 
 
 
 
 
729
 
730
  act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
731
  act_fn = act_map.get(act, nn.GELU())
 
757
  [Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
758
  })
759
 
760
+ def forward(self, x, f0=None, stage="encoder"):
761
  outputs = {}
762
  if self.f0_rotary:
763
  f0 = f0 if f0 is not None else x.get("pitch")
 
767
  if y in x and y in self.blocks:
768
  f = x[y]
769
  for block in self.blocks[y]:
770
+ f = block(f, f0=f0, stage=stage)
771
  outputs[y] = f
772
 
773
  if "encoder" in self.debug and self._counter % 100 == 0:
 
783
  features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False):
784
  super(TextDecoder, self).__init__()
785
 
786
+ self.dims = dims
787
+ self.head = head
788
+ self.ctx = ctx
789
+ self.head_dim = dims // head
790
+
791
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
792
+ dtype = torch.float32
793
+ self.device = device
794
+ self.dtype = dtype
795
+ self.debug = debug
796
  self._counter = 0
797
+
798
  self.dropout = 0.01
 
799
  self.sequential = sequential
800
  self.features = features
801
  self.f0_rotary = f0_rotary
 
814
  for _ in range(layer)]) for f in features})
815
 
816
  self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
 
817
  self.ln_dec = RMSNorm(dims)
818
 
819
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
820
  self.register_buffer("mask", mask, persistent=False)
821
 
822
+ def forward(self, x, enc, order=None, f0=None, stage='decoder') -> Tensor:
823
+ bln = self.blend
824
  x = x.to(device)
825
  if self.f0_rotary:
826
  f0 = f0
 
832
  x = self.token(x) + self.positional[:x.shape[1]]
833
  x = F.dropout(x, p=self.dropout, training=self.training)
834
  for block in self._blocks:
835
+ x = block(x, f0=f0, mask=mask, stage=stage)
836
  for f in order:
837
  if f in enc:
838
  xa = enc[f]
839
  for block in self.blocks[f]:
840
+ out = block(x=x, xa=xa, f0=f0, mask=None, stage=stage)
841
+ a = torch.sigmoid(bln[f])
842
  x = a * out + (1 - a) * x
843
  x = self.ln_dec(x)
844
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
845
 
846
+ class Echo(nn.Module): # Echo the research model
847
  def __init__(self, param: Dimensions):
848
  super().__init__()
849
  self.param = param
 
1027
 
1028
  if "label" in features[0] and features[0]["label"] is not None:
1029
  labels_list = [f["label"] for f in features]
1030
+ max_len = max(len(l) for l in labels_list) # noqa: E741
1031
  all_ids = []
1032
  all_labels = []
1033
 
 
1203
  batch["envelope"] = torch.stack(envelope_list)
1204
  batch["phase"] = torch.stack(phase_list)
1205
 
1206
+ # batch["spec_envelope_freq"] = hilbert_transform_2d(spec, dim=-2)
1207
+
1208
  wav_1d = wav.unsqueeze(0)
1209
 
1210
  if waveforms:
 
1246
  f0, t = pw.dio(wav_np, sampling_rate,
1247
  frame_period=hop_length/sampling_rate*1000)
1248
  f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
1249
+ f0 = f0
1250
  batch["f0"] = torch.from_numpy(f0).float()
1251
 
1252
  if spectrogram and waveforms and pitch:
 
1263
  pitch_max = 600.0
1264
  batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
1265
 
1266
+ # print(f"Spectrogram shape: {batch['spectrogram'].shape}, Waveform shape: {batch['waveform'].shape if 'waveform' in batch else None}, Pitch shape: {batch['pitch'].shape if 'pitch' in batch else None}, F0 shape: {batch['f0'].shape if 'f0' in batch else None}, Envelope shape: {batch['envelope'].shape if 'envelope' in batch else None}, Phase shape: {batch['phase'].shape if 'phase' in batch else None}")
1267
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1268
  return batch
1269
 
 
1333
  }
1334
 
1335
  print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}")
1336
+
1337
  return metrics
1338
 
1339
  logger = logging.getLogger(__name__)
 
1398
  "en_us",
1399
  token=token,
1400
  trust_remote_code=True,
1401
+ streaming=False)
1402
+
1403
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1404
 
1405
  if sanity_check:
1406
+ dataset = dataset["test"].take(10)
1407
  dataset = dataset.select_columns(["audio", "transcription"])
1408
  logger.info(f"Sanity dataset size: {dataset.num_rows}")
1409
  print(f"Sanity dataset size: {dataset.num_rows}")
 
1421
  len(x["audio"]["array"]) > 0 and
1422
  len(x["audio"]["array"]) < 1500 * 160)
1423
 
1424
+ dataset = dataset.filter(filter_func).shuffle(seed=4)
1425
  logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1426
  print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1427
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
 
 
1428
  columns_to_remove = list(next(iter(dataset.values())).features)
1429
+ train_dataset = dataset["train"]
1430
+ test_dataset = dataset["test"].take(50) # Limit test set size for faster processing
1431
+ logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
1432
 
1433
  train_dataset = train_dataset.map(
1434
  function=prepare_fn,
 
1447
  batch_eval_metrics: bool = False,
1448
  max_steps: int = 10,
1449
  save_steps: int = 1000,
1450
+ eval_steps: int = 1,
1451
  warmup_steps: int = 0,
1452
  num_train_epochs: int = 1,
1453
+ logging_steps: int = 1,
1454
  eval_on_start: bool = False,
1455
  learning_rate: float = 1e-4,
1456
  weight_decay: float = 0.01,
 
1509
  eval_steps = 1,
1510
  warmup_steps = 0,
1511
  logging_steps = 1,
1512
+ eval_on_start = False,
1513
  learning_rate = 5e-6,
1514
  weight_decay = 0.01,
1515
  )
 
1517
  training_args = get_training_args(
1518
  log_dir,
1519
  batch_eval_metrics = False,
1520
+ max_steps = 1000,
1521
+ save_steps = 1000,
1522
+ eval_steps = 100,
1523
+ warmup_steps = 100,
1524
+ logging_steps = 10,
1525
  eval_on_start = False,
1526
  learning_rate = 2.5e-4,
1527
  weight_decay = 0.01,
 
1541
  text_dims=512,
1542
  text_idx=4,
1543
  act="swish",
1544
+ debug={"rotary"},#{"encoder", "decoder", "residual", "rotary"},
1545
  cross_attn=True,
1546
+ f0_rotary=False,
1547
+ features = ["spectrogram"], # ["spectrogram", "waveform", "pitch"]
1548
+ )# features = ["spectrogram", "spec_envelope", "spec_phase"],
1549
 
1550
  sanity_check = False
 
1551
  training_args = sanity(sanity_check)
 
1552
  dataset_config = {
1553
  "spectrogram": True,
1554
  "waveforms": False,
1555
  "pitch": False,
1556
  "downsamples": False,
1557
+ "f0": False,
1558
  "hilbert": False,
1559
  "hop_length": 128,
1560
  "fmin": 150,
 
1585
  sanity_check=sanity_check,
1586
  dataset_config=dataset_config)
1587
 
 
1588
  trainer = Seq2SeqTrainer(
1589
  args=training_args,
1590
  model=model,
 
1593
  data_collator=DataCollator(tokenizer=tokenizer),
1594
  compute_metrics=metrics_fn,
1595
  )
1596
+
1597
  trainer.train()
1598
 
1599
  if __name__ == "__main__":
1600
  main()
1601