Sin2pi commited on
Commit
51d0009
·
verified ·
1 Parent(s): 018f69e

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +217 -209
model_hf.py CHANGED
@@ -1,7 +1,4 @@
1
  import os
2
- PATH = 'E:/hf'
3
- os.environ['HF_HOME'] = PATH
4
- os.environ['HF_DATASETS_CACHE'] = PATH
5
  import pyworld as pw
6
  import math
7
  import warnings
@@ -25,7 +22,9 @@ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
25
  import transformers
26
  import evaluate
27
  from dataclasses import dataclass
28
- import aiohttp
 
 
29
  torch.backends.cudnn.allow_tf32 = True
30
  torch.backends.cuda.matmul.allow_tf32 = True
31
  torch.set_float32_matmul_precision('high')
@@ -36,6 +35,25 @@ dtype = torch.float32
36
 
37
  warnings.filterwarnings("ignore")
38
  logging.basicConfig(level=logging.ERROR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @dataclass
41
  class Dimensions:
@@ -306,14 +324,12 @@ class rotary(nn.Module):
306
  return f0.to(device=device, dtype=dtype)
307
 
308
  def synth_f0(self, f0, ctx):
309
- # f0 = self.f0proj(f0)
310
  if f0.dim() == 1:
311
  length = f0.shape[0]
312
  if length == ctx:
313
  return f0
314
  frames = length / ctx
315
  idx = torch.arange(ctx, device=f0.device)
316
- # return torch.arange(1, ctx+1, device=f0.device, dtype=torch.float)
317
  return f0[idx]
318
 
319
  def align_f0(self, ctx, f0):
@@ -367,7 +383,6 @@ class rotary(nn.Module):
367
  else:
368
  theta = self.theta
369
  freqs = self.theta_freqs(theta)
370
-
371
  freqs = t[:, None] * freqs[None, :]
372
  if self.radii and f0 is not None and layer == "encoder":
373
  radius = f0.to(device, dtype)
@@ -377,7 +392,7 @@ class rotary(nn.Module):
377
  idx = torch.arange(ctx, device=f0.device)
378
  idx = (idx * F).long().clamp(0, L - 1)
379
  radius = radius[idx]
380
- rad = radius
381
  radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
382
  radius = torch.sigmoid(radius)
383
  else:
@@ -445,7 +460,7 @@ class MultiheadA(nn.Module):
445
  else:
446
  self.rope = None
447
 
448
- def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
449
  scale = (self.dims // self.head) ** -0.25
450
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
451
  if rbf_ratio <= 0.0:
@@ -457,30 +472,27 @@ class MultiheadA(nn.Module):
457
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
458
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
459
 
460
- def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio") -> tuple:
 
461
  x = x.to(device, dtype)
462
  if xa is not None:
463
  xa = xa.to(device, dtype)
464
-
465
- batch, ctx, dims = x.shape
466
  scale = (self.dims // self.head) ** -0.25
467
 
468
  z = default(xa, x).to(device, dtype)
469
  q = self.q(x)
470
  k = self.k(z)
471
  v = self.v(z)
472
- qlen = q.shape[1]
473
- klen = k.shape[1]
474
 
475
  if self.rotary_emb:
476
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
477
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
478
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
479
- qlen = q.shape[2]
480
- klen = k.shape[2]
481
 
482
- q = self.rope.apply_rotary(q, (self.rope(qlen, enc=enc, layer=layer)))
483
- k = self.rope.apply_rotary(k, (self.rope(klen, enc=enc, layer=layer)))
484
  else:
485
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
486
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
@@ -488,21 +500,21 @@ class MultiheadA(nn.Module):
488
  batch, head, ctx, head_dim = q.shape
489
 
490
  if self.rbf:
491
- qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
492
 
493
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
494
  if self.rope.use_pbias:
495
  f0 = enc.get("f0", None) if enc is not None else None
496
  pbias = self.rope.use_pbias(f0)
497
  if pbias is not None:
498
- qk = qk + pbias[:,:,:q.shape[2],:q.shape[2]]
499
  token_ids = k[:, :, :, 0]
500
  zscale = torch.ones_like(token_ids)
501
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
502
  zscale[token_ids.float() == self.pad_token] = fzero
503
 
504
  if mask is not None:
505
- mask = mask[:q.shape[2], :q.shape[2]]
506
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
507
  qk = qk * zscale.unsqueeze(-2)
508
  w = F.softmax(qk, dim=-1).to(q.dtype)
@@ -511,8 +523,129 @@ class MultiheadA(nn.Module):
511
  if "multihead" in self.debug and self.counter % 100 == 0:
512
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
513
  self.counter += 1
514
- return self.o(wv), qk.detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  class t_gate(nn.Module):
517
  def __init__(self, dims, num_types=4):
518
  super().__init__()
@@ -567,7 +700,6 @@ class c_gate(nn.Module):
567
  comb = torch.cat([s, w, p, e, ph], dim=-1)
568
  return self.integ(comb)
569
 
570
-
571
  class Residual(nn.Module):
572
  _seen = set()
573
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
@@ -667,7 +799,6 @@ class Residual(nn.Module):
667
  self.counter += 1
668
  return x
669
 
670
-
671
  class FEncoder(nn.Module):
672
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
673
  super().__init__()
@@ -1030,7 +1161,6 @@ class Echo(nn.Module):
1030
  phase: Optional[torch.Tensor]=None,
1031
  ) -> Dict[str, torch.Tensor]:
1032
 
1033
- decoder_input_ids = input_ids
1034
  encoder_inputs = {}
1035
  if spectrogram is not None:
1036
  encoder_inputs["spectrogram"] = spectrogram
@@ -1142,120 +1272,51 @@ class Echo(nn.Module):
1142
  print("Counter reset to 0.")
1143
 
1144
  metric = evaluate.load(path="wer")
1145
-
1146
  @dataclass
1147
  class DataCollator:
1148
  tokenizer: Any
 
1149
  def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1150
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1151
- bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
1152
-
1153
  batch = {}
1154
-
1155
- if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
1156
- spectrogram_list = [f["spectrogram"] for f in features]
1157
- max_len_feat = max(f.shape[-1] for f in spectrogram_list)
1158
- pad_spectrogram = []
1159
- for feat in spectrogram_list:
1160
- current_len = feat.shape[-1]
1161
- padding = max_len_feat - current_len
1162
- if padding > 0:
1163
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1164
- else:
1165
- pad_feat = feat
1166
- pad_spectrogram.append(pad_feat)
1167
- batch["spectrogram"] = torch.stack(pad_spectrogram)
1168
-
1169
- if "waveform" in features[0] and features[0]["waveform"] is not None:
1170
- waveform_list = [f["waveform"] for f in features]
1171
- max_len_wav = max(w.shape[-1] for w in waveform_list)
1172
- pad_waveforms = []
1173
- for wav in waveform_list:
1174
- current_len = wav.shape[-1]
1175
- padding = max_len_wav - current_len
1176
- if padding > 0:
1177
- if wav.ndim == 1:
1178
- wav = wav.unsqueeze(0)
1179
- pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
1180
- else:
1181
- pad_wav = wav
1182
- pad_waveforms.append(pad_wav)
1183
- batch["waveform"] = torch.stack(pad_waveforms)
1184
-
1185
- if "label" in features[0] and features[0]["label"] is not None:
1186
- labels_list = [f["label"] for f in features]
1187
- max_len = max(len(l) for l in labels_list)
1188
- all_ids = []
1189
- all_labels = []
1190
-
1191
- for label in labels_list:
1192
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1193
- decoder_input = [bos_token_id] + label_list
1194
- label_eos = label_list + [pad_token_id]
1195
- input_len = max_len + 1 - len(decoder_input)
1196
- label_len = max_len + 1 - len(label_eos)
1197
- padded_input = decoder_input + [pad_token_id] * input_len
1198
- padded_labels = label_eos + [pad_token_id] * label_len
1199
- all_ids.append(padded_input)
1200
- all_labels.append(padded_labels)
1201
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1202
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1203
-
1204
- if "pitch" in features[0] and features[0]["pitch"] is not None:
1205
- pitch_list = [f["pitch"] for f in features]
1206
- max_len_pitch = max(e.shape[-1] for e in pitch_list)
1207
- pad_pitch = []
1208
- for pitch in pitch_list:
1209
- current_len = pitch.shape[-1]
1210
- padding = max_len_pitch - current_len
1211
- if padding > 0:
1212
- pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
1213
- else:
1214
- pad_pitch_item = pitch
1215
- pad_pitch.append(pad_pitch_item)
1216
- batch["pitch"] = torch.stack(pad_pitch)
1217
-
1218
- if "f0" in features[0] and features[0]["f0"] is not None:
1219
- f0_list = [f["f0"] for f in features]
1220
- max_len_f0 = max(f.shape[-1] for f in f0_list)
1221
- pad_f0 = []
1222
- for f0 in f0_list:
1223
- current_len = f0.shape[-1]
1224
- padding = max_len_f0 - current_len
1225
- if padding > 0:
1226
- pad_f0_item = F.pad(f0, (0, padding), mode='constant', value=pad_token_id)
1227
- else:
1228
- pad_f0_item = f0
1229
- pad_f0.append(pad_f0_item)
1230
- batch["f0"] = torch.stack(pad_f0)
1231
-
1232
- if "envelope" in features[0] and features[0]["envelope"] is not None:
1233
- env_list = [f["envelope"] for f in features]
1234
- max_len = max(f.shape[-1] for f in env_list)
1235
- pad_env = []
1236
- for feat in env_list:
1237
- current_len = feat.shape[-1]
1238
- padding = max_len - current_len
1239
- if padding > 0:
1240
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1241
- else:
1242
- pad_feat = feat
1243
- pad_env.append(pad_feat)
1244
- batch["envelope"] = torch.stack(pad_env)
1245
-
1246
- if "phase" in features[0] and features[0]["phase"] is not None:
1247
- ph_list = [f["phase"] for f in features]
1248
- max_len = max(f.shape[-1] for f in ph_list)
1249
- pad_ph = []
1250
- for feat in ph_list:
1251
- current_len = feat.shape[-1]
1252
- padding = max_len - current_len
1253
- if padding > 0:
1254
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
1255
- else:
1256
- pad_feat = feat
1257
- pad_ph.append(pad_feat)
1258
- batch["phase"] = torch.stack(pad_ph)
1259
  return batch
1260
 
1261
  def hilbert_transform(x):
@@ -1335,8 +1396,6 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1335
  pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1336
  norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1337
 
1338
- dtype = torch.float32
1339
- device = torch.device("cuda:0")
1340
  audio = batch["audio"]
1341
  sampling_rate = audio["sampling_rate"]
1342
  sr = audio["sampling_rate"]
@@ -1414,75 +1473,37 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=
1414
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1415
  return batch
1416
 
1417
- def compute_metrics(eval_pred, compute_result: bool = True,
1418
- print_pred: bool = False, num_samples: int = 0, tokenizer=None, model=None):
1419
 
1420
- pred_logits = eval_pred.predictions
1421
- label_ids = eval_pred.label_ids
1422
 
1423
- if hasattr(pred_logits, "cpu"):
1424
- pred_logits = pred_logits.cpu()
1425
  else:
1426
- pred_logits = torch.tensor(pred_logits).cpu()
1427
- if hasattr(label_ids, "cpu"):
1428
- label_ids = label_ids.cpu()
1429
- else:
1430
- label_ids = torch.tensor(label_ids).cpu()
1431
 
1432
- if isinstance(pred_logits, tuple):
1433
- pred_ids = pred_logits[0]
1434
- else:
1435
- pred_ids = pred_logits
1436
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1437
- if not isinstance(pred_ids, torch.Tensor):
1438
- pred_ids = torch.tensor(pred_ids)
1439
  pred_ids = pred_ids.argmax(dim=-1)
1440
- pred_ids = pred_ids.tolist()
1441
-
1442
- if hasattr(label_ids, "tolist"):
1443
- label_ids = label_ids.tolist()
1444
-
1445
- label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
1446
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1447
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1448
 
1449
  if print_pred:
 
 
1450
  for i in range(min(num_samples, len(pred_str))):
1451
  print(f"Preds: {pred_str[i]}")
1452
  print(f"Label: {label_str[i]}")
 
 
1453
  print("--------------------------------")
1454
 
1455
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1456
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1457
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1458
 
1459
- if model is None:
1460
- global global_model
1461
- if 'global_model' in globals():
1462
- model = global_model
1463
-
1464
- if model is not None:
1465
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1466
- if trainable_params > 0:
1467
- efficiency_score = (100 - wer) / trainable_params
1468
- else:
1469
- print("Warning: Zero trainable parameters detected")
1470
- efficiency_score = 0.0
1471
- else:
1472
- print("Warning: Model not available for parameter counting")
1473
- trainable_params = 0.0
1474
- efficiency_score = 0.0
1475
-
1476
- if hasattr(wer, "item"):
1477
- wer = wer.item()
1478
-
1479
- metrics = {
1480
- "wer": float(wer),
1481
- "trainable_params_M": float(trainable_params),
1482
- "efficiency_score": float(efficiency_score),
1483
- }
1484
-
1485
- return metrics
1486
 
1487
  logger = logging.getLogger(__name__)
1488
 
@@ -1548,18 +1569,6 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
1548
  trust_remote_code=True,
1549
  streaming=False)
1550
 
1551
- # cache_dir = "./processed_datasets"
1552
- # os.makedirs(cache_dir, exist_ok=True)
1553
- # cache_file_train = os.path.join(cache_dir, "train.arrow")
1554
- # cache_file_test = os.path.join(cache_dir, "test.arrow")
1555
-
1556
- # if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
1557
- # from datasets import Dataset
1558
- # train_dataset = Dataset.load_from_disk(cache_file_train)
1559
- # test_dataset = Dataset.load_from_disk(cache_file_test)
1560
- # return train_dataset, test_dataset
1561
-
1562
-
1563
  dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1564
 
1565
  if sanity_check:
@@ -1577,9 +1586,8 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
1577
 
1578
  dataset = dataset.filter(filter_func)
1579
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1580
- # columns_to_remove = list(next(iter(dataset.values())).features)
1581
- train_dataset = dataset["train"].take(1000)
1582
- test_dataset = dataset["test"].take(100)
1583
 
1584
  train_dataset = train_dataset.map(
1585
  function=prepare_fn,
@@ -1611,7 +1619,7 @@ def get_training_args(
1611
  return Seq2SeqTrainingArguments(
1612
  output_dir=log_dir,
1613
  per_device_train_batch_size=1,
1614
- per_device_eval_batch_size=2,
1615
  gradient_accumulation_steps=1,
1616
  eval_accumulation_steps=4,
1617
  eval_strategy="steps",
@@ -1669,11 +1677,11 @@ def main():
1669
  training_args = get_training_args(
1670
  log_dir,
1671
  batch_eval_metrics = False,
1672
- max_steps = 10000,
1673
  save_steps = 1005,
1674
- eval_steps = 1000,
1675
- warmup_steps = 1000,
1676
- logging_steps = 100,
1677
  eval_on_start = False,
1678
  learning_rate = 2.5e-4,
1679
  weight_decay = 0.01,
 
1
  import os
 
 
 
2
  import pyworld as pw
3
  import math
4
  import warnings
 
22
  import transformers
23
  import evaluate
24
  from dataclasses import dataclass
25
+ import pretty_errors
26
+ from rich.traceback import install
27
+
28
  torch.backends.cudnn.allow_tf32 = True
29
  torch.backends.cuda.matmul.allow_tf32 = True
30
  torch.set_float32_matmul_precision('high')
 
35
 
36
  warnings.filterwarnings("ignore")
37
  logging.basicConfig(level=logging.ERROR)
38
+ install(show_locals=True)
39
+
40
+ pretty_errors.configure(
41
+ separator_character = '*',
42
+ filename_display = pretty_errors.FILENAME_EXTENDED,
43
+ line_number_first = True,
44
+ display_link = True,
45
+ lines_before = 5,
46
+ lines_after = 2,
47
+ line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
48
+ code_color = ' ' + pretty_errors.default_config.line_color,
49
+ )
50
+
51
+ PATH = 'E:/hf'
52
+ os.environ['HF_HOME'] = PATH
53
+ os.environ['HF_DATASETS_CACHE'] = PATH
54
+ os.environ['TORCH_HOME'] = PATH
55
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
56
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
57
 
58
  @dataclass
59
  class Dimensions:
 
324
  return f0.to(device=device, dtype=dtype)
325
 
326
  def synth_f0(self, f0, ctx):
 
327
  if f0.dim() == 1:
328
  length = f0.shape[0]
329
  if length == ctx:
330
  return f0
331
  frames = length / ctx
332
  idx = torch.arange(ctx, device=f0.device)
 
333
  return f0[idx]
334
 
335
  def align_f0(self, ctx, f0):
 
383
  else:
384
  theta = self.theta
385
  freqs = self.theta_freqs(theta)
 
386
  freqs = t[:, None] * freqs[None, :]
387
  if self.radii and f0 is not None and layer == "encoder":
388
  radius = f0.to(device, dtype)
 
392
  idx = torch.arange(ctx, device=f0.device)
393
  idx = (idx * F).long().clamp(0, L - 1)
394
  radius = radius[idx]
395
+
396
  radius = radius.unsqueeze(-1).expand(-1, freqs.shape[-1])
397
  radius = torch.sigmoid(radius)
398
  else:
 
460
  else:
461
  self.rope = None
462
 
463
+ def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
464
  scale = (self.dims // self.head) ** -0.25
465
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
466
  if rbf_ratio <= 0.0:
 
472
  rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
473
  return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
474
 
475
+ def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None, enc = None, layer = None, feature_type="audio", need_weights=True) -> tuple:
476
+
477
  x = x.to(device, dtype)
478
  if xa is not None:
479
  xa = xa.to(device, dtype)
 
 
480
  scale = (self.dims // self.head) ** -0.25
481
 
482
  z = default(xa, x).to(device, dtype)
483
  q = self.q(x)
484
  k = self.k(z)
485
  v = self.v(z)
 
 
486
 
487
  if self.rotary_emb:
488
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
489
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
490
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
491
+ q2 = q.shape[2]
492
+ k2 = k.shape[2]
493
 
494
+ q = self.rope.apply_rotary(q, (self.rope(q2, enc=enc, layer=layer)))
495
+ k = self.rope.apply_rotary(k, (self.rope(k2, enc=enc, layer=layer)))
496
  else:
497
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
498
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
500
  batch, head, ctx, head_dim = q.shape
501
 
502
  if self.rbf:
503
+ qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
504
 
505
  qk = (q * scale) @ (k * scale).transpose(-1, -2)
506
  if self.rope.use_pbias:
507
  f0 = enc.get("f0", None) if enc is not None else None
508
  pbias = self.rope.use_pbias(f0)
509
  if pbias is not None:
510
+ qk = qk + pbias[:,:,:q2,:q2]
511
  token_ids = k[:, :, :, 0]
512
  zscale = torch.ones_like(token_ids)
513
  fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
514
  zscale[token_ids.float() == self.pad_token] = fzero
515
 
516
  if mask is not None:
517
+ mask = mask[:q2, :q2]
518
  qk = qk + mask.unsqueeze(0).unsqueeze(0) * zscale.unsqueeze(-2).expand(qk.shape)
519
  qk = qk * zscale.unsqueeze(-2)
520
  w = F.softmax(qk, dim=-1).to(q.dtype)
 
523
  if "multihead" in self.debug and self.counter % 100 == 0:
524
  print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
525
  self.counter += 1
526
+ return self.o(wv), qk
527
+
528
+ class SpanPredictor(nn.Module):
529
+ def __init__(self, dims):
530
+ super().__init__()
531
+ self.linear = nn.Linear(in_features=dims, out_features=1)
532
+
533
+ def forward(self, global_out):
534
+ scale = torch.sigmoid(self.linear(global_out))
535
+ return scale
536
+
537
+ class FocusA(nn.Module):
538
+ def __init__(self, base: int, dims: int, head: int, max_dist: int, sharpen: bool,
539
+ win_size: int = 32, max_span: int = 32, slid_win: int = 32,
540
+ temp_scale: float = 0.01, num_iterations: int = 3):
541
+
542
+ super().__init__()
543
+ self.base = base
544
+ self.dims = dims
545
+ self.head = head
546
+ self.max_dist = max_dist
547
+ self.sharpen = sharpen
548
+ self.win_size = win_size
549
+ self.max_span = max_span
550
+ self.slid_win = slid_win
551
+ self.temp_scale = temp_scale
552
+ self.num_iterations = num_iterations
553
+ self.span_predictor = SpanPredictor(dims=dims)
554
+ self.span_scale_param = nn.Parameter(torch.tensor(1.0))
555
+
556
+ self.attn_local = nn.MultiheadAttention(embed_dim=dims, num_heads=head, batch_first=True)
557
+ self.attn_global = nn.MultiheadAttention(embed_dim=dims, num_heads=head, batch_first=True)
558
+
559
+ self.ln_local = nn.LayerNorm(normalized_shape=dims)
560
+ self.ln_global = nn.LayerNorm(normalized_shape=dims)
561
+ self.projection = nn.Linear(in_features=2 * dims, out_features=dims)
562
+
563
+ def _focus(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, span_scale: torch.Tensor) -> torch.Tensor:
564
+
565
+ max_iterations = 1
566
+ iteration = 0
567
+ prev_attn_out = torch.zeros_like(query)
568
+ base_threshold = 1e-4
569
+ scaling_factor = 0.1
570
+
571
+ while iteration < max_iterations:
572
+ span_len = int(self.max_span * span_scale.mean().item())
573
+ span_len = min(span_len, query.size(1), key.size(1), value.size(1))
574
+ eff_span = min(span_len, self.max_dist)
575
+
576
+ q_span = query[:, :eff_span, :]
577
+ k_span = key[:, :eff_span, :]
578
+ v_span = value[:, :eff_span, :]
579
+
580
+ batch, ctx, dims = q_span.size()
581
+ scale_factor = (dims // self.head) ** -0.25
582
+
583
+ q = q_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
584
+ k = k_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
585
+ v = v_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
586
+
587
+ if self.sharpen:
588
+ temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
589
+ else:
590
+ temperature = 0.5 + self.temp_scale * span_scale.mean().item()
591
+
592
+ attn_scores = torch.matmul(q, k.transpose(-2, -1))
593
+ attn_weights = torch.softmax((attn_scores / temperature) * scale_factor, dim=-1)
594
+ attn_out = torch.matmul(attn_weights, v)
595
+
596
+ attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(batch, ctx, -1)
597
+
598
+ diff = torch.abs(attn_out - prev_attn_out).mean()
599
+ dynamic_threshold = base_threshold + scaling_factor * diff
600
+
601
+ if diff < dynamic_threshold:
602
+ break
603
+
604
+ prev_attn_out = attn_out
605
+ query = query + attn_out
606
+ iteration += 1
607
+
608
+ return attn_out, attn_weights
609
+
610
+ def _window(self, x: torch.Tensor, win_size: int, span_len: int, span_scale: torch.Tensor) -> torch.Tensor:
611
+
612
+ batch, ctx, dims = x.size()
613
+ num_windows = (ctx + win_size - 1) // win_size
614
+
615
+ output = torch.zeros_like(x, device=x.device)
616
+
617
+ for i in range(num_windows):
618
+ start_idx = i * win_size
619
+ end_idx = min((i + 1) * win_size, ctx)
620
+ query = x[:, start_idx:end_idx, :]
621
+
622
+ key_start = max(0, start_idx - span_len + win_size)
623
+ key_end = min(start_idx + span_len, ctx)
624
+ key = x[:, key_start:key_end, :]
625
+ value = x[:, key_start:key_end, :]
626
+
627
+ attn_out = self._focus(query, key, value, span_scale)
628
+ output[:, start_idx:end_idx, :] = attn_out
629
 
630
+ return output
631
+
632
+ def forward(self, x, xa=None, mask=None, kv_cache=None) -> torch.Tensor:
633
+ span_scale = self.span_predictor(x)
634
+ span_scale = torch.sigmoid(span_scale)
635
+
636
+ local_attn_out = self.attn_local(x, x, x)
637
+ local_attn_out = self.ln_local(local_attn_out)
638
+
639
+ global_attn_out = self.attn_global(x, x, x)
640
+ global_attn_out = self.ln_global(global_attn_out)
641
+
642
+ attn_out = torch.cat((local_attn_out, global_attn_out), dim=-1)
643
+ attn_out = self.projection(attn_out)
644
+
645
+ windowed_attn_out = self._window(attn_out, self.win_size, self.max_span, span_scale)
646
+ focused_attn_out = self._focus(windowed_attn_out, windowed_attn_out, windowed_attn_out, span_scale)
647
+ return focused_attn_out
648
+
649
  class t_gate(nn.Module):
650
  def __init__(self, dims, num_types=4):
651
  super().__init__()
 
700
  comb = torch.cat([s, w, p, e, ph], dim=-1)
701
  return self.integ(comb)
702
 
 
703
  class Residual(nn.Module):
704
  _seen = set()
705
  def __init__(self, ctx, dims, head, act, cross_attn=True, debug: List[str] = [],
 
799
  self.counter += 1
800
  return x
801
 
 
802
  class FEncoder(nn.Module):
803
  def __init__(self, input_dims, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None):
804
  super().__init__()
 
1161
  phase: Optional[torch.Tensor]=None,
1162
  ) -> Dict[str, torch.Tensor]:
1163
 
 
1164
  encoder_inputs = {}
1165
  if spectrogram is not None:
1166
  encoder_inputs["spectrogram"] = spectrogram
 
1272
  print("Counter reset to 0.")
1273
 
1274
  metric = evaluate.load(path="wer")
1275
+
1276
  @dataclass
1277
  class DataCollator:
1278
  tokenizer: Any
1279
+
1280
  def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
1281
+ all_keys = set()
1282
+ for f in features:
1283
+ all_keys.update(f.keys())
1284
  batch = {}
1285
+ pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
1286
+ bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
1287
+ eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
1288
+
1289
+ for key in all_keys:
1290
+ if key == "label":
1291
+ labels_list = [f["label"] for f in features]
1292
+ max_len = max(len(l) for l in labels_list)
1293
+ all_ids, all_labels = [], []
1294
+ for label in labels_list:
1295
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
1296
+ decoder_input = [bos_token_id] + label_list
1297
+ label_eos = label_list + [eos_token_id]
1298
+ input_len = max_len + 1 - len(decoder_input)
1299
+ label_len = max_len + 1 - len(label_eos)
1300
+ padded_input = decoder_input + [pad_token_id] * input_len
1301
+ padded_labels = label_eos + [pad_token_id] * label_len
1302
+ all_ids.append(padded_input)
1303
+ all_labels.append(padded_labels)
1304
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1305
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1306
+ elif key in ["spectrogram", "waveform", "pitch", "f0", "env", "phase"]:
1307
+ items = [f[key] for f in features if key in f]
1308
+ max_len = max(item.shape[-1] for item in items)
1309
+ padded = []
1310
+ for item in items:
1311
+ pad_width = max_len - item.shape[-1]
1312
+ if pad_width > 0:
1313
+ pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
1314
+ else:
1315
+ pad_item = item
1316
+ padded.append(pad_item)
1317
+ batch[key] = torch.stack(padded)
1318
+ if key == "spectrogram":
1319
+ batch["spectrogram"] = batch[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1320
  return batch
1321
 
1322
  def hilbert_transform(x):
 
1396
  pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
1397
  norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
1398
 
 
 
1399
  audio = batch["audio"]
1400
  sampling_rate = audio["sampling_rate"]
1401
  sr = audio["sampling_rate"]
 
1473
  batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
1474
  return batch
1475
 
1476
+ def compute_metrics(pred, compute_result: bool = True, print_pred: bool = False, num_samples: int = 0, tokenizer = None, model = None):
 
1477
 
1478
+ pred_ids = pred.predictions
1479
+ label_ids = pred.label_ids
1480
 
1481
+ if isinstance(pred_ids, tuple):
1482
+ pred_ids = pred_ids[0]
1483
  else:
1484
+ pred_ids = pred_ids
 
 
 
 
1485
 
1486
+ if pred_ids.ndim == 3:
 
 
 
 
 
 
1487
  pred_ids = pred_ids.argmax(dim=-1)
1488
+
1489
+ pred_ids = pred_ids.tolist()
1490
+ label_ids = label_ids.tolist()
 
 
 
 
 
1491
 
1492
  if print_pred:
1493
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1494
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
1495
  for i in range(min(num_samples, len(pred_str))):
1496
  print(f"Preds: {pred_str[i]}")
1497
  print(f"Label: {label_str[i]}")
1498
+ print(f"Preds: {pred_ids[i]}")
1499
+ print(f"Label: {label_ids[i]}")
1500
  print("--------------------------------")
1501
 
1502
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
1503
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1504
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1505
 
1506
+ return {"wer": wer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
 
1508
  logger = logging.getLogger(__name__)
1509
 
 
1569
  trust_remote_code=True,
1570
  streaming=False)
1571
 
 
 
 
 
 
 
 
 
 
 
 
 
1572
  dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1573
 
1574
  if sanity_check:
 
1586
 
1587
  dataset = dataset.filter(filter_func)
1588
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1589
+ train_dataset = dataset["train"]
1590
+ test_dataset = dataset["test"]
 
1591
 
1592
  train_dataset = train_dataset.map(
1593
  function=prepare_fn,
 
1619
  return Seq2SeqTrainingArguments(
1620
  output_dir=log_dir,
1621
  per_device_train_batch_size=1,
1622
+ per_device_eval_batch_size=1,
1623
  gradient_accumulation_steps=1,
1624
  eval_accumulation_steps=4,
1625
  eval_strategy="steps",
 
1677
  training_args = get_training_args(
1678
  log_dir,
1679
  batch_eval_metrics = False,
1680
+ max_steps = 1000,
1681
  save_steps = 1005,
1682
+ eval_steps = 100,
1683
+ warmup_steps = 100,
1684
+ logging_steps = 10,
1685
  eval_on_start = False,
1686
  learning_rate = 2.5e-4,
1687
  weight_decay = 0.01,