Sin2pi commited on
Commit
badbec0
·
verified ·
1 Parent(s): 685a41f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +444 -465
model.py CHANGED
@@ -1,8 +1,10 @@
1
 
2
- import pyworld as pw
3
  import os
 
4
  import math
5
  import warnings
 
 
6
  import logging
7
  import gzip
8
  import base64
@@ -11,6 +13,7 @@ import torchaudio
11
  import torch.nn.functional as F
12
  import torch.nn.init as init
13
  from torch import nn, Tensor
 
14
  import numpy as np
15
  from einops import rearrange
16
  import matplotlib.pyplot as plt
@@ -18,16 +21,15 @@ from typing import Optional, Dict, Union, List, Tuple, Any
18
  from functools import partial
19
  from datetime import datetime
20
  from datasets import load_dataset, Audio
21
- from transformers.trainer_seq2seq import Seq2SeqTrainer
22
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
23
- import transformers
24
  import evaluate
25
  from dataclasses import dataclass
26
-
27
  torch.backends.cudnn.allow_tf32 = True
28
  torch.backends.cuda.matmul.allow_tf32 = True
29
  torch.set_float32_matmul_precision('high')
30
- transformers.utils.logging.set_verbosity_error()
31
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
  dtype = torch.float32
@@ -40,8 +42,6 @@ tokenizer = None
40
  optimizer = None
41
  scheduler = None
42
  model = None
43
- Residual = None
44
- MultiheadA = None
45
 
46
  @dataclass
47
  class Dimensions:
@@ -284,22 +284,13 @@ class rotary(nn.Module):
284
  self.freqs.data.copy_(freqs)
285
  self.theta.data.copy_(theta)
286
 
287
- def get_bias(self, f0, ctx):
288
  if f0 is None:
289
  return None
290
- if f0.dim() == 1:
291
- length = f0.shape[0]
292
- if length == ctx:
293
- return f0
294
- frames = length / ctx
295
- idx = torch.arange(ctx, device=f0.device)
296
- idx = (idx * frames).long().clamp(0, length - 1)
297
- f0 = f0[idx]
298
- f0_norm = (f0 - f0.mean()) / (f0.std() + 1e-8)
299
- f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
300
  f0_norm.unsqueeze(1)))
301
- # diff = f0_norm[:, None] - f0_norm[None, :]
302
- # f0_sim = torch.exp(-diff.pow(2))
303
  return f0_sim.unsqueeze(0).unsqueeze(0)
304
 
305
  def f0proj(self, f0):
@@ -313,7 +304,6 @@ class rotary(nn.Module):
313
  return f0.to(device=device, dtype=dtype)
314
 
315
  def synth_f0(self, f0, ctx):
316
- # f0 = self.f0proj(f0)
317
  if f0.dim() == 1:
318
  length = f0.shape[0]
319
  if length == ctx:
@@ -321,7 +311,7 @@ class rotary(nn.Module):
321
  frames = length / ctx
322
  idx = torch.arange(ctx, device=f0.device)
323
  return f0[idx]
324
-
325
  def align_f0(self, ctx, f0):
326
  f0 = self.f0proj(f0)
327
  if f0.dim() == 3:
@@ -361,26 +351,22 @@ class rotary(nn.Module):
361
  batch, head, ctx, head_dim = x.shape
362
  t = torch.arange(ctx, device=device, dtype=dtype)
363
 
364
- f0 = enc.get("f0") if enc is not None else None
365
  if f0 is not None and f0.dim() == 2:
366
  if f0.shape[0] == 1:
367
  f0 = f0.squeeze(0)
368
  else:
369
  f0 = f0.view(-1)
370
 
371
- if f0 is not None:
372
  f0_mean = f0.mean()
373
  theta = f0_mean + self.theta
374
  else:
375
- theta = 10000.0
376
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
377
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
378
 
379
- if "rot2" in self.debug and self.counter % 100 == 0:
380
- print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
381
-
382
  freqs = t[:, None] * freqs[None, :]
383
- if self.radii and f0 is not None and layer == "encoder":
384
  radius = f0.to(device, dtype)
385
  L = radius.shape[0]
386
  if L != ctx:
@@ -403,7 +389,6 @@ class rotary(nn.Module):
403
  theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
404
  print(f" [{layer}] [f0] {f0.shape if f0 is not None else None} [Theta] {theta_value:.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
405
 
406
-
407
  if "rot3" in self.debug and self.counter % 100 == 0:
408
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
409
 
@@ -428,19 +413,6 @@ class rotary(nn.Module):
428
  x1 = x1.view(orig_shape)
429
  return torch.cat([x1.type_as(x), x2], dim=-1)
430
 
431
- @staticmethod
432
- def apply_rotary(x, freqs):
433
- x1 = x[..., :freqs.shape[-1]*2]
434
- x2 = x[..., freqs.shape[-1]*2:]
435
- orig_shape = x1.shape
436
- if x1.ndim == 2:
437
- x1 = x1.unsqueeze(0)
438
- x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
439
- x1 = torch.view_as_complex(x1) * freqs
440
- x1 = torch.view_as_real(x1).flatten(-2)
441
- x1 = x1.view(orig_shape)
442
- return torch.cat([x1.type_as(x), x2], dim=-1)
443
-
444
  class MultiheadA(nn.Module):
445
  _seen = set()
446
  rbf = False
@@ -472,8 +444,7 @@ class MultiheadA(nn.Module):
472
  dims=dims,
473
  head=head,
474
  debug=debug,
475
- radii=True if "radii" in debug else False,
476
- use_pbias=True if "pbias" in debug else False,
477
  )
478
  else:
479
  self.rope = None
@@ -525,12 +496,12 @@ class MultiheadA(nn.Module):
525
 
526
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
527
  if self.rope.use_pbias:
528
- f0 = enc.get("f0", None) if enc is not None else None
529
- pbias = self.rope.get_bias(f0, q2)
530
  if pbias is not None:
531
- qk = qk + pbias
532
  token_ids = k[:, :, :, 0]
533
- zscale = torch.ones_like(token_ids)
534
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
535
  zscale[token_ids.float() == self.pad_token] = fzero
536
 
@@ -619,6 +590,7 @@ class Residual(nn.Module):
619
  self.t_gate = tgate
620
  self.m_gate = mgate
621
  self.c_gate = cgate
 
622
 
623
  self.blend = nn.Parameter(torch.tensor(0.5))
624
 
@@ -628,8 +600,8 @@ class Residual(nn.Module):
628
  "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
629
  act_fn = act_map.get(act, nn.GELU())
630
 
631
- self.attna = MultiheadA(dims, head, rotary_emb=True, debug=debug)
632
- self.attnb = (MultiheadA(dims, head, rotary_emb=True, debug=debug) if cross_attn else None)
633
 
634
  mlp = dims * 4
635
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
@@ -660,25 +632,28 @@ class Residual(nn.Module):
660
 
661
  normx = self.lnc(x)
