Sin2pi commited on
Commit
ebb8a2a
·
verified ·
1 Parent(s): 2c71a26

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +160 -320
model_hf.py CHANGED
@@ -3,15 +3,12 @@ import pyworld as pw
3
  import math
4
  import warnings
5
  import logging
6
- import gzip
7
- import base64
8
  import torch
9
  import torchaudio
10
  import torch.nn.functional as F
11
  import torch.nn.init as init
12
  from torch import nn, Tensor
13
  import numpy as np
14
- from einops import rearrange
15
  import matplotlib.pyplot as plt
16
  from typing import Optional, Dict, Union, List, Tuple, Any
17
  from functools import partial
@@ -22,8 +19,7 @@ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
22
  import transformers
23
  import evaluate
24
  from dataclasses import dataclass
25
- import pretty_errors
26
- from rich.traceback import install
27
 
28
  torch.backends.cudnn.allow_tf32 = True
29
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -35,25 +31,7 @@ dtype = torch.float32
35
 
36
  warnings.filterwarnings("ignore")
37
  logging.basicConfig(level=logging.ERROR)
38
- install(show_locals=True)
39
-
40
- pretty_errors.configure(
41
- separator_character = '*',
42
- filename_display = pretty_errors.FILENAME_EXTENDED,
43
- line_number_first = True,
44
- display_link = True,
45
- lines_before = 5,
46
- lines_after = 2,
47
- line_color = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
48
- code_color = ' ' + pretty_errors.default_config.line_color,
49
- )
50
-
51
- PATH = 'E:/hf'
52
- os.environ['HF_HOME'] = PATH
53
- os.environ['HF_DATASETS_CACHE'] = PATH
54
- os.environ['TORCH_HOME'] = PATH
55
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
56
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
57
 
58
  @dataclass
59
  class Dimensions:
@@ -72,35 +50,36 @@ class Dimensions:
72
  cross_attn: bool
73
  features: List[str]
74
 
 
75
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
76
  title="", markers=None, marker_labels=None,
77
  show_voiced_regions=True, show_energy=False):
78
  num_plots = sum([x is not None, w is not None, p is not None, per is not None])
79
  if num_plots == 0:
80
  raise ValueError("No data to plot. Please provide at least one input tensor.")
81
- time_spans = []
82
 
83
  if w is not None:
84
  w_np = w[sample_idx].detach().cpu().numpy()
85
  if w_np.ndim > 1:
86
  w_np = w_np.squeeze()
87
- time_spans.append(len(w_np) / sr)
88
  if x is not None:
89
  x_np = x[sample_idx].detach().cpu().numpy()
90
  if x_np.shape[0] < x_np.shape[1]:
91
  x_np = x_np.T
92
- time_spans.append(x_np.shape[0] * hop_length / sr)
93
  if p is not None:
94
  p_np = p[sample_idx].detach().cpu().numpy()
95
  if p_np.ndim > 1:
96
  p_np = p_np.squeeze()
97
- time_spans.append(len(p_np) * hop_length / sr)
98
  if per is not None:
99
  per_np = per[sample_idx].detach().cpu().numpy()
100
  if per_np.ndim > 1:
101
  per_np = per_np.squeeze()
102
- time_spans.append(len(per_np) * hop_length / sr)
103
- max_time = max(time_spans) if time_spans else 0
104
  fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
105
  if num_plots == 1:
106
  axs = [axs]
@@ -114,13 +93,13 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
114
  for i in range(len(per_np)-1):
115
  if per_np[i] > threshold:
116
  ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
117
- current_ax = 0
118
  if w is not None:
119
  w_np = w[sample_idx].detach().cpu().numpy()
120
  if w_np.ndim > 1:
121
  w_np = w_np.squeeze()
122
  t = np.arange(len(w_np)) / sr
123
- axs[current_ax].plot(t, w_np, color="tab:blue")
124
 
125
  if show_energy:
126
  frame_length = hop_length
@@ -132,51 +111,51 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
132
  energy = np.array(energy)
133
  energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
134
  t_energy = np.arange(len(energy)) * hop_length_energy / sr