662
  mlp_out = self.mlp(normx)
663
-
664
- if self.t_gate:
665
- gate = self.t_gate(normx)
666
- x = x + gate * mlp_out
667
-
668
- elif self.m_gate:
669
- gate = self.m_gate(normx)
670
- x = x + gate * mlp_out
671
-
672
- elif self.c_gate:
673
- gate_output = self.c_gate(normx, self.features)
674
- x = x + gate_output
675
 
 
 
676
  else:
677
- if hasattr(self, 'mlp_gate'):
678
- mlp_gate = self.mlp_gate(normx)
679
- x = x + mlp_gate * mlp_out
 
 
 
 
 
 
 
 
 
680
  else:
681
- x = x + mlp_out
 
 
 
 
682
 
683
  if "residual" in self.debug and self.counter % 100 == 0:
684
  print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
@@ -897,7 +872,7 @@ class AudioEncoder(nn.Module):
897
  )
898
  })
899
 
900
- def forward(self, enc, order=None, layer="encoder"):
901
  enc = dict_to(enc, device, dtype)
902
 
903
  if self.counter < 1:
@@ -906,13 +881,10 @@ class AudioEncoder(nn.Module):
906
  p = default(enc.get("pitch"), enc.get("f0"))
907
  plot_waveform(x=s, w=w, p=p, hop_length=128)
908
 
909
- if order is None:
910
- order = self.features
911
-
912
  out = {}
913
  out.update(enc)
914
 
915
- for f in order:
916
  if f in enc and f in self.blocks:
917
  x = enc[f]
918
  for block in self.blocks[f]:
@@ -921,7 +893,7 @@ class AudioEncoder(nn.Module):
921
 
922
  if "encoder" in self.debug and self.counter % 100 == 0:
923
  shapes = {k: v.shape for k, v in enc.items()}
924
- print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}, order: {order}")
925
  self.counter += 1
926
  return out
927
 
@@ -969,13 +941,12 @@ class TextDecoder(nn.Module):
969
  mask = self.mask[:x.shape[1], :x.shape[1]]
970
  x = self.token(x) + self.positional[:x.shape[1]]
971
  x = F.dropout(x, p=self.dropout, training=self.training)
972
-
973
  for block in self.block:
974
- x = block(x, xa=None, mask=mask, enc=enc, layer=layer)
975
 
976
  for f in order:
977
  if f in enc:
978
-
979
  xa = enc[f]
980
  for block in self.blocks[f]:
981
  out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
@@ -1029,12 +1000,10 @@ class Echo(nn.Module):
1029
  for name, module in self.encoder.named_modules():
1030
  if isinstance(module, (rotary)):
1031
  module.update_base(f0)
1032
- module.return_f0(f0)
1033
 
1034
  for name, module in self.decoder.named_modules():
1035
  if isinstance(module, (rotary)):
1036
  module.update_base(f0)
1037
- module.return_f0(f0)
1038
 
1039
  def set_alignment_head(self, dump: bytes):
1040
  array = np.frombuffer(
@@ -1050,16 +1019,19 @@ class Echo(nn.Module):
1050
  return self.decoder(input_ids, encoder_output)
1051
 
1052
  def forward(self,
 
1053
  labels=None,
1054
  waveform: Optional[torch.Tensor]=None,
1055
  input_ids=None,
1056
  spectrogram: torch.Tensor=None,
1057
  pitch: Optional[torch.Tensor]=None,
1058
  f0: Optional[torch.Tensor]=None,
 
1059
  envelope: Optional[torch.Tensor]=None,
1060
  phase: Optional[torch.Tensor]=None,
1061
  ) -> Dict[str, torch.Tensor]:
1062
 
 
1063
  encoder_inputs = {}
1064
  if spectrogram is not None:
1065
  encoder_inputs["spectrogram"] = spectrogram
@@ -1073,7 +1045,7 @@ class Echo(nn.Module):
1073
  encoder_inputs["phase"] = phase
1074
  if f0 is not None:
1075
  encoder_inputs["f0"] = f0
1076
-
1077
  encoder_outputs = self.encoder(encoder_inputs)
1078
  logits = self.decoder(input_ids, encoder_outputs)
1079
 
@@ -1170,122 +1142,58 @@ class Echo(nn.Module):
1170
  self.counter = 0
1171
  print("Counter reset to 0.")
1172
 
1173
- metric = evaluate.load(path="wer")
1174
-
1175
- @dataclass
1176
- class DataCollator:
1177
- tokenizer: Any
1178
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1179
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1180
- bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
 
 
 
 
 
 
 
 
 
1181
 
1182
- batch = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1183
 
1184
- if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
1185
- spectrogram_list = [f["spectrogram"] for f in features]
1186
- max_len_feat = max(f.shape[-1] for f in spectrogram_list)
1187
- pad_spectrogram = []
1188
- for feat in spectrogram_list:
1189
- current_len = feat.shape[-1]
1190
- padding = max_len_feat - current_len
1191
- if padding > 0:
1192
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1193
- else:
1194
- pad_feat = feat
1195
- pad_spectrogram.append(pad_feat)
1196
- batch["spectrogram"] = torch.stack(pad_spectrogram)
1197
-
1198
- if "waveform" in features[0] and features[0]["waveform"] is not None:
1199
- waveform_list = [f["waveform"] for f in features]
1200
- max_len_wav = max(w.shape[-1] for w in waveform_list)
1201
- pad_waveforms = []
1202
- for wav in waveform_list:
1203
- current_len = wav.shape[-1]
1204
- padding = max_len_wav - current_len
1205
- if padding > 0:
1206
- if wav.ndim == 1:
1207
- wav = wav.unsqueeze(0)
1208
- pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1209
- else:
1210
- pad_wav = wav
1211
- pad_waveforms.append(pad_wav)
1212
- batch["waveform"] = torch.stack(pad_waveforms)
1213
-
1214
- if "label" in features[0] and features[0]["label"] is not None:
1215
- labels_list = [f["label"] for f in features]
1216
- max_len = max(len(l) for l in labels_list)
1217
- all_ids = []
1218
- all_labels = []
1219
 
1220
- for label in labels_list:
1221
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1222
- decoder_input = [bos_token_id] + label_list
1223
- label_eos = label_list + [pad_token_id]
1224
- input_len = max_len + 1 - len(decoder_input)
1225
- label_len = max_len + 1 - len(label_eos)
1226
- padded_input = decoder_input + [pad_token_id] * input_len
1227
- padded_labels = label_eos + [pad_token_id] * label_len
1228
- all_ids.append(padded_input)
1229
- all_labels.append(padded_labels)
1230
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1231
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1232
-
1233
- if "pitch" in features[0] and features[0]["pitch"] is not None:
1234
- pitch_list = [f["pitch"] for f in features]
1235
- max_len_pitch = max(e.shape[-1] for e in pitch_list)
1236
- pad_pitch = []
1237
- for pitch in pitch_list:
1238
- current_len = pitch.shape[-1]
1239
- padding = max_len_pitch - current_len
1240
- if padding > 0:
1241
- pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1242
- else:
1243
- pad_pitch_item = pitch
1244
- pad_pitch.append(pad_pitch_item)
1245
- batch["pitch"] = torch.stack(pad_pitch)
1246
-
1247
- if "f0" in features[0] and features[0]["f0"] is not None:
1248
- f0_list = [f["f0"] for f in features]
1249
- max_len_f0 = max(f.shape[-1] for f in f0_list)
1250
- pad_f0 = []
1251
- for f0 in f0_list:
1252
- current_len = f0.shape[-1]
1253
- padding = max_len_f0 - current_len
1254
- if padding > 0:
1255
- pad_f0_item = F.pad(f0, (0, padding), mode='constant', value=pad_token_id)
1256
- else:
1257
- pad_f0_item = f0
1258
- pad_f0.append(pad_f0_item)
1259
- batch["f0"] = torch.stack(pad_f0)
1260
-
1261
- if "envelope" in features[0] and features[0]["envelope"] is not None:
1262
- env_list = [f["envelope"] for f in features]
1263
- max_len = max(f.shape[-1] for f in env_list)
1264
- pad_env = []
1265
- for feat in env_list:
1266
- current_len = feat.shape[-1]
1267
- padding = max_len - current_len
1268
- if padding > 0:
1269
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1270
- else:
1271
- pad_feat = feat
1272
- pad_env.append(pad_feat)
1273
- batch["envelope"] = torch.stack(pad_env)
1274
-
1275
- if "phase" in features[0] and features[0]["phase"] is not None:
1276
- ph_list = [f["phase"] for f in features]
1277
- max_len = max(f.shape[-1] for f in ph_list)
1278
- pad_ph = []
1279
- for feat in ph_list:
1280
- current_len = feat.shape[-1]
1281
- padding = max_len - current_len
1282
- if padding > 0:
1283
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1284
- else:
1285
- pad_feat = feat
1286
- pad_ph.append(pad_feat)
1287
- batch["phase"] = torch.stack(pad_ph)
1288
- return batch
1289
 
1290
  def hilbert_transform(x):
1291
  N = x.shape[-1]
@@ -1338,26 +1246,51 @@ def process_spectrogram_with_hilbert(spec):
1338
  phase = torch.angle(analytic)
1339
  return envelope, phase
1340
 
1341
- def load_wave(wave_data, sample_rate):
1342
- if isinstance(wave_data, str):
1343
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1344
- elif isinstance(wave_data, dict):
1345
- waveform = torch.tensor(data=wave_data["array"]).float()
1346
- sr = wave_data["sampling_rate"]
1347
- else:
1348
- raise TypeError("Invalid wave_data format.")
1349
-
1350
- if waveform.dim() == 1:
1351
- waveform = waveform.unsqueeze(0)
1352
-
1353
- if sr != sample_rate:
1354
- original_length = waveform.shape[1]
1355
- target_length = int(original_length * (sample_rate / sr))
1356
-
1357
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1358
- waveform = resampler(waveform)
1359
-
1360
- return waveform.flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1361
 
1362
  def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1363
  hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
@@ -1443,72 +1376,20 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1443
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1444
  return batch
1445
 
1446
- def compute_metrics(eval_pred, compute_result: bool = True,
1447
- print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
1448
-
1449
- pred_logits = eval_pred.predictions
1450
- label_ids = eval_pred.label_ids
1451
-
1452
- if hasattr(pred_logits, "cpu"):
1453
- pred_logits = pred_logits.cpu()
1454
- if hasattr(label_ids, "cpu"):
1455
- label_ids = label_ids.cpu()
1456
- if isinstance(pred_logits, tuple):
1457
- pred_ids = pred_logits[0]
1458
  else:
1459
- pred_ids = pred_logits
1460
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1461
- if not isinstance(pred_ids, torch.Tensor):
1462
- pred_ids = torch.tensor(pred_ids)
1463
- pred_ids = pred_ids.argmax(dim=-1)
1464
- pred_ids = pred_ids.tolist()
1465
-
1466
- if hasattr(label_ids, "tolist"):
1467
- label_ids = label_ids.tolist()
1468
-
1469
- label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1470
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1471
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1472
-
1473
- if print_pred:
1474
- for i in range(min(num_samples, len(pred_str))):
1475
- print(f"Preds: {pred_str[i]}")
1476
- print(f"Label: {label_str[i]}")
1477
- print(f"preds: {pred_ids[i]}")
1478
- print(f"label: {label_ids[i]}")
1479
- print("--------------------------------")
1480
-
1481
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1482
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1483
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1484
-
1485
- if model is None:
1486
- global global_model
1487
- if 'global_model' in globals():
1488
- model = global_model
1489
-
1490
- if model is not None:
1491
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1492
- if trainable_params > 0:
1493
- efficiency_score = (100 - wer) / trainable_params
1494
- else:
1495
- print("Warning: Zero trainable parameters detected")
1496
- efficiency_score = 0.0
1497
- else:
1498
- print("Warning: Model not available for parameter counting")
1499
- trainable_params = 0.0
1500
- efficiency_score = 0.0
1501
-
1502
- if hasattr(wer, "item"):
1503
- wer = wer.item()
1504
-
1505
- metrics = {
1506
- "wer": float(wer),
1507
- "trainable_params_M": float(trainable_params),
1508
- "efficiency_score": float(efficiency_score),
1509
- }
1510
-
1511
- return metrics
1512
 
1513
  logger = logging.getLogger(__name__)
1514
 
@@ -1533,13 +1414,16 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/t
1533
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1534
  ids = [id for id in ids if id not in sp_ids]
1535
  return ids
 
1536
  def bdec(ids_list, skip_special_tokens=True):
1537
  results = []
1538
  for ids in ids_list:
 
 
1539
  if skip_special_tokens:
1540
  ids = [id for id in ids if id not in [0, 1, 2]]
1541
  results.append(tokenizer.decode(ids))
1542
- return results
1543
  def save_pretrained(save_dir):
1544
  os.makedirs(save_dir, exist_ok=True)
1545
  tokenizer.save(f"{save_dir}/tokenizer.json")
@@ -1552,229 +1436,324 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/t
1552
  return tokenizer
1553
 
1554
  def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1555
- if dataset_config is None:
1556
- dataset_config = {
1557
- "spectrogram": True,
1558
- "waveforms": True,
1559
- "pitch": True,
1560
- "frequency": True,
1561
- "downsamples": True,
1562
- "hop_length": 128,
1563
- "fmin": 50,
1564
- "fmax": 2000,
1565
- "n_mels": 128,
1566
- "n_fft": 1024,
1567
- "sampling_rate": 16000,
1568
- }
1569
-
1570
- dataset = load_dataset(
1571
- "google/fleurs",
1572
- "en_us",
1573
- token=token,
1574
- trust_remote_code=True,
1575
- streaming=False)
1576
-
1577
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1578
-
1579
  if sanity_check:
 
 
 
 
 
 
 
 
 
1580
  dataset = dataset["test"].take(10)
1581
  dataset = dataset.select_columns(["audio", "transcription"])
1582
- logger.info(f"Sanity dataset size: {dataset.num_rows}")
1583
- print(f"Sanity dataset size: {dataset.num_rows}")
1584
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1585
-
1586
- dataset = dataset.map(
1587
- function=prepare_fn,
1588
- remove_columns=["audio", "transcription"]
1589
- ).with_format(type="torch")
1590
  train_dataset = dataset
1591
  test_dataset = dataset
1592
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1593
  def filter_func(x):
1594
  return (0 < len(x["transcription"]) < 512 and
1595
  len(x["audio"]["array"]) > 0 and
1596
  len(x["audio"]["array"]) < 1500 * 160)
1597
 
1598
- dataset = dataset.filter(filter_func).shuffle(seed=4)
1599
- logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1600
- print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
1601
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1602
- columns_to_remove = list(next(iter(dataset.values())).features)
1603
- train_dataset = dataset["train"]
1604
- test_dataset = dataset["test"].take(50)
1605
- logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
1606
-
1607
  train_dataset = train_dataset.map(
1608
  function=prepare_fn,
1609
- remove_columns=columns_to_remove
1610
  ).with_format(type="torch")
1611
 
1612
  test_dataset = test_dataset.map(
1613
  function=prepare_fn,
1614
- remove_columns=columns_to_remove
1615
  ).with_format(type="torch")
1616
-
 
 
 
1617
  return train_dataset, test_dataset
1618
 
1619
- def get_training_args(
1620
- log_dir: str,
1621
- batch_eval_metrics: bool = False,
1622
- max_steps: int = 10,
1623
- save_steps: int = 1000,
1624
- eval_steps: int = 1,
1625
- warmup_steps: int = 0,
1626
- num_train_epochs: int = 1,
1627
- logging_steps: int = 1,
1628
- eval_on_start: bool = False,
1629
- learning_rate: float = 1e-4,
1630
- weight_decay: float = 0.01,
1631
- max_grad_norm: float = 1.0,
1632
- ) -> Seq2SeqTrainingArguments:
1633
-
1634
- return Seq2SeqTrainingArguments(
1635
- output_dir=log_dir,
1636
- per_device_train_batch_size=1,
1637
- per_device_eval_batch_size=1,
1638
- gradient_accumulation_steps=1,
1639
- eval_accumulation_steps=None,
1640
- eval_strategy="steps",
1641
- save_strategy="no",
1642
- max_steps=max_steps,
1643
- save_steps=save_steps,
1644
- eval_steps=eval_steps,
1645
- warmup_steps=warmup_steps,
1646
- num_train_epochs=num_train_epochs,
1647
- logging_steps=logging_steps,
1648
- logging_dir=log_dir,
1649
- logging_strategy="steps",
1650
- report_to=["tensorboard"],
1651
- push_to_hub=False,
1652
- disable_tqdm=False,
1653
- save_total_limit=1,
1654
- label_names=["labels"],
1655
- optim="adamw_torch",
1656
- lr_scheduler_type="cosine",
1657
- learning_rate=learning_rate,
1658
- weight_decay=weight_decay,
1659
- save_safetensors=False,
1660
- eval_on_start=eval_on_start,
1661
- batch_eval_metrics=batch_eval_metrics,
1662
- max_grad_norm=max_grad_norm,
1663
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1664
 
1665
  def main():
1666
-
1667
  token = ""
1668
- log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H_%M_%S'))
1669
- os.makedirs(name=log_dir, exist_ok=True)
1670
  tokenizer = setup_tokenizer(token)
1671
 
1672
- def sanity(sanity: bool):
1673
-
1674
- if sanity:
1675
- training_args = get_training_args(
1676
- log_dir,
1677
- batch_eval_metrics = False,
1678
- max_steps = 10,
1679
- save_steps = 0,
1680
- eval_steps = 1,
1681
- warmup_steps = 0,
1682
- logging_steps = 1,
1683
- eval_on_start = False,
1684
- learning_rate = 5e-6,
1685
- weight_decay = 0.01,
1686
- )
1687
- else:
1688
- training_args = get_training_args(
1689
- log_dir,
1690
- batch_eval_metrics = False,
1691
- max_steps = 1000,
1692
- save_steps = 1005,
1693
- eval_steps = 100,
1694
- warmup_steps = 100,
1695
- logging_steps = 10,
1696
- eval_on_start = False,
1697
- learning_rate = 2.5e-4,
1698
- weight_decay = 0.01,
1699
- )
1700
-
1701
- return training_args
1702
-
1703
  param = Dimensions(
1704
- mels=128,
1705
- aud_ctx=1500,
1706
- aud_head=4,
1707
- aud_dims=512,
1708
- aud_idx=4,
1709
- vocab=40000,
1710
- text_ctx=512,
1711
- text_head=4,
1712
- text_dims=512,
1713
- text_idx=4,
1714
- act="swish",
1715
- debug={},
1716
- cross_attn=True,
1717
- features = ["spectrogram"]
1718
- )
1719
-
1720
- sanity_check = False
1721
 
1722
- training_args = sanity(sanity_check)
1723
  dataset_config = {
1724
- "spectrogram": True,
1725
- "waveforms": True,
1726
- "pitch": False,
1727
- "downsamples": False,
1728
- "frequency": False,
1729
- "hilbert": False,
1730
- "hop_length": 128,
1731
- "fmin": 150,
1732
- "fmax": 2000,
1733
- "n_mels": 128,
1734
- "n_fft": 1024,
1735
- "sampling_rate": 16000,
1736
- "pad_mode": "constant",
1737
- "center": True,
1738
- "power": 2.0,
1739
- "window_fn": torch.hann_window,
1740
- "mel_scale": "htk",
1741
- "norm": None,
1742
- "normalized": False}
1743
-
1744
  model = create_model(param)
1745
-
1746
- global global_model
1747
- global_model = model
1748
-
1749
- metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1750
- tokenizer=tokenizer, model=model)
1751
-
1752
- print(f"{'Sanity check' if sanity_check else 'Training'} mode")
1753
  train_dataset, test_dataset = prepare_datasets(
1754
- tokenizer=tokenizer,
1755
- token=token,
1756
- sanity_check=sanity_check,
1757
- dataset_config=dataset_config)
 
 
1758
 
1759
- trainer = Seq2SeqTrainer(
1760
- args=training_args,
 
 
 
1761
  model=model,
1762
- train_dataset=train_dataset,
1763
- eval_dataset=test_dataset,
1764
- data_collator=DataCollator(tokenizer=tokenizer),
1765
- compute_metrics=metrics_fn,
1766
- )
1767
-
1768
- model.init_weights()
1769
- trainer.train()
 
 
 
 
 
 
 
 
1770
 
1771
  if __name__ == "__main__":
1772
  main()
1773
 
1774
- # from tensorboard import program
1775
- # log_dir = "./output/logs"
1776
- # tb = program.TensorBoard()
1777
- # tb.configure(argv=[None, '--logdir', log_dir])
1778
- # url = tb.launch()
1779
- # print(f"TensorBoard started at {url}")
1780
-
 
1
 
 
2
  import os
3
+ import pyworld as pw
4
  import math
5
  import warnings
6
+ import time
7
+ import random
8
  import logging
9
  import gzip
10
  import base64
 
13
  import torch.nn.functional as F
14
  import torch.nn.init as init
15
  from torch import nn, Tensor
16
+ from torch.utils.data import Dataset, DataLoader
17
  import numpy as np
18
  from einops import rearrange
19
  import matplotlib.pyplot as plt
 
21
  from functools import partial
22
  from datetime import datetime
23
  from datasets import load_dataset, Audio
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ import tqdm
26
+ from tqdm import tqdm
27
  import evaluate
28
  from dataclasses import dataclass
29
+ import aiohttp
30
  torch.backends.cudnn.allow_tf32 = True
31
  torch.backends.cuda.matmul.allow_tf32 = True
32
  torch.set_float32_matmul_precision('high')
 
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
  dtype = torch.float32
 
42
  optimizer = None
43
  scheduler = None
44
  model = None
 
 
45
 
46
  @dataclass
47
  class Dimensions:
 
284
  self.freqs.data.copy_(freqs)
285
  self.theta.data.copy_(theta)
286
 
287
+ def get_pitch_bias(self, f0):
288
  if f0 is None:
289
  return None
290
+ f0_flat = f0.squeeze().float()
291
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
292
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
 
 
 
 
 
 
 
293
  f0_norm.unsqueeze(1)))
 
 