135
- axs[current_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
136
- axs[current_ax].legend(loc='upper right')
137
- axs[current_ax].set_title("Waveform")
138
- axs[current_ax].set_ylabel("Amplitude")
139
- axs[current_ax].set_xlim([0, max_time])
140
- axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
141
- current_ax += 1
142
 
143
  if x is not None:
144
  x_np = x[sample_idx].detach().cpu().numpy()
145
  if x_np.shape[0] < x_np.shape[1]:
146
  x_np = x_np.T
147
- im = axs[current_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
148
  extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
149
- axs[current_ax].set_title("Spectrogram")
150
- axs[current_ax].set_ylabel("Mel Bin")
151
- axs[current_ax].set_xlim([0, max_time])
152
- axs[current_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
153
- current_ax += 1
154
 
155
  if p is not None:
156
  p_np = p[sample_idx].detach().cpu().numpy()
157
  if p_np.ndim > 1:
158
  p_np = p_np.squeeze()
159
  t_p = np.arange(len(p_np)) * hop_length / sr
160
- axs[current_ax].plot(t_p, p_np, color="tab:green")
161
- axs[current_ax].set_title("Pitch")
162
- axs[current_ax].set_ylabel("Frequency (Hz)")
163
- axs[current_ax].set_xlim([0, max_time])
164
- axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
165
- axs[current_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
166
- current_ax += 1
167
 
168
  if per is not None:
169
  per_np = per[sample_idx].detach().cpu().numpy()
170
  if per_np.ndim > 1:
171
  per_np = per_np.squeeze()
172
  t_per = np.arange(len(per_np)) * hop_length / sr
173
- axs[current_ax].plot(t_per, per_np, color="tab:red")
174
- axs[current_ax].set_title("Period (Voice Activity)")
175
- axs[current_ax].set_ylabel("periodocity")
176
- axs[current_ax].set_xlim([0, max_time])
177
- axs[current_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
178
- axs[current_ax].set_ylim([-0.05, 1.05])
179
- axs[current_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
180
 
181
  if markers is not None:
182
  for i, t in enumerate(markers):
@@ -185,7 +164,7 @@ def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_
185
  ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
186
  if marker_labels:
187
  axs[0].legend(loc='upper right', fontsize='small')
188
- axs[-1].set_xlabel("Time (s)")
189
  fig.suptitle(title, fontsize=16)
190
  plt.tight_layout(rect=[0, 0, 1, 0.97])
191
  plt.show()
@@ -254,15 +233,16 @@ def get_dtype():
254
  def tox():
255
  return {"device": get_device(), "dtype": get_dtype()}
256
 
257
- def sinusoids(length, channels, max_timescale=10000):
258
  assert channels % 2 == 0
259
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
260
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
261
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
262
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
 
263
 
264
  class rotary(nn.Module):
265
- def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=False, debug: List[str] = [], use_pbias=False):
266
  super(rotary, self).__init__()
267
 
268
  self.use_pbias = use_pbias
@@ -275,7 +255,7 @@ class rotary(nn.Module):
275
  self.counter = 0
276
  self.last_theta = None
277
 
278
- self.f0_proj = nn.Linear(1, self.head_dim // 2) if radii else None
279
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
280
 
281
  def theta_freqs(self, theta):
@@ -323,15 +303,6 @@ class rotary(nn.Module):
323
  f0 = f0.squeeze(0)
324
  return f0.to(device=device, dtype=dtype)
325
 
326
- def synth_f0(self, f0, ctx):
327
- if f0.dim() == 1:
328
- length = f0.shape[0]
329
- if length == ctx:
330
- return f0
331
- frames = length / ctx
332
- idx = torch.arange(ctx, device=f0.device)
333
- return f0[idx]
334
-
335
  def align_f0(self, ctx, f0):
336
  f0 = self.f0proj(f0)
337
  if f0.dim() == 3:
@@ -382,7 +353,9 @@ class rotary(nn.Module):
382
  theta = f0_mean + self.theta
383
  else:
384
  theta = self.theta
 
385
  freqs = self.theta_freqs(theta)
 
386
  freqs = t[:, None] * freqs[None, :]
387
 
388
  if self.radii and f0 is not None:
@@ -422,6 +395,8 @@ class rotary(nn.Module):
422
  x1 = x1.view(orig_shape)
423
  return torch.cat([x1.type_as(x), x2], dim=-1)
424
 
 
 
425
  class MultiheadA(nn.Module):
426
  _seen = set()
427
  rbf = False
@@ -435,10 +410,10 @@ class MultiheadA(nn.Module):
435
  self.debug = debug
436
  self.counter = 0
437
 
438
- self.q = Linear(dims, dims).to(device, dtype)
439
- self.k = Linear(dims, dims, bias=False).to(device, dtype)
440
- self.v = Linear(dims, dims).to(device, dtype)
441
- self.o = Linear(dims, dims).to(device, dtype)
442
 
443
  self.pad_token = 0
444
  self.rotary_emb = rotary_emb
@@ -458,6 +433,15 @@ class MultiheadA(nn.Module):
458
  else:
459
  self.rope = None
460
 
 
 
 
 
 
 
 
 
 
461
  def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
462
  scale = (self.dims // self.head) ** -0.25
463
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
@@ -523,127 +507,7 @@ class MultiheadA(nn.Module):
523
  self.counter += 1
524
  return self.o(wv), qk
525
 
526
- class SpanPredictor(nn.Module):
527
- def __init__(self, dims):
528
- super().__init__()
529
- self.linear = nn.Linear(in_features=dims, out_features=1)
530
-
531
- def forward(self, global_out):
532
- scale = torch.sigmoid(self.linear(global_out))
533
- return scale
534
-
535
- class FocusA(nn.Module):
536
- def __init__(self, base: int, dims: int, head: int, max_dist: int, sharpen: bool,
537
- win_size: int = 32, max_span: int = 32, slid_win: int = 32,
538
- temp_scale: float = 0.01, num_iterations: int = 3):
539
-
540
- super().__init__()
541
- self.base = base
542
- self.dims = dims
543
- self.head = head
544
- self.max_dist = max_dist
545
- self.sharpen = sharpen
546
- self.win_size = win_size
547
- self.max_span = max_span
548
- self.slid_win = slid_win
549
- self.temp_scale = temp_scale
550
- self.num_iterations = num_iterations
551
- self.span_predictor = SpanPredictor(dims=dims)
552
- self.span_scale_param = nn.Parameter(torch.tensor(1.0))
553
-
554
- self.attn_local = nn.MultiheadAttention(embed_dim=dims, num_heads=head, batch_first=True)
555
- self.attn_global = nn.MultiheadAttention(embed_dim=dims, num_heads=head, batch_first=True)
556
-
557
- self.ln_local = nn.LayerNorm(normalized_shape=dims)
558
- self.ln_global = nn.LayerNorm(normalized_shape=dims)
559
- self.projection = nn.Linear(in_features=2 * dims, out_features=dims)
560
-
561
- def _focus(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, span_scale: torch.Tensor) -> torch.Tensor:
562
-
563
- max_iterations = 1
564
- iteration = 0
565
- prev_attn_out = torch.zeros_like(query)
566
- base_threshold = 1e-4
567
- scaling_factor = 0.1
568
-
569
- while iteration < max_iterations:
570
- span_len = int(self.max_span * span_scale.mean().item())
571
- span_len = min(span_len, query.size(1), key.size(1), value.size(1))
572
- eff_span = min(span_len, self.max_dist)
573
-
574
- q_span = query[:, :eff_span, :]
575
- k_span = key[:, :eff_span, :]
576
- v_span = value[:, :eff_span, :]
577
-
578
- batch, ctx, dims = q_span.size()
579
- scale_factor = (dims // self.head) ** -0.25
580
-
581
- q = q_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
582
- k = k_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
583
- v = v_span.view(batch, ctx, self.head, -1).permute(0, 2, 1, 3)
584
-
585
- if self.sharpen:
586
- temperature = 1.0 + self.temp_scale * (1.0 - span_scale.mean().item())
587
- else:
588
- temperature = 0.5 + self.temp_scale * span_scale.mean().item()
589
-
590
- attn_scores = torch.matmul(q, k.transpose(-2, -1))
591
- attn_weights = torch.softmax((attn_scores / temperature) * scale_factor, dim=-1)
592
- attn_out = torch.matmul(attn_weights, v)
593
-
594
- attn_out = attn_out.permute(0, 2, 1, 3).contiguous().view(batch, ctx, -1)
595
 
596
- diff = torch.abs(attn_out - prev_attn_out).mean()
597
- dynamic_threshold = base_threshold + scaling_factor * diff
598
-
599
- if diff < dynamic_threshold:
600
- break
601
-
602
- prev_attn_out = attn_out
603
- query = query + attn_out
604
- iteration += 1
605
-
606
- return attn_out, attn_weights
607
-
608
- def _window(self, x: torch.Tensor, win_size: int, span_len: int, span_scale: torch.Tensor) -> torch.Tensor:
609
-
610
- batch, ctx, dims = x.size()
611
- num_windows = (ctx + win_size - 1) // win_size
612
-
613
- output = torch.zeros_like(x, device=x.device)
614
-
615
- for i in range(num_windows):
616
- start_idx = i * win_size
617
- end_idx = min((i + 1) * win_size, ctx)
618
- query = x[:, start_idx:end_idx, :]
619
-
620
- key_start = max(0, start_idx - span_len + win_size)
621
- key_end = min(start_idx + span_len, ctx)
622
- key = x[:, key_start:key_end, :]
623
- value = x[:, key_start:key_end, :]
624
-
625
- attn_out = self._focus(query, key, value, span_scale)
626
- output[:, start_idx:end_idx, :] = attn_out
627
-
628
- return output
629
-
630
- def forward(self, x, xa=None, mask=None, kv_cache=None) -> torch.Tensor:
631
- span_scale = self.span_predictor(x)
632
- span_scale = torch.sigmoid(span_scale)
633
-
634
- local_attn_out = self.attn_local(x, x, x)
635
- local_attn_out = self.ln_local(local_attn_out)
636
-
637
- global_attn_out = self.attn_global(x, x, x)
638
- global_attn_out = self.ln_global(global_attn_out)
639
-
640
- attn_out = torch.cat((local_attn_out, global_attn_out), dim=-1)
641
- attn_out = self.projection(attn_out)
642
-
643
- windowed_attn_out = self._window(attn_out, self.win_size, self.max_span, span_scale)
644
- focused_attn_out = self._focus(windowed_attn_out, windowed_attn_out, windowed_attn_out, span_scale)
645
- return focused_attn_out
646
-
647
  class t_gate(nn.Module):
648
  def __init__(self, dims, num_types=4):
649
  super().__init__()
@@ -745,18 +609,14 @@ class Residual(nn.Module):
745
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
746
 
747
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
748
- x = x.to(device, dtype)
749
- if xa is not None:
750
- xa = xa.to(device, dtype)
751
 
752
- blend = self.blend
753
  x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
754
  xb = x
755
  if self.attnb and xa is not None:
756
  x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
757
 
758
  if self.do_blend:
759
- b = torch.sigmoid(blend)
760
  x = b * xb + (1 - b) * x
761
 
762
  if self.skip_gates:
@@ -978,37 +838,35 @@ class AudioEncoder(nn.Module):
978
  cgate = False
979
 
980
  self.blocks = nn.ModuleDict({
 
981
  "spectrogram": nn.ModuleList(
982
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
983
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "spectrogram" in features else None
984
- ),
 
985
  "waveform": nn.ModuleList(
986
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
987
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "waveform" in features else None
988
- ),
 
989
  "pitch": nn.ModuleList(
990
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
991
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "pitch" in features else None
992
- ),
 
993
  "envelope": nn.ModuleList(
994
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
995
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "envelope" in features else None
996
- ),
 
997
  "phase": nn.ModuleList(
998
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
999
- [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)] if "phase" in features else None
1000
- )
1001
- })
1002
 
1003
  def forward(self, enc, layer="encoder"):
1004
  enc = dict_to(enc, device, dtype)
1005
-
1006
- if self.counter < 1:
1007
- s = enc.get("spectrogram")
1008
- w = enc.get("waveform")
1009
- p = default(enc.get("pitch"), enc.get("f0"))
1010
- plot_waveform(x=s, w=w, p=p, hop_length=128)
1011
-
1012
  out = {}
1013
  out.update(enc)
1014
 
@@ -1018,13 +876,18 @@ class AudioEncoder(nn.Module):
1018
  for block in self.blocks[f]:
1019
  x = block(x, enc=enc, layer=layer)
1020
  out[f] = x
1021
-
1022
- if "encoder" in self.debug and self.counter % 100 == 0:
 
 
 
 
1023
  shapes = {k: v.shape for k, v in enc.items()}
1024
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
1025
  self.counter += 1
1026
  return out
1027
 
 
1028
  class TextDecoder(nn.Module):
1029
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
1030
  debug: List[str], features: List[str]):
@@ -1038,6 +901,8 @@ class TextDecoder(nn.Module):
1038
  self.counter = 0
1039
  self.dropout = 0.01
1040
  self.features = features
 
 
1041
 
1042
  self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
1043
  with torch.no_grad():
@@ -1058,10 +923,7 @@ class TextDecoder(nn.Module):
1058
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1059
  self.register_buffer("mask", mask, persistent=False)
1060
 
1061
- def forward(self, x, enc, order=None, layer='decoder', sequential=False) -> Tensor:
1062
- enc = dict_to(enc, device, dtype)
1063
- x = x.to(device)
1064
- bln = self.blend
1065
 
1066
  if order is None:
1067
  order = self.features
@@ -1070,6 +932,7 @@ class TextDecoder(nn.Module):
1070
  x = self.token(x) + self.positional[:x.shape[1]]
1071
  x = F.dropout(x, p=self.dropout, training=self.training)
1072
 
 
1073
  for block in self.block:
1074
  x = block(x, xa=None, mask=mask, enc=None, layer=layer)
1075
 
@@ -1079,24 +942,25 @@ class TextDecoder(nn.Module):
1079
  for block in self.blocks[f]:
1080
  out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
1081
 
1082
- if sequential:
1083
  x = out
1084
  else:
1085
- a = torch.sigmoid(bln[f])
1086
  x = a * out + (1 - a) * x
1087
 
1088
- if "decoder" in self.debug and self.counter % 100 == 0:
1089
- print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}")
 
1090
  self.counter += 1
1091
 
1092
  x = self.ln_dec(x)
1093
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1094
 
 
1095
  class Echo(nn.Module):
1096
  def __init__(self, param: Dimensions):
1097
  super().__init__()
1098
  self.param = param
1099
- self.count = 0
1100
 
1101
  self.encoder = AudioEncoder(
1102
  mels=param.mels,
@@ -1119,32 +983,6 @@ class Echo(nn.Module):
1119
  debug=param.debug,
1120
  features=param.features,
1121
  )
1122
-
1123
- all_head = torch.zeros(self.param.text_idx, self.param.text_head, dtype=torch.bool)
1124
- all_head[self.param.text_idx // 2 :] = True
1125
- self.register_buffer("alignment_head", all_head.to_sparse(), persistent=False)
1126
-
1127
- def update_base(self, f0):
1128
- for name, module in self.encoder.named_modules():
1129
- if isinstance(module, (rotary)):
1130
- module.update_base(f0)
1131
-
1132
- for name, module in self.decoder.named_modules():
1133
- if isinstance(module, (rotary)):
1134
- module.update_base(f0)
1135
-
1136
- def set_alignment_head(self, dump: bytes):
1137
- array = np.frombuffer(
1138
- gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
1139
- mask = torch.from_numpy(array).reshape(
1140
- self.param.text_idx, self.param.text_head)
1141
- self.register_buffer("alignment_head", mask.to_sparse(), persistent=False)
1142
-
1143
- def embed_audio(self, spectrogram: torch.Tensor):
1144
- return self.encoder(spectrogram)
1145
-
1146
- def logits(self,input_ids: torch.Tensor, encoder_output: torch.Tensor):
1147
- return self.decoder(input_ids, encoder_output)
1148
 
1149
  def forward(self,
1150
  decoder_input_ids=None,
@@ -1172,7 +1010,7 @@ class Echo(nn.Module):
1172
  encoder_inputs["phase"] = phase
1173
  if f0 is not None:
1174
  encoder_inputs["f0"] = f0
1175
-
1176
  encoder_outputs = self.encoder(encoder_inputs)
1177
  logits = self.decoder(input_ids, encoder_outputs)
1178
 
@@ -1181,11 +1019,7 @@ class Echo(nn.Module):
1181
  loss = F.cross_entropy(
1182
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1183
 
1184
- self.count += 1
1185
- return {
1186
- "logits": logits,
1187
- "loss": loss,
1188
- }
1189
 
1190
  @property
1191
  def device(self):
@@ -1241,34 +1075,6 @@ class Echo(nn.Module):
1241
  if count > 0:
1242
  print(f"{module_type}: {count}")
1243
 
1244
- def register_gradient_hooks(self):
1245
- for name, param in self.named_parameters():
1246
- if param.requires_grad:
1247
- if "encoder" in name:
1248
- param.register_hook(lambda grad, n=name: self._print_encoder_grad(n, grad))
1249
- elif "decoder" in name:
1250
- param.register_hook(lambda grad, n=name: self._print_decoder_grad(n, grad))
1251
-
1252
- print("Gradient debugging hooks registered")
1253
- return self
1254
-
1255
- def _print_encoder_grad(self, name, grad):
1256
- if grad is not None and self.count == 10:
1257
- norm = grad.median().item()
1258
- print(f"ENCODER GRAD: {name} = {norm:.6f}")
1259
-
1260
- return None
1261
-
1262
- def _print_decoder_grad(self, name, grad):
1263
- if grad is not None and self.count == 10:
1264
- norm = grad.median().item()
1265
- print(f"DECODER GRAD: {name} = {norm:.6f}")
1266
- return None
1267
-
1268
- def resetcounter(self):
1269
- self.counter = 0
1270
- print("Counter reset to 0.")
1271
-
1272
  metric = evaluate.load(path="wer")
1273
 
1274
  @dataclass
@@ -1480,13 +1286,18 @@ def compute_metrics(pred, compute_result: bool = True, print_pred: bool = False,
1480
  pred_ids = pred_ids[0]
1481
  else:
1482
  pred_ids = pred_ids
1483
-
1484
- if pred_ids.ndim == 3:
 
1485
  pred_ids = pred_ids.argmax(dim=-1)
1486
 
 
1487
  pred_ids = pred_ids.tolist()
1488
  label_ids = label_ids.tolist()
1489
 
 
 
 
1490
  if print_pred:
1491
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1492
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
@@ -1501,7 +1312,34 @@ def compute_metrics(pred, compute_result: bool = True, print_pred: bool = False,
1501
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1502
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1503
 
1504
- return {"wer": wer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1505
 
1506
  logger = logging.getLogger(__name__)
1507
 
@@ -1526,6 +1364,8 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1526
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1527
  ids = [id for id in ids if id not in sp_ids]
1528
  return ids
 
 
1529
  def bdec(ids_list, skip_special_tokens=True):
1530
  results = []
1531
  for ids in ids_list:
@@ -1533,6 +1373,7 @@ def setup_tokenizer(token: str, local_tokenizer_path: str = "./"):
1533
  ids = [id for id in ids if id not in [0, 1, 2]]
1534
  results.append(tokenizer.decode(ids))
1535
  return results
 
1536
  def save_pretrained(save_dir):
1537
  os.makedirs(save_dir, exist_ok=True)
1538
  tokenizer.save(f"{save_dir}/tokenizer.json")
@@ -1570,7 +1411,7 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
1570
  dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1571
 
1572
  if sanity_check:
1573
- dataset = dataset["test"]
1574
  dataset = dataset.select_columns(["audio", "transcription"])
1575
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1576
  dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
@@ -1609,9 +1450,6 @@ def get_training_args(
1609
  num_train_epochs: int = 1,
1610
  logging_steps: int = 1,
1611
  eval_on_start: bool = False,
1612
- learning_rate: float = 1e-4,
1613
- weight_decay: float = 0.01,
1614
- max_grad_norm: float = 1.0,
1615
  ) -> Seq2SeqTrainingArguments:
1616
 
1617
  return Seq2SeqTrainingArguments(
@@ -1619,7 +1457,7 @@ def get_training_args(
1619
  per_device_train_batch_size=1,
1620
  per_device_eval_batch_size=1,
1621
  gradient_accumulation_steps=1,
1622
- eval_accumulation_steps=4,
1623
  eval_strategy="steps",
1624
  save_strategy="no",
1625
  max_steps=max_steps,
@@ -1635,17 +1473,9 @@ def get_training_args(
1635
  disable_tqdm=False,
1636
  save_total_limit=1,
1637
  label_names=["labels"],
1638
- optim="adamw_torch",
1639
- adam_beta1=0.9,
1640
- adam_beta2=0.999,
1641
- adam_epsilon=1e-8,
1642
- lr_scheduler_type="cosine",
1643
- learning_rate=learning_rate,
1644
- weight_decay=weight_decay,
1645
  save_safetensors=False,
1646
  eval_on_start=eval_on_start,
1647
  batch_eval_metrics=batch_eval_metrics,
1648
- max_grad_norm=max_grad_norm,
1649
  )
1650
 
1651
  def main():
@@ -1666,24 +1496,18 @@ def main():
1666
  eval_steps = 1,
1667
  warmup_steps = 0,
1668
  logging_steps = 1,
1669
- eval_on_start = False,
1670
- learning_rate = 5e-6,
1671
- weight_decay = 0.01,
1672
- max_grad_norm = 0.6,
1673
  )
1674
  else:
1675
  training_args = get_training_args(
1676
  log_dir,
1677
  batch_eval_metrics = False,
1678
  max_steps = 1000,
1679
- save_steps = 1005,
1680
  eval_steps = 100,
1681
  warmup_steps = 100,
1682
  logging_steps = 10,
1683
  eval_on_start = False,
1684
- learning_rate = 2.5e-4,
1685
- weight_decay = 0.01,
1686
- max_grad_norm = 0.6,
1687
  )
1688
 
1689
  return training_args
@@ -1723,7 +1547,7 @@ def main():
1723
  "sampling_rate": 16000,
1724
  "pad_mode": "constant",
1725
  "center": True,
1726
- "power": 2.0,
1727
  "window_fn": torch.hann_window,
1728
  "mel_scale": "htk",
1729
  "norm": None,
@@ -1734,7 +1558,7 @@ def main():
1734
  global global_model
1735
  global_model = model
1736
 
1737
- metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
1738
  tokenizer=tokenizer, model=model)
1739
 
1740
  print(f"{'Sanity check' if sanity_check else 'Training'} mode")
@@ -1744,6 +1568,17 @@ def main():
1744
  sanity_check=sanity_check,
1745
  dataset_config=dataset_config)
1746
 
 
 
 
 
 
 
 
 
 
 
 
1747
  trainer = Seq2SeqTrainer(
1748
  args=training_args,
1749
  model=model,
@@ -1751,11 +1586,16 @@ def main():
1751
  eval_dataset=test_dataset,
1752
  data_collator=DataCollator(tokenizer=tokenizer),
1753
  compute_metrics=metrics_fn,
 
1754
  )
1755
 
1756
  model.init_weights()
1757
  trainer.train()
1758
 
 
1759
  if __name__ == "__main__":
1760
  main()
1761
 
 
 
 
 
3
  import math
4
  import warnings
5
  import logging
 
 
6
  import torch
7
  import torchaudio
8
  import torch.nn.functional as F
9
  import torch.nn.init as init
10
  from torch import nn, Tensor
11
  import numpy as np
 
12
  import matplotlib.pyplot as plt
13
  from typing import Optional, Dict, Union, List, Tuple, Any
14
  from functools import partial
 
19
  import transformers
20
  import evaluate
21
  from dataclasses import dataclass
22
+ from opimizer import MaxFactor
 
23
 
24
  torch.backends.cudnn.allow_tf32 = True
25
  torch.backends.cuda.matmul.allow_tf32 = True
 
31
 
32
  warnings.filterwarnings("ignore")
33
  logging.basicConfig(level=logging.ERROR)
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  @dataclass
37
  class Dimensions:
 
50
  cross_attn: bool
51
  features: List[str]
52
 
53
+
54
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
55
  title="", markers=None, marker_labels=None,
56
  show_voiced_regions=True, show_energy=False):
57
  num_plots = sum([x is not None, w is not None, p is not None, per is not None])
58
  if num_plots == 0:
59
  raise ValueError("No data to plot. Please provide at least one input tensor.")
60
+ t_spans = []
61
 
62
  if w is not None:
63
  w_np = w[sample_idx].detach().cpu().numpy()
64
  if w_np.ndim > 1:
65
  w_np = w_np.squeeze()
66
+ t_spans.append(len(w_np) / sr)
67
  if x is not None:
68
  x_np = x[sample_idx].detach().cpu().numpy()
69
  if x_np.shape[0] < x_np.shape[1]:
70
  x_np = x_np.T
71
+ t_spans.append(x_np.shape[0] * hop_length / sr)
72
  if p is not None:
73
  p_np = p[sample_idx].detach().cpu().numpy()
74
  if p_np.ndim > 1:
75
  p_np = p_np.squeeze()
76
+ t_spans.append(len(p_np) * hop_length / sr)
77
  if per is not None:
78
  per_np = per[sample_idx].detach().cpu().numpy()
79
  if per_np.ndim > 1:
80
  per_np = per_np.squeeze()
81
+ t_spans.append(len(per_np) * hop_length / sr)
82
+ max_t = max(t_spans) if t_spans else 0
83
  fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
84
  if num_plots == 1:
85
  axs = [axs]
 
93
  for i in range(len(per_np)-1):
94
  if per_np[i] > threshold:
95
  ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
96
+ cu_ax = 0
97
  if w is not None:
98
  w_np = w[sample_idx].detach().cpu().numpy()
99
  if w_np.ndim > 1:
100
  w_np = w_np.squeeze()
101
  t = np.arange(len(w_np)) / sr
102
+ axs[cu_ax].plot(t, w_np, color="tab:blue")
103
 
104
  if show_energy:
105
  frame_length = hop_length
 
111
  energy = np.array(energy)
112
  energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
113
  t_energy = np.arange(len(energy)) * hop_length_energy / sr
114
+ axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
115
+ axs[cu_ax].legend(loc='upper right')
116
+ axs[cu_ax].set_title("Waveform")
117
+ axs[cu_ax].set_ylabel("Amplitude")
118
+ axs[cu_ax].set_xlim([0, max_t])
119
+ axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
120
+ cu_ax += 1
121
 
122
  if x is not None:
123
  x_np = x[sample_idx].detach().cpu().numpy()
124
  if x_np.shape[0] < x_np.shape[1]:
125
  x_np = x_np.T
126
+ axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
127
  extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
128
+ axs[cu_ax].set_title("Spectrogram")
129
+ axs[cu_ax].set_ylabel("Mel Bin")
130
+ axs[cu_ax].set_xlim([0, max_t])
131
+ axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
132
+ cu_ax += 1
133
 
134
  if p is not None:
135
  p_np = p[sample_idx].detach().cpu().numpy()
136
  if p_np.ndim > 1:
137
  p_np = p_np.squeeze()
138
  t_p = np.arange(len(p_np)) * hop_length / sr
139
+ axs[cu_ax].plot(t_p, p_np, color="tab:green")
140
+ axs[cu_ax].set_title("Pitch")
141
+ axs[cu_ax].set_ylabel("Frequency (Hz)")
142
+ axs[cu_ax].set_xlim([0, max_t])
143
+ axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
144
+ axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
145
+ cu_ax += 1
146
 
147
  if per is not None:
148
  per_np = per[sample_idx].detach().cpu().numpy()
149
  if per_np.ndim > 1:
150
  per_np = per_np.squeeze()
151
  t_per = np.arange(len(per_np)) * hop_length / sr
152
+ axs[cu_ax].plot(t_per, per_np, color="tab:red")
153
+ axs[cu_ax].set_title("Period (Voice Activity)")
154
+ axs[cu_ax].set_ylabel("periodocity")
155
+ axs[cu_ax].set_xlim([0, max_t])
156
+ axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
157
+ axs[cu_ax].set_ylim([-0.05, 1.05])
158
+ axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
159
 
160
  if markers is not None:
161
  for i, t in enumerate(markers):
 
164
  ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
165
  if marker_labels:
166
  axs[0].legend(loc='upper right', fontsize='small')
167
+ axs[-1].set_xlabel("t (s)")
168
  fig.suptitle(title, fontsize=16)
169
  plt.tight_layout(rect=[0, 0, 1, 0.97])
170
  plt.show()
 
233
  def tox():
234
  return {"device": get_device(), "dtype": get_dtype()}
235
 
236
+ def sinusoids(length, channels, max_tscale=10000):
237
  assert channels % 2 == 0
238
+ log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
239
+ inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
240
+ scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
241
+ return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
242
+
243
 
244
  class rotary(nn.Module):
245
+ def __init__(self, dims, head, max_ctx=1500, theta=10000, radii=True, debug: List[str] = [], use_pbias=False):
246
  super(rotary, self).__init__()
247
 
248
  self.use_pbias = use_pbias
 
255
  self.counter = 0
256
  self.last_theta = None
257
 
258
+ self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
259
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
260
 
261
  def theta_freqs(self, theta):
 
303
  f0 = f0.squeeze(0)
304
  return f0.to(device=device, dtype=dtype)
305
 
 
 
 
 
 
 
 
 
 
306
  def align_f0(self, ctx, f0):
307
  f0 = self.f0proj(f0)
308
  if f0.dim() == 3:
 
353
  theta = f0_mean + self.theta
354
  else:
355
  theta = self.theta
356
+
357
  freqs = self.theta_freqs(theta)
358
+
359
  freqs = t[:, None] * freqs[None, :]
360
 
361
  if self.radii and f0 is not None:
 
395
  x1 = x1.view(orig_shape)
396
  return torch.cat([x1.type_as(x), x2], dim=-1)
397
 
398
+
399
+
400
  class MultiheadA(nn.Module):
401
  _seen = set()
402
  rbf = False
 
410
  self.debug = debug
411
  self.counter = 0
412
 
413
+ self.q = nn.Linear(dims, dims).to(device, dtype)
414
+ self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
415
+ self.v = nn.Linear(dims, dims).to(device, dtype)
416
+ self.o = nn.Linear(dims, dims).to(device, dtype)
417
 
418
  self.pad_token = 0
419
  self.rotary_emb = rotary_emb
 
433
  else:
434
  self.rope = None
435
 
436
+ def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
437
+ q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
438
+ k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
439
+ qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
440
+ qk_cosine = qk_cosine + mask
441
+ weights = F.softmax(qk_cosine, dim=-1)
442
+ out = torch.matmul(weights, v)
443
+ return out
444
+
445
  def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
446
  scale = (self.dims // self.head) ** -0.25
447
  dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
 
507
  self.counter += 1
508
  return self.o(wv), qk
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  class t_gate(nn.Module):
512
  def __init__(self, dims, num_types=4):
513
  super().__init__()
 
609
  self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
610
 
611
  def forward(self, x, xa=None, mask=None, enc=None, layer=None, feature_type="audio") -> Tensor:
 
 
 
612
 
 
613
  x = x + self.attna(self.lna(x), xa=None, mask=mask, enc=enc, layer=layer)[0]
614
  xb = x
615
  if self.attnb and xa is not None:
616
  x = x + self.attnb(self.lnb(x), xa=xa, mask=None, enc=enc, layer=layer)[0]
617
 
618
  if self.do_blend:
619
+ b = torch.sigmoid(self.blend)
620
  x = b * xb + (1 - b) * x
621
 
622
  if self.skip_gates:
 
838
  cgate = False
839
 
840
  self.blocks = nn.ModuleDict({
841
+
842
  "spectrogram": nn.ModuleList(
843
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
844
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
845
+ if "spectrogram" in features else None),
846
+
847
  "waveform": nn.ModuleList(
848
  [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
849
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
850
+ if "waveform" in features else None),
851
+
852
  "pitch": nn.ModuleList(
853
  [FEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act, stride=2)] +
854
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
855
+ if "pitch" in features else None),
856
+
857
  "envelope": nn.ModuleList(
858
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
859
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
860
+ if "envelope" in features else None),
861
+
862
  "phase": nn.ModuleList(
863
  [FEncoder(input_dims=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
864
+ [Residual(ctx=ctx, dims=dims, head=head, act=act, debug=debug, features=features, cgate=cgate) for _ in range(layer)]
865
+ if "phase" in features else None),
866
+ })
867
 
868
  def forward(self, enc, layer="encoder"):
869
  enc = dict_to(enc, device, dtype)
 
 
 
 
 
 
 
870
  out = {}
871
  out.update(enc)
872
 
 
876
  for block in self.blocks[f]:
877
  x = block(x, enc=enc, layer=layer)
878
  out[f] = x
879
+
880
+ if self.counter < 1 and "encoder" in self.debug:
881
+ s = enc.get("spectrogram")
882
+ w = enc.get("waveform")
883
+ p = default(enc.get("pitch"), enc.get("f0"))
884
+ plot_waveform(x=s, w=w, p=p, hop_length=128)
885
  shapes = {k: v.shape for k, v in enc.items()}
886
  print(f"Step {self.counter}: mode: {list(enc.keys()) }: shapes: {shapes}")
887
  self.counter += 1
888
  return out
889
 
890
+
891
  class TextDecoder(nn.Module):
892
  def __init__(self, vocab: int, ctx: int, dims: int, head: int, layer: int, cross_attn: bool,
893
  debug: List[str], features: List[str]):
 
901
  self.counter = 0
902
  self.dropout = 0.01
903
  self.features = features
904
+ self.do_blend = "no_blend" not in self.debug
905
+ self.sequential = False
906
 
907
  self.token = nn.Embedding(num_embeddings=vocab, embedding_dim=dims)
908
  with torch.no_grad():
 
923
  mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
924
  self.register_buffer("mask", mask, persistent=False)
925
 
926
+ def forward(self, x, enc, order=None, layer='decoder') -> Tensor:
 
 
 
927
 
928
  if order is None:
929
  order = self.features
 
932
  x = self.token(x) + self.positional[:x.shape[1]]
933
  x = F.dropout(x, p=self.dropout, training=self.training)
934
 
935
+
936
  for block in self.block:
937
  x = block(x, xa=None, mask=mask, enc=None, layer=layer)
938
 
 
942
  for block in self.blocks[f]:
943
  out = block(x=x, xa=xa, mask=None, enc=None, layer=layer)
944
 
945
+ if self.sequential:
946
  x = out
947
  else:
948
+ a = torch.sigmoid(self.blend[f])
949
  x = a * out + (1 - a) * x
950
 
951
+ if self.counter < 1 and "decoder" in self.debug:
952
+ shapes = {k: v.shape for k, v in enc.items()}
953
+ print(f"Step {self.counter}: Decoder output shape: {x.shape}, enc keys: {list(enc.keys())}, order: {order}: shapes: {shapes}")
954
  self.counter += 1
955
 
956
  x = self.ln_dec(x)
957
  return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
958
 
959
+
960
  class Echo(nn.Module):
961
  def __init__(self, param: Dimensions):
962
  super().__init__()
963
  self.param = param
 
964
 
965
  self.encoder = AudioEncoder(
966
  mels=param.mels,
 
983
  debug=param.debug,
984
  features=param.features,
985
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
 
987
  def forward(self,
988
  decoder_input_ids=None,
 
1010
  encoder_inputs["phase"] = phase
1011
  if f0 is not None:
1012
  encoder_inputs["f0"] = f0
1013
+
1014
  encoder_outputs = self.encoder(encoder_inputs)
1015
  logits = self.decoder(input_ids, encoder_outputs)
1016
 
 
1019
  loss = F.cross_entropy(
1020
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1021
 
1022
+ return {"logits": logits, "loss": loss}
 
 
 
 
1023
 
1024
  @property
1025
  def device(self):
 
1075
  if count > 0:
1076
  print(f"{module_type}: {count}")
1077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1078
  metric = evaluate.load(path="wer")
1079
 
1080
  @dataclass
 
1286
  pred_ids = pred_ids[0]
1287
  else:
1288
  pred_ids = pred_ids
1289
+ if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
1290
+ if not isinstance(pred_ids, torch.Tensor):
1291
+ pred_ids = torch.tensor(pred_ids)
1292
  pred_ids = pred_ids.argmax(dim=-1)
1293
 
1294
+
1295
  pred_ids = pred_ids.tolist()
1296
  label_ids = label_ids.tolist()
1297
 
1298
+ pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
1299
+ label_ids = [[pad_token_id if token == -100 else token for token in seq] for seq in label_ids]
1300
+
1301
  if print_pred:
1302
  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
1303
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
 
1312
  label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
1313
  wer = 100 * metric.compute(predictions=pred_str, references=label_str)
1314
 
1315
+
1316
+ if model is None:
1317
+ global global_model
1318
+ if 'global_model' in globals():
1319
+ model = global_model
1320
+
1321
+ if model is not None:
1322
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
1323
+ if trainable_params > 0:
1324
+ efficiency_score = (100 - wer) / trainable_params
1325
+ else:
1326
+ print("Warning: Zero trainable parameters detected")
1327
+ efficiency_score = 0.0
1328
+ else:
1329
+ print("Warning: Model not available for parameter counting")
1330
+ trainable_params = 0.0
1331
+ efficiency_score = 0.0
1332
+
1333
+ if hasattr(wer, "item"):
1334
+ wer = wer.item()
1335
+
1336
+ metrics = {
1337
+ "wer": float(wer),
1338
+ "trainable_params_M": float(trainable_params),
1339
+ "efficiency_score": float(efficiency_score),
1340
+ }
1341
+ return metrics
1342
+
1343
 
1344
  logger = logging.getLogger(__name__)
1345
 
 
1364
  sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
1365
  ids = [id for id in ids if id not in sp_ids]
1366
  return ids
1367
+
1368
+
1369
  def bdec(ids_list, skip_special_tokens=True):
1370
  results = []
1371
  for ids in ids_list:
 
1373
  ids = [id for id in ids if id not in [0, 1, 2]]
1374
  results.append(tokenizer.decode(ids))
1375
  return results
1376
+
1377
  def save_pretrained(save_dir):
1378
  os.makedirs(save_dir, exist_ok=True)
1379
  tokenizer.save(f"{save_dir}/tokenizer.json")
 
1411
  dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
1412
 
1413
  if sanity_check:
1414
+ dataset = dataset["test"].take(10)
1415
  dataset = dataset.select_columns(["audio", "transcription"])
1416
  prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
1417
  dataset = dataset.map(function=prepare_fn, remove_columns=["audio", "transcription"]).with_format(type="torch")
 
1450
  num_train_epochs: int = 1,
1451
  logging_steps: int = 1,
1452
  eval_on_start: bool = False,
 
 
 
1453
  ) -> Seq2SeqTrainingArguments:
1454
 
1455
  return Seq2SeqTrainingArguments(
 
1457
  per_device_train_batch_size=1,
1458
  per_device_eval_batch_size=1,
1459
  gradient_accumulation_steps=1,
1460
+ eval_accumulation_steps=1,
1461
  eval_strategy="steps",
1462
  save_strategy="no",
1463
  max_steps=max_steps,
 
1473
  disable_tqdm=False,
1474
  save_total_limit=1,
1475
  label_names=["labels"],
 
 
 
 
 
 
 
1476
  save_safetensors=False,
1477
  eval_on_start=eval_on_start,
1478
  batch_eval_metrics=batch_eval_metrics,
 
1479
  )
1480
 
1481
  def main():
 
1496
  eval_steps = 1,
1497
  warmup_steps = 0,
1498
  logging_steps = 1,
1499
+ eval_on_start = True,
 
 
 
1500
  )
1501
  else:
1502
  training_args = get_training_args(
1503
  log_dir,
1504
  batch_eval_metrics = False,
1505
  max_steps = 1000,
1506
+ save_steps = 1000,
1507
  eval_steps = 100,
1508
  warmup_steps = 100,
1509
  logging_steps = 10,
1510
  eval_on_start = False,
 
 
 
1511
  )
1512
 
1513
  return training_args
 
1547
  "sampling_rate": 16000,
1548
  "pad_mode": "constant",
1549
  "center": True,
1550
+ "power": 1.0,
1551
  "window_fn": torch.hann_window,
1552
  "mel_scale": "htk",
1553
  "norm": None,
 
1558
  global global_model
1559
  global_model = model
1560
 
1561
+ metrics_fn = partial(compute_metrics, print_pred=False, num_samples=1,
1562
  tokenizer=tokenizer, model=model)
1563
 
1564
  print(f"{'Sanity check' if sanity_check else 'Training'} mode")
 
1568
  sanity_check=sanity_check,
1569
  dataset_config=dataset_config)
1570
 
1571
+ optimizer = MaxFactor(model.parameters(), lr=0.025, beta2_decay=-0.8, eps=(1e-10, 1e-7), d=1.0,
1572
+ weight_decay=0.025, gamma=0.99, max=False)
1573
+
1574
+
1575
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1576
+ optimizer,
1577
+ T_max=training_args.max_steps,
1578
+ eta_min=1e-7,
1579
+ last_epoch=-1,
1580
+ )
1581
+
1582
  trainer = Seq2SeqTrainer(
1583
  args=training_args,
1584
  model=model,
 
1586
  eval_dataset=test_dataset,
1587
  data_collator=DataCollator(tokenizer=tokenizer),
1588
  compute_metrics=metrics_fn,
1589
+ optimizers=(optimizer, scheduler)
1590
  )
1591
 
1592
  model.init_weights()
1593
  trainer.train()
1594
 
1595
+
1596
  if __name__ == "__main__":
1597
  main()
1598
 
1599
+
1600
+
1601
+