294
  return f0_sim.unsqueeze(0).unsqueeze(0)
295
 
296
  def f0proj(self, f0):
 
304
  return f0.to(device=device, dtype=dtype)
305
 
306
  def synth_f0(self, f0, ctx):
 
307
  if f0.dim() == 1:
308
  length = f0.shape[0]
309
  if length == ctx:
 
311
  frames = length / ctx
312
  idx = torch.arange(ctx, device=f0.device)
313
  return f0[idx]
314
+
315
  def align_f0(self, ctx, f0):
316
  f0 = self.f0proj(f0)
317
  if f0.dim() == 3:
 
351
  batch, head, ctx, head_dim = x.shape
352
  t = torch.arange(ctx, device=device, dtype=dtype)
353
 
 
354
  if f0 is not None and f0.dim() == 2:
355
  if f0.shape[0] == 1:
356
  f0 = f0.squeeze(0)
357
  else:
358
  f0 = f0.view(-1)
359
 
360
+ if f0 is not None and layer == "encoder":
361
  f0_mean = f0.mean()
362
  theta = f0_mean + self.theta
363
  else:
364
+ theta = self.theta
365
  freqs = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
366
  self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
367
 
 
 
 
368
  freqs = t[:, None] * freqs[None, :]
369
+ if self.radii and f0 is not None:
370
  radius = f0.to(device, dtype)
371
  L = radius.shape[0]
372
  if L != ctx:
 
389
  theta_value = theta.item() if isinstance(theta, torch.Tensor) else theta
390
  print(f" [{layer}] [f0] {f0.shape if f0 is not None else None} [Theta] {theta_value:.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
391
 
 
392
  if "rot3" in self.debug and self.counter % 100 == 0:
393
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx} [Radius] {radius.shape} {radius.mean():.2f}")
394
 
 
413
  x1 = x1.view(orig_shape)
414
  return torch.cat([x1.type_as(x), x2], dim=-1)
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  class MultiheadA(nn.Module):
417
  _seen = set()
418
  rbf = False
 
444
  dims=dims,
445
  head=head,
446
  debug=debug,
447
+ radii=True,
 
448
  )
449
  else:
450
  self.rope = None
 
496
 
497
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
498
  if self.rope.use_pbias:
499
+ f0 = enc.get("f0", None) if enc is not None else None
500
+ pbias = self.rope.use_pbias(f0)
501
  if pbias is not None:
502
+ qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
503
  token_ids = k[:, :, :, 0]
504
+ zscale = torch.ones_like(token_ids, device=device, dtype=dtype)
505
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
506
  zscale[token_ids.float() == self.pad_token] = fzero
507
 
 
590
  self.t_gate = tgate
591
  self.m_gate = mgate
592
  self.c_gate = cgate
593
+ self.skip_gates=True
594
 
595
  self.blend = nn.Parameter(torch.tensor(0.5))
596
 
 
600
  "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
601
  act_fn = act_map.get(act, nn.GELU())
602
 
603
+ self.attna = MultiheadA(dims=dims, head=head, rotary_emb=True, debug=debug)
604
+ self.attnb = (MultiheadA(dims=dims, head=head, rotary_emb=True, debug=debug) if cross_attn else None)
605
 
606
  mlp = dims * 4
607
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
 
632
 
633
  normx = self.lnc(x)
634
  mlp_out = self.mlp(normx)
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
+ if self.skip_gates:
637
+ x = x + mlp_out
638
  else:
639
+ if self.t_gate:
640
+ gate = self.t_gate(normx)
641
+ x = x + gate * mlp_out
642
+
643
+ elif self.m_gate:
644
+ gate = self.m_gate(normx)
645
+ x = x + gate * mlp_out
646
+
647
+ elif self.c_gate:
648
+ gate_output = self.c_gate(normx, self.features)
649
+ x = x + gate_output
650
+
651
  else:
652
+ if hasattr(self, 'mlp_gate'):
653
+ mlp_gate = self.mlp_gate(normx)
654
+ x = x + mlp_gate * mlp_out
655
+ else:
656
+ x = x + mlp_out
657
 
658
  if "residual" in self.debug and self.counter % 100 == 0:
659
  print(f"Step {self.counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
 
872
  )
873
  })
874
 
875
+ def forward(self, enc, layer="encoder"):
876
  enc = dict_to(enc, device, dtype)
877
 
878
  if self.counter < 1:
 
881
  p = default(enc.get("pitch"), enc.get("f0"))
882
  plot_waveform(x=s, w=w, p=p, hop_length=128)
883
 
 
 
 
884
  out = {}
885
  out.update(enc)
886
 
887
+ for f in self.features:
888
  if f in enc and f in self.blocks:
889
  x = enc[f]
890
  for block in self.blocks[f]:
 
893
 
894
  if "encoder" in self.debug and self.counter % 100 == 0:
895
  shapes = {k: v.shape for k, v in enc.items()}
896
+ print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
897
  self.counter += 1
898
  return out
899
 
 
941
  mask = self.mask[:x.shape[1], :x.shape[1]]
942
  x = self.token(x) + self.positional[:x.shape[1]]
943
  x = F.dropout(x, p=self.dropout, training=self.training)
944
+
945
  for block in self.block:
946
+ x = block(x, xa=None, mask=mask, enc=None, layer=layer)
947
 
948
  for f in order:
949
  if f in enc:
 
950
  xa = enc[f]
951
  for block in self.blocks[f]:
952
  out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
 
1000
  for name, module in self.encoder.named_modules():
1001
  if isinstance(module, (rotary)):
1002
  module.update_base(f0)
 
1003
 
1004
  for name, module in self.decoder.named_modules():
1005
  if isinstance(module, (rotary)):
1006
  module.update_base(f0)
 
1007
 
1008
  def set_alignment_head(self, dump: bytes):
1009
  array = np.frombuffer(
 
1019
  return self.decoder(input_ids, encoder_output)
1020
 
1021
  def forward(self,
1022
+ decoder_input_ids=None,
1023
  labels=None,
1024
  waveform: Optional[torch.Tensor]=None,
1025
  input_ids=None,
1026
  spectrogram: torch.Tensor=None,
1027
  pitch: Optional[torch.Tensor]=None,
1028
  f0: Optional[torch.Tensor]=None,
1029
+ f0d: Optional[torch.Tensor]=None,
1030
  envelope: Optional[torch.Tensor]=None,
1031
  phase: Optional[torch.Tensor]=None,
1032
  ) -> Dict[str, torch.Tensor]:
1033
 
1034
+ decoder_input_ids = input_ids
1035
  encoder_inputs = {}
1036
  if spectrogram is not None:
1037
  encoder_inputs["spectrogram"] = spectrogram
 
1045
  encoder_inputs["phase"] = phase
1046
  if f0 is not None:
1047
  encoder_inputs["f0"] = f0
1048
+
1049
  encoder_outputs = self.encoder(encoder_inputs)
1050
  logits = self.decoder(input_ids, encoder_outputs)
1051
 
 
1142
  self.counter = 0
1143
  print("Counter reset to 0.")
1144
 
1145
+ def ctx_to_samples(audio_ctx, hop_length):
1146
+ samples_token = hop_length * 2
1147
+ n_samples = audio_ctx * samples_token
1148
+ return n_samples
1149
+
1150
+ def load_wave(wave_data, sample_rate):
1151
+ if isinstance(wave_data, str):
1152
+ waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
1153
+ elif isinstance(wave_data, dict):
1154
+ waveform = torch.tensor(data=wave_data["array"]).float()
1155
+ sr = wave_data["sampling_rate"]
1156
+ else:
1157
+ raise TypeError("Invalid wave_data format.")
1158
+
1159
+ if sr != sample_rate:
1160
+ original_length = waveform.shape[1]
1161
+ target_length = int(original_length * (sample_rate / sr))
1162
 
1163
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
1164
+ waveform = resampler(waveform)
1165
+
1166
+ return waveform
1167
+
1168
+ def pad(array, target_length, axis=-1, dtype: torch.dtype = torch.float32):
1169
+ if isinstance(array, np.ndarray):
1170
+ array = torch.from_numpy(array).to(dtype)
1171
+ if torch.is_tensor(array):
1172
+ if array.shape[axis] > target_length:
1173
+ array = array.index_select(
1174
+ dim=axis,
1175
+ index=torch.arange(
1176
+ end=target_length, device=array.device, dtype=torch.long
1177
+ ),
1178
+ )
1179
+ if array.shape[axis] < target_length:
1180
+ pad_widths = [(0, 0)] * array.ndim
1181
+ pad_widths[axis] = (0, target_length - array.shape[axis])
1182
+ array = F.pad(
1183
+ input=array, pad=[pad for sizes in pad_widths[::-1] for pad in sizes]
1184
+ )
1185
+ array = array.to(dtype=dtype)
1186
+ else:
1187
+ raise TypeError(
1188
+ f"Unsupported input type: {type(array)}. Expected torch.Tensor or np.ndarray."
1189
+ )
1190
+ return array
1191
 
1192
+ def exact_div(x, y):
1193
+ assert x % y == 0
1194
+ return x // y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1195
 
1196
+ metrics = evaluate.load(path="wer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1197
 
1198
  def hilbert_transform(x):
1199
  N = x.shape[-1]
 
1246
  phase = torch.angle(analytic)
1247
  return envelope, phase
1248
 
1249
+ @dataclass
1250
+ class DataCollator:
1251
+ tokenizer: Any
1252
+
1253
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1254
+ all_keys = set()
1255
+ for f in features:
1256
+ all_keys.update(f.keys())
1257
+ batch = {}
1258
+ pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1259
+ bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1260
+
1261
+ for key in all_keys:
1262
+ if key == "label":
1263
+ labels_list = [f["label"] for f in features]
1264
+ max_len = max(len(l) for l in labels_list)
1265
+ all_ids, all_labels = [], []
1266
+ for label in labels_list:
1267
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1268
+ decoder_input = [bos_token_id] + label_list
1269
+ label_eos = label_list + [pad_token_id]
1270
+ input_len = max_len + 1 - len(decoder_input)
1271
+ label_len = max_len + 1 - len(label_eos)
1272
+ padded_input = decoder_input + [pad_token_id] * input_len
1273
+ padded_labels = label_eos + [pad_token_id] * label_len
1274
+ all_ids.append(padded_input)
1275
+ all_labels.append(padded_labels)
1276
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1277
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1278
+
1279
+ elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1280
+ items = [f[key] for f in features if key in f]
1281
+ max_len = max(item.shape[-1] for item in items)
1282
+ padded = []
1283
+ for item in items:
1284
+ pad_width = max_len - item.shape[-1]
1285
+ if pad_width > 0:
1286
+ pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1287
+ else:
1288
+ pad_item = item
1289
+ padded.append(pad_item)
1290
+ batch[key] = torch.stack(padded)
1291
+ if key == "spectrogram":
1292
+ batch["spectrogram"] = batch[key]
1293
+ return batch
1294
 
1295
  def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
1296
  hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
 
1376
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1377
  return batch
1378
 
1379
+ def compute_metrics(pred, tokenizer):
1380
+ pred_ids = pred["predictions"]
1381
+ label_ids = pred["label_ids"]
1382
+ if isinstance(pred_ids, tuple):
1383
+ pred_ids = pred_ids[0]
 
 
 
 
 
 
 
1384
  else:
1385
+ pred_ids = pred_ids
1386
+ if pred_ids.ndim == 3:
1387
+ pred_ids = np.argmax(pred_ids, axis=-1)
1388
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1389
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1390
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1391
+ wer = metrics.compute(predictions=pred_str, references=label_str)
1392
+ return {"wer": wer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1393
 
1394
  logger = logging.getLogger(__name__)
1395
 
 
1414
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1415
  ids = [id for id in ids if id not in sp_ids]
1416
  return ids
1417
+
1418
  def bdec(ids_list, skip_special_tokens=True):
1419
  results = []
1420
  for ids in ids_list:
1421
+ if not isinstance(ids, list):
1422
+ ids = ids.tolist()
1423
  if skip_special_tokens:
1424
  ids = [id for id in ids if id not in [0, 1, 2]]
1425
  results.append(tokenizer.decode(ids))
1426
+ return results
1427
  def save_pretrained(save_dir):
1428
  os.makedirs(save_dir, exist_ok=True)
1429
  tokenizer.save(f"{save_dir}/tokenizer.json")
 
1436
  return tokenizer
1437
 
1438
  def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
1439
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1440
  if sanity_check:
1441
+
1442
+ dataset = load_dataset(
1443
+ "./librispeech_asr.py", "clean", "train.100",
1444
+ storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}},
1445
+ token=token, trust_remote_code=True, streaming=False)
1446
+
1447
+ dataset = dataset.rename_column("text", "transcription")
1448
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1449
+
1450
  dataset = dataset["test"].take(10)
1451
  dataset = dataset.select_columns(["audio", "transcription"])
 
 
1452
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1453
+ dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
 
 
 
 
1454
  train_dataset = dataset
1455
  test_dataset = dataset
1456
  else:
1457
+ cache_dir = "./processed_datasets"
1458
+ os.makedirs(cache_dir, exist_ok=True)
1459
+ cache_file_train = os.path.join(cache_dir, "train.arrow")
1460
+ cache_file_test = os.path.join(cache_dir, "test.arrow")
1461
+
1462
+ if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1463
+ from datasets import Dataset
1464
+ train_dataset = Dataset.load_from_disk(cache_file_train)
1465
+ test_dataset = Dataset.load_from_disk(cache_file_test)
1466
+ return train_dataset, test_dataset
1467
+
1468
+ if dataset_config is None:
1469
+ dataset_config = {
1470
+ "spectrogram": True,
1471
+ "waveforms": True,
1472
+ "pitch": True,
1473
+ "frequency": True,
1474
+ "downsamples": True,
1475
+ "hop_length": 128,
1476
+ "fmin": 50,
1477
+ "fmax": 2000,
1478
+ "n_mels": 128,
1479
+ "n_fft": 1024,
1480
+ "sampling_rate": 16000,
1481
+ }
1482
+
1483
+ dataset = load_dataset(
1484
+ "./librispeech_asr.py", "clean", "train.100",
1485
+ storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}},
1486
+ token=token, trust_remote_code=True, streaming=False)
1487
+
1488
+ dataset = dataset.rename_column("text", "transcription")
1489
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1490
+
1491
  def filter_func(x):
1492
  return (0 < len(x["transcription"]) < 512 and
1493
  len(x["audio"]["array"]) > 0 and
1494
  len(x["audio"]["array"]) < 1500 * 160)
1495
 
1496
+ dataset = dataset.filter(filter_func)
 
 
1497
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1498
+
1499
+ train_dataset = dataset["train.100"].take(10000)
1500
+ test_dataset = dataset["test"].take(1000)
 
 
1501
  train_dataset = train_dataset.map(
1502
  function=prepare_fn,
1503
+ remove_columns=["audio", "transcription"]
1504
  ).with_format(type="torch")
1505
 
1506
  test_dataset = test_dataset.map(
1507
  function=prepare_fn,
1508
+ remove_columns=["audio", "transcription"]
1509
  ).with_format(type="torch")
1510
+
1511
+ train_dataset.save_to_disk(cache_file_train)
1512
+ test_dataset.save_to_disk(cache_file_test)
1513
+
1514
  return train_dataset, test_dataset
1515
 
1516
+ @dataclass
1517
+ class DataCollator:
1518
+ tokenizer: Any
1519
+
1520
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1521
+ all_keys = set()
1522
+ for f in features:
1523
+ all_keys.update(f.keys())
1524
+ batch = {}
1525
+ pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1526
+ bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1527
+
1528
+ for key in all_keys:
1529
+ if key == "label":
1530
+ labels_list = [f["label"] for f in features]
1531
+ max_len = max(len(l) for l in labels_list)
1532
+ all_ids, all_labels = [], []
1533
+ for label in labels_list:
1534
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1535
+ decoder_input = [bos_token_id] + label_list
1536
+ label_eos = label_list + [pad_token_id]
1537
+ input_len = max_len + 1 - len(decoder_input)
1538
+ label_len = max_len + 1 - len(label_eos)
1539
+ padded_input = decoder_input + [pad_token_id] * input_len
1540
+ padded_labels = label_eos + [pad_token_id] * label_len
1541
+ all_ids.append(padded_input)
1542
+ all_labels.append(padded_labels)
1543
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1544
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1545
+ elif key in ["spectrogram", "waveform", "pitch", "f0", "envelope", "phase"]:
1546
+ items = [f[key] for f in features if key in f]
1547
+ max_len = max(item.shape[-1] for item in items)
1548
+ padded = []
1549
+ for item in items:
1550
+ pad_width = max_len - item.shape[-1]
1551
+ if pad_width > 0:
1552
+ pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1553
+ else:
1554
+ pad_item = item
1555
+ padded.append(pad_item)
1556
+ batch[key] = torch.stack(padded)
1557
+ if key == "spectrogram":
1558
+ batch["spectrogram"] = batch[key]
1559
+ return batch
1560
+
1561
+ def train_and_evaluate(
1562
+ model, tokenizer, train_loader, eval_loader, optimizer, scheduler, loss_fn,
1563
+ max_steps=10000, device='cuda', accumulation_steps=1, clear_cache=True,
1564
+ log_interval=10, eval_interval=100, save_interval=1000,
1565
+ checkpoint_dir="checkpoint_dir", log_dir="log_dir"
1566
+ ):
1567
+ model.to(device)
1568
+ global_step = 0
1569
+ scaler = torch.GradScaler()
1570
+ writer = SummaryWriter(log_dir=log_dir)
1571
+ train_iterator = iter(train_loader)
1572
+ total_loss = 0
1573
+ step_in_report = 0
1574
+ dataset_epochs = 0
1575
+
1576
+ progress_bar = tqdm(total=max_steps, desc="Training Progress", leave=True, colour='green')
1577
+
1578
+ model.train()
1579
+ optimizer.zero_grad()
1580
+
1581
+ while global_step < max_steps:
1582
+ try:
1583
+ batch = next(train_iterator)
1584
+ except StopIteration:
1585
+ train_iterator = iter(train_loader)
1586
+ batch = next(train_iterator)
1587
+ dataset_epochs += 1
1588
+ print(f"Starting dataset epoch {dataset_epochs}")
1589
+
1590
+ if step_in_report > 0:
1591
+ avg_loss = total_loss / step_in_report
1592
+ logging.info(f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}")
1593
+ total_loss = 0
1594
+ step_in_report = 0
1595
+
1596
+ start_time = time.time()
1597
+
1598
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
1599
+
1600
+ with torch.autocast(device_type="cuda"):
1601
+ output = model(**batch) if hasattr(model, '__call__') else model.forward(**batch)
1602
+ logits = output["logits"] if isinstance(output, dict) and "logits" in output else output
1603
+ labels = batch["labels"]
1604
+ active_logits = logits.view(-1, logits.size(-1))
1605
+ active_labels = labels.view(-1)
1606
+ active_mask = active_labels != 0
1607
+ active_logits = active_logits[active_mask]
1608
+ active_labels = active_labels[active_mask]
1609
+ loss = loss_fn(active_logits, active_labels)
1610
+ total_loss += loss.item()
1611
+ loss = loss / accumulation_steps
1612
+
1613
+ scaler.scale(loss).backward()
1614
+
1615
+ if (global_step + 1) % accumulation_steps == 0:
1616
+ scaler.unscale_(optimizer)
1617
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
1618
+ scaler.step(optimizer)
1619
+ scaler.update()
1620
+ optimizer.zero_grad()
1621
+ if clear_cache:
1622
+ torch.cuda.empty_cache()
1623
+
1624
+ end_time = time.time()
1625
+ samples_per_sec = batch["spectrogram"].size(0) / (end_time - start_time)
1626
+
1627
+ if global_step % log_interval == 0:
1628
+ writer.add_scalar(tag='Loss/train', scalar_value=total_loss / (global_step + 1), global_step=global_step)
1629
+ lr = scheduler.get_last_lr()[0]
1630
+ writer.add_scalar(tag='LearningRate', scalar_value=lr, global_step=global_step)
1631
+ writer.add_scalar(tag='SamplesPerSec', scalar_value=samples_per_sec, global_step=global_step)
1632
+
1633
+ if global_step % eval_interval == 0:
1634
+ model.eval()
1635
+ eval_start_time = time.time()
1636
+ eval_loss = 0
1637
+ all_predictions = []
1638
+ all_labels = []
1639
+ batch_count = 0
1640
+ total_samples = 0
1641
+
1642
+ with torch.no_grad():
1643
+ for eval_batch in eval_loader:
1644
+ eval_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in eval_batch.items()}
1645
+ output = model(**eval_batch) if hasattr(model, '__call__') else model.forward(**eval_batch)
1646
+ logits = output["logits"] if isinstance(output, dict) and "logits" in output else output
1647
+ labels = eval_batch["labels"]
1648
+ batch_size = logits.size(0)
1649
+ total_samples += batch_size
1650
+ loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
1651
+ eval_loss += loss.item()
1652
+ all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist())
1653
+ all_labels.extend(labels.cpu().numpy().tolist())
1654
+ batch_count += 1
1655
+
1656
+ eval_time = time.time() - eval_start_time
1657
+ loss_avg = eval_loss / batch_count if batch_count > 0 else 0
1658
+ predictions = {"predictions": np.array(all_predictions, dtype=object), "label_ids": np.array(all_labels, dtype=object)}
1659
+ metrics = compute_metrics(pred=predictions, tokenizer=tokenizer)
1660
+
1661
+ writer.add_scalar('Loss/eval', loss_avg, global_step)
1662
+ writer.add_scalar('WER', metrics['wer'], global_step)
1663
+ writer.add_scalar('EvalSamples', total_samples, global_step)
1664
+ writer.add_scalar('EvalTimeSeconds', eval_time, global_step)
1665
+
1666
+ lr = scheduler.get_last_lr()[0]
1667
+ print(f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}")
1668
+ logging.info(f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}")
1669
+ model.train()
1670
+
1671
+ if global_step % save_interval == 0:
1672
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt')
1673
+ torch.save(model.state_dict(), checkpoint_path)
1674
+ logging.info(f"Model saved at step {global_step} to {checkpoint_path}")
1675
+
1676
+ lr = scheduler.get_last_lr()[0]
1677
+ scheduler.step()
1678
+ global_step += 1
1679
+ step_in_report += 1
1680
+
1681
+ avg_loss = total_loss / (global_step + 1)
1682
+ postfix_dict = {
1683
+ 'loss': f'{avg_loss:.4f}',
1684
+ 'lr': f'{lr:.6f}',
1685
+ 'samp': f'{samples_per_sec:.1f}'
1686
+ }
1687
+ progress_bar.set_postfix(postfix_dict, refresh=True)
1688
+ progress_bar.update(1)
1689
+
1690
+ final_model_path = os.path.join(checkpoint_dir, 'final_model.pt')
1691
+ torch.save(model.state_dict(), final_model_path)
1692
+ print(f"Training completed after {global_step} steps. Final model saved to {final_model_path}")
1693
+ writer.close()
1694
+ progress_bar.close()
1695
+
1696
+ def get_optimizer(model, lr=5e-4, weight_decay=0.01):
1697
+ return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-6, betas=(0.9, 0.98))
1698
+
1699
+ def get_scheduler(optimizer, total_steps=10000):
1700
+ return torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.25, total_iters=total_steps, last_epoch=-1)
1701
+
1702
+ def get_loss_fn():
1703
+ return torch.nn.CrossEntropyLoss(ignore_index=0)
1704
 
1705
  def main():
 
1706
  token = ""
1707
+ log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
1708
+ os.makedirs(log_dir, exist_ok=True)
1709
  tokenizer = setup_tokenizer(token)
1710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1711
  param = Dimensions(
1712
+ mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4,
1713
+ vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4,
1714
+ act="swish", debug={}, cross_attn=True, features=["spectrogram"]
1715
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
1716
 
 
1717
  dataset_config = {
1718
+ "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False,
1719
+ "frequency": True, "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000,
1720
+ "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant",
1721
+ "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk",
1722
+ "norm": None, "normalized": False
1723
+ }
1724
+
 
 
 
 
 
 
 
 
 
 
 
 
 
1725
  model = create_model(param)
 
 
 
 
 
 
 
 
1726
  train_dataset, test_dataset = prepare_datasets(
1727
+ tokenizer=tokenizer, token=token, sanity_check=False, dataset_config=dataset_config
1728
+ )
1729
+
1730
+ collator = DataCollator(tokenizer=tokenizer)
1731
+ train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, num_workers=0)
1732
+ eval_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, num_workers=0)
1733
 
1734
+ optimizer = get_optimizer(model)
1735
+ scheduler = get_scheduler(optimizer)
1736
+ loss_fn = get_loss_fn()
1737
+
1738
+ train_and_evaluate(
1739
  model=model,
1740
+ tokenizer=tokenizer,
1741
+ train_loader=train_loader,
1742
+ eval_loader=eval_loader,
1743
+ optimizer=optimizer,
1744
+ scheduler=scheduler,
1745
+ loss_fn=loss_fn,
1746
+ max_steps=10000,
1747
+ device='cuda',
1748
+ accumulation_steps=1,
1749
+ clear_cache=False,
1750
+ log_interval=10,
1751
+ eval_interval=500,
1752
+ save_interval=10000,
1753
+ checkpoint_dir="./checkpoints",
1754
+ log_dir=log_dir
1755
+ )
1756
 
1757
  if __name__ == "__main__":
1758
  main()
1759