Sin2pi commited on
Commit
c05d8b0
·
verified ·
1 Parent(s): 5e7abab

Upload echoutils.py

Browse files
Files changed (1) hide show
  1. echoutils.py +817 -56
echoutils.py CHANGED
@@ -38,10 +38,37 @@ def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, de
38
  else:
39
  mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
40
  if padding_mask is not None:
41
- padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
42
  mask = mask | (~padding_mask)
43
  return mask
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def mask_win(text_ctx, aud_ctx):
46
  mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
47
  audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
@@ -93,18 +120,23 @@ def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
93
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
94
  return out, None
95
 
96
- # example
97
- # class MyAttention(nn.Module):
98
- # def __init__(self, dims, head):
99
- # super().__init__()
100
- # self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
101
- # self.head = head
 
 
 
 
102
 
103
- # def forward(self, x, xa=None, mask=None, temperature=1.0):
104
- # q, k, v = create_qkv(self.q, self.k, self.v, x, xa, head=self.head)
105
- # out, _ = calculate_attention(q, k, v, mask=mask, temperature=temperature)
106
- # out = self.o(out) # Final projection
107
- # return out
 
108
 
109
  def mel_scale_scalar(freq: float) -> float:
110
  return 1127.0 * math.log(1.0 + freq / 700.0)
@@ -142,7 +174,7 @@ def track_xa(new_xa, operation=""):
142
  current_id = id(new_xa)
143
  if current_id != xa_id[0]:
144
  print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
145
- xa_id[0] = current_id
146
  else:
147
  print(f"xa REUSE: {current_id} in {operation}")
148
  return new_xa
@@ -163,18 +195,6 @@ def get_activation(act: str) -> nn.Module:
163
  }
164
  return act_map.get(act, nn.GELU())
165
 
166
- @dataclass
167
- class Dimensions:
168
- vocab: int
169
- mels: int
170
- ctx: int
171
- dims: int
172
- head: int
173
- layer: int
174
- act: str
175
- debug: List[str]
176
- features: List[str]
177
-
178
  def get_generation_config(param):
179
  return GenerationConfig( # type: ignore
180
  max_length=param.text_ctx,
@@ -193,6 +213,350 @@ def get_generation_config(param):
193
  use_cache=False,
194
  return_timestamps=False)
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
197
  title="", markers=None, marker_labels=None,
198
  show_voiced_regions=True, show_energy=False):
@@ -391,18 +755,99 @@ class Sinusoids(nn.Module):
391
  features[:, 0::2] = torch.sin(position * div_term)
392
  features[:, 1::2] = torch.cos(position* div_term)
393
  self.register_buffer('sinusoid', tensor=features)
394
- self.positional_embeddings = nn.Parameter(self.sinusoid.clone())
395
  def forward(self, positions):
396
  position_embeddings = self.positional_embeddings[positions]
397
  return position_embeddings
398
 
399
  def sinusoids(length, channels, max_tscale=10000):
400
  assert channels % 2 == 0
401
- log_tscale_increment = np.log(max_tscale) / (channels // 2 - 1)
402
- inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
403
- scaled_t = torch.arange(length)[:, np.newaxis] * inv_tscales[np.newaxis, :]
404
  return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
407
  if isinstance(ids, torch.Tensor):
408
  ids = ids.tolist()
@@ -413,7 +858,7 @@ def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
413
 
414
  def setup_tokenizer(dir: str):
415
  from tokenizers import Tokenizer
416
- tokenizer = Tokenizer.from_file(f"{dir}/tokenizer.json")
417
  orig_encode = tokenizer.encode
418
  orig_decode = tokenizer.decode
419
 
@@ -480,7 +925,33 @@ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
480
  ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
481
  return sp_mel, ap_mel
482
 
483
- def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  audio = batch["audio"]
486
  sample_rate = audio["sampling_rate"]
@@ -488,43 +959,116 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
488
  wav = load_wave(wave_data=audio, sample_rate=sample_rate)
489
 
490
  spectrogram_config = {
491
- "hop_length": 256,
492
- "f_min": 150,
493
- "f_max": 2000,
494
- "n_mels": 128,
495
- "n_fft": 1024,
496
  "sample_rate": 16000,
497
- "pad_mode": "constant",
498
- "center": True,
499
- "power": 1.0,
500
- "window_fn": torch.hann_window,
501
- "mel_scale": "htk",
502
- "norm": None,
503
- "normalized": False,
504
  }
505
 
506
- if crepe:
507
- time, frequency, confidence, activation = crepe.predict(wav, sample_rate, viterbi=True)
 
 
 
508
  crepe_time = torch.from_numpy(time)
509
  crepe_frequency = torch.from_numpy(frequency)
510
  crepe_confidence = torch.from_numpy(confidence)
511
  crepe_activation = torch.from_numpy(activation)
 
 
 
 
 
512
  else:
513
  crepe_time = None
514
  crepe_frequency = None
515
  crepe_confidence = None
516
  crepe_activation = None
517
 
518
- if spec:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
520
  mel_spectrogram = transform(wav)
521
  log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
522
- # spectrogram_tensor = mel_spectrogram.log10()
523
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
524
  spectrogram_tensor = (log_mel + 4.0) / 4.0
525
  spectrogram_tensor = torch.tensor(spectrogram_tensor)
526
- else:
527
- spectrogram_tensor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  if f0 or f0t or pitch or harmonics or aperiodics:
530
  wavnp = wav.numpy().astype(np.float64)
@@ -548,7 +1092,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
548
  end_idx = torch.searchsorted(t2, token_ends, side="right")
549
  pitch_tok = torch.zeros(T, dtype=torch.float32)
550
  for i in range(T):
551
- lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
552
  segment = f0_np[lo:hi]
553
  if mode == "mean":
554
  pitch_tok[i] = segment.mean()
@@ -559,14 +1103,14 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
559
  pitch_tok[pitch_tok < 100.0] = 0.0
560
  bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
561
  f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
562
- # f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
563
  else:
564
  f0t_tensor = None
565
 
566
  if phase_mod:
567
  tframe = torch.mean(t2[1:] - t2[:-1])
568
  phi0 = 0.0
569
- omega = 2 * torch.pi * f0_tensor
570
  dphi = omega * tframe
571
  phi = torch.cumsum(dphi, dim=0) + phi0
572
  phase = torch.remainder(phi, 2 * torch.pi)
@@ -574,7 +1118,10 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
574
  phase = None
575
 
576
  if pitch:
577
- p_tensor = torch.from_numpy(f0_np)
 
 
 
578
  else:
579
  p_tensor = None
580
 
@@ -596,9 +1143,27 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
596
  else:
597
  wave_tensor = None
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  if debug:
600
 
601
- print(f"['f0']: {f0_tensor.shape if f0 else None}")
602
  print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
603
  print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
604
  print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
@@ -607,10 +1172,11 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
607
  print(f"['labels']: {len(labels) if labels else None}")
608
  print(f"['phase']: {phase.shape if phase else None}")
609
  print(f"['pitch']: {p_tensor.shape if pitch else None}")
610
- print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
611
  print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
612
  print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
613
  print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
 
614
 
615
  return {
616
  "waveform": wave_tensor if waveform else None,
@@ -626,6 +1192,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
626
  "crepe_frequency": crepe_frequency if crepe else None,
627
  "crepe_confidence": crepe_confidence if crepe else None,
628
  "crepe_activation": crepe_activation if crepe else None,
 
629
  }
630
 
631
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
@@ -646,6 +1213,7 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
646
  "debug": False,
647
  "phase_mod": False,
648
  "crepe": False,
 
649
  }
650
 
651
  if load_saved:
@@ -703,6 +1271,199 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
703
 
704
  return train_dataset, test_dataset
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  @dataclass
707
  class DataCollator:
708
  tokenizer: Any
@@ -734,7 +1495,7 @@ class DataCollator:
734
  batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
735
  batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
736
 
737
- elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation"]:
738
  items = [f[key] for f in features if key in f]
739
  items = [item for item in items if item is not None]
740
  if not items:
 
38
  else:
39
  mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
40
  if padding_mask is not None:
41
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
42
  mask = mask | (~padding_mask)
43
  return mask
44
 
45
+ def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
46
+ q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
47
+ k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
48
+ qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
49
+ qk_cosine = qk_cosine + mask
50
+ weights = F.softmax(qk_cosine, dim=-1)
51
+ out = torch.matmul(weights, v)
52
+ return out
53
+
54
+ def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
55
+ dot_scores = torch.matmul(q, k.transpose(-1, -2))
56
+ if rbf_ratio <= 0.0:
57
+ return dot_scores
58
+ q_norm = q.pow(2).sum(dim=-1, keepdim=True)
59
+ k_norm = k.pow(2).sum(dim=-1, keepdim=True)
60
+ qk = torch.matmul(q, k.transpose(-1, -2))
61
+ dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
62
+ rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
63
+ return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
64
+
65
+ def sliding_window_mask(q_len, k_len, window, device):
66
+ # mask[i, j] = 1 if j in [i-window+1, i], else 0
67
+ idxs = torch.arange(q_len, device=device).unsqueeze(1)
68
+ jdxs = torch.arange(k_len, device=device).unsqueeze(0)
69
+ mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
70
+ return mask.float() # shape: (q_len, k_len)
71
+
72
  def mask_win(text_ctx, aud_ctx):
73
  mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
74
  audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
 
120
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
121
  return out, None
122
 
123
+ class KVCache(nn.Module):
124
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
125
+ super().__init__()
126
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
127
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
128
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
129
+
130
+ def update(self, input_pos, k_val, v_val):
131
+ # input_pos: [S], k_val: [B, H, S, D]
132
+ assert input_pos.shape[0] == k_val.shape[2]
133
 
134
+ k_out = self.k_cache
135
+ v_out = self.v_cache
136
+ k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue]
137
+ v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue]
138
+
139
+ return k_out, v_out
140
 
141
  def mel_scale_scalar(freq: float) -> float:
142
  return 1127.0 * math.log(1.0 + freq / 700.0)
 
174
  current_id = id(new_xa)
175
  if current_id != xa_id[0]:
176
  print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
177
+ xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue]
178
  else:
179
  print(f"xa REUSE: {current_id} in {operation}")
180
  return new_xa
 
195
  }
196
  return act_map.get(act, nn.GELU())
197
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def get_generation_config(param):
199
  return GenerationConfig( # type: ignore
200
  max_length=param.text_ctx,
 
213
  use_cache=False,
214
  return_timestamps=False)
215
 
216
+ # class rotary(nn.Module):
217
+ # def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
218
+
219
+ # super(rotary, self).__init__()
220
+ # self.use_pbias = use_pbias
221
+ # self.dims = dims
222
+ # self.head = head
223
+ # self.head_dim = dims // head
224
+ # self.radii = radii
225
+ # self.debug = debug
226
+ # self.counter = 0
227
+ # self.last_theta = None
228
+ # self.axial = axial
229
+
230
+ # self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
231
+ # theta = (torch.tensor(10000, device=device, dtype=dtype))
232
+ # self.theta = nn.Parameter(theta, requires_grad=True)
233
+ # self.theta_values = []
234
+
235
+ # if axial and spec_shape is not None:
236
+ # time_frames, freq_bins = spec_shape
237
+ # self.time_frames = time_frames
238
+ # self.freq_bins = freq_bins
239
+
240
+ # time_theta = 50.0
241
+ # time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
242
+ # self.register_buffer('time_freqs', time_freqs)
243
+
244
+ # freq_theta = 100.0
245
+ # freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
246
+ # self.register_buffer('freq_freqs', freq_freqs)
247
+
248
+ # def pitch_bias(self, f0):
249
+ # if f0 is None:
250
+ # return None
251
+ # f0_flat = f0.squeeze().float()
252
+ # f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
253
+ # f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
254
+ # f0_norm.unsqueeze(1)))
255
+ # return f0_sim.unsqueeze(0).unsqueeze(0)
256
+
257
+ # def theta_freqs(self, theta):
258
+ # if theta.dim() == 0:
259
+ # theta = theta.unsqueeze(0)
260
+ # freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
261
+ # torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
262
+ # self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
263
+ # return freq
264
+
265
+ # def _apply_radii(self, freqs, f0, ctx):
266
+ # if self.radii and f0 is not None:
267
+ # radius = f0.to(device, dtype)
268
+ # L = radius.shape[0]
269
+ # if L != ctx:
270
+ # feature = L / ctx
271
+ # idx = torch.arange(ctx, device=f0.device)
272
+ # idx = (idx * feature).long().clamp(0, L - 1)
273
+ # radius = radius[idx]
274
+ # return torch.polar(radius.unsqueeze(-1), freqs), radius
275
+ # else:
276
+ # return torch.polar(radius.unsqueeze(-1), freqs), radius
277
+ # else:
278
+ # return torch.polar(torch.ones_like(freqs), freqs), None
279
+
280
+ # def check_f0(self, f0, f0t, ctx):
281
+ # if f0 is not None and f0.shape[1] == ctx:
282
+ # return f0
283
+ # elif f0t is not None and f0t.shape[1] == ctx:
284
+ # return f0t
285
+ # else:
286
+ # return None
287
+
288
+ # def axial_freqs(self, ctx):
289
+ # if not self.axial:
290
+ # return None
291
+ # time_frames = self.time_frames
292
+ # freq_bins = self.freq_bins
293
+
294
+ # t = torch.arange(ctx, device=device, dtype=dtype)
295
+ # t_x = (t % time_frames).float()
296
+ # t_y = torch.div(t, time_frames, rounding_mode='floor').float()
297
+ # freqs_x = torch.outer(t_x, self.time_freqs)
298
+ # freqs_y = torch.outer(t_y, self.freq_freqs)
299
+ # freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
300
+ # freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
301
+ # return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
302
+
303
+ # def forward(self, x=None, feats=None, feature=None, layer=None) -> Tensor:
304
+ # ctx=x
305
+ # f0 = feats.get("f0") if feats is not None else None
306
+ # f0t = feats.get("f0t") if feats is not None else None
307
+
308
+ # f0 = self.check_f0(f0, f0t, ctx)
309
+ # if f0 is not None:
310
+ # # if f0.dim() == 2:
311
+ # # f0 = f0.squeeze(0)
312
+ # theta = f0 + self.theta
313
+ # else:
314
+ # theta = self.theta
315
+ # freqs = self.theta_freqs(theta)
316
+ # t = torch.arange(ctx, device=device, dtype=dtype) # type: ignore
317
+ # freqs = t[:, None] * freqs
318
+ # freqs, radius = self._apply_radii(freqs, f0, ctx)
319
+
320
+ # if self.axial and feature == "spectrogram":
321
+ # freqs_2d = self.axial_freqs(ctx)
322
+ # if freqs_2d is not None:
323
+ # return freqs_2d.unsqueeze(0)
324
+
325
+ # if "radius" in self.debug and self.counter == 10:
326
+ # print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
327
+ # self.counter += 1
328
+ # return freqs.unsqueeze(0)
329
+
330
+ # @staticmethod
331
+ # def split(X: Tensor):
332
+ # half_dim = X.shape[-1] // 2
333
+ # return X[..., :half_dim], X[..., half_dim:]
334
+
335
+ # @staticmethod
336
+ # def apply_rotary(x, freqs):
337
+ # x1 = x[..., :freqs.shape[-1]*2]
338
+ # x2 = x[..., freqs.shape[-1]*2:]
339
+ # orig_shape = x1.shape
340
+ # if x1.ndim == 2:
341
+ # x1 = x1.unsqueeze(0)
342
+ # x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
343
+ # x1 = torch.view_as_complex(x1) * freqs
344
+ # x1 = torch.view_as_real(x1).flatten(-2)
345
+ # x1 = x1.view(orig_shape)
346
+ # return torch.cat([x1.type_as(x), x2], dim=-1)
347
+
348
+
349
+ # class feature_encoder(nn.Module):
350
+ # def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
351
+ # """
352
+ # Feature encoder for audio processing.
353
+ # """
354
+ # super().__init__()
355
+
356
+ # self.dims = dims
357
+ # self.head = head
358
+ # self.head_dim = dims // head
359
+ # self.dropout = 0.01
360
+ # self.use_rope = use_rope
361
+ # self.attend_feature = attend_feature
362
+ # self.target_length = target_length
363
+ # self.feature = feature
364
+
365
+ # self.debug = debug
366
+ # act_fn = get_activation(act)
367
+
368
+ # if self.attend_feature:
369
+ # self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
370
+ # self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
371
+ # else:
372
+ # self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
373
+ # self.mlp = None
374
+
375
+ # self.spectrogram = nn.Sequential(
376
+ # Conv1d(mels, dims, kernel_size=3), act_fn,
377
+ # Conv1d(dims, dims, kernel_size=3), act_fn,
378
+ # Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
379
+
380
+ # self.waveform = nn.Sequential(
381
+ # Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
382
+ # Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
383
+ # Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
384
+
385
+ # self.pitch = nn.Sequential(
386
+ # Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
387
+ # Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
388
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
389
+
390
+ # if use_rope:
391
+ # # if spec_shape is not None:
392
+ # self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
393
+ # self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
394
+ # else:
395
+ # self.rope = None
396
+ # self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
397
+ # self.norm = RMSNorm(dims)
398
+
399
+ # def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
400
+ # if isinstance(x, int):
401
+ # ctx = x
402
+ # elif isinstance(x, torch.Tensor):
403
+ # ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
404
+ # batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
405
+
406
+ # x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
407
+ # freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
408
+ # x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue]
409
+ # x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
410
+ # return x
411
+
412
+ # def mel_scalar(self, freq: float) -> float:
413
+ # return 1127.0 * math.log(1.0 + freq / 700.0)
414
+
415
+ # def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
416
+ # target_length = x.shape[1] if self.target_length is None else self.target_length
417
+
418
+ # if feature == "pitch":
419
+ # xp = x.clone()
420
+ # enc_dict = feats if feats is not None else {}
421
+ # enc_dict = dict(enc_dict)
422
+ # enc_dict["f0"] = xp
423
+ # # xp = self.mel_scalar(xp.mean())
424
+ # # print(f"Using pitch scalar: {xp}")
425
+ # # max_tscale = xp*300
426
+ # # print(f"Using max_tscale: {max_tscale}")
427
+ # feats = enc_dict
428
+ # if x.dim() == 2:
429
+ # x = x.unsqueeze(0)
430
+ # x = self.pitch(x).permute(0, 2, 1)
431
+
432
+ # if feature == "phase":
433
+ # if x.dim() == 2:
434
+ # x = x.unsqueeze(0)
435
+ # x = self.pitch(x).permute(0, 2, 1)
436
+
437
+ # if feature == "waveform":
438
+ # if x.dim() == 2:
439
+ # x = x.unsqueeze(0)
440
+ # x = self.waveform(x).permute(0, 2, 1)
441
+ # if target_length and x.shape[1] != self.target_length:
442
+ # x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
443
+
444
+ # if feature == "harmonics":
445
+ # if x.dim() == 2:
446
+ # x = x.unsqueeze(0)
447
+ # x = self.spectrogram(x).permute(0, 2, 1)
448
+
449
+ # if feature == "aperiodic":
450
+ # if x.dim() == 2:
451
+ # x = x.unsqueeze(0)
452
+ # x = self.spectrogram(x).permute(0, 2, 1)
453
+
454
+ # if feature == "spectrogram":
455
+ # if x.dim() == 2:
456
+ # x = x.unsqueeze(0)
457
+ # x = self.spectrogram(x).permute(0, 2, 1)
458
+
459
+ # if self.use_rope:
460
+ # x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
461
+ # x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
462
+ # else:
463
+ # max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
464
+ # x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
465
+ # x = nn.functional.dropout(x, p=self.dropout, training=self.training)
466
+ # x = self.norm(x)
467
+
468
+ # if self.attend_feature:
469
+ # xa = feats[feature] # pyright: ignore[reportOptionalSubscript]
470
+ # if xa is not None:
471
+ # q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
472
+ # out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
473
+ # x = x + out
474
+
475
+ # x = nn.functional.dropout(x, p=self.dropout, training=self.training)
476
+ # x = self.norm(x)
477
+ # return x
478
+
479
+ class OneShot(nn.Module):
480
+ def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
481
+ super().__init__()
482
+ if features is None:
483
+ features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
484
+ self.head = head
485
+ self.head_dim = dims // head
486
+ self.scale = 1.0 // len(features) if features else scale
487
+
488
+ self.q = Linear(dims, dims)
489
+ self.k = Linear(dims, dims)
490
+
491
+ def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
492
+ B, L, D = x.shape
493
+ K = xa.size(1)
494
+ q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
495
+ k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
496
+ bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
497
+ return bias
498
+
499
+ class curiosity(nn.Module):
500
+ def __init__(self, d, h, bias=True):
501
+ super().__init__()
502
+ self.h = h
503
+ self.dh = d // h
504
+ self.qkv = nn.Linear(d, d * 3, bias=bias)
505
+ self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
506
+ self.o = nn.Linear(d, d, bias=bias)
507
+ self.g = nn.Parameter(torch.zeros(h))
508
+
509
+ def split(self, x):
510
+ b, t, _ = x.shape
511
+ return x.view(b, t, self.h, self.dh).transpose(1, 2)
512
+
513
+ def merge(self, x):
514
+ b, h, t, dh = x.shape
515
+ return x.transpose(1, 2).contiguous().view(b, t, h * dh)
516
+
517
+ def forward(self, x, xa, mask=None):
518
+ q, k, v = self.qkv(x).chunk(3, -1)
519
+ qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
520
+ q, k, v = map(self.split, (q, k, v))
521
+ qa, ka, va = map(self.split, (qa, ka, va))
522
+ dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
523
+ dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
524
+ if mask is not None: dots = dots.masked_fill(mask, -9e15)
525
+ p = dots.softmax(-1)
526
+ pa = dots_aux.softmax(-1)
527
+ h_main = p @ v
528
+ h_aux = pa @ va
529
+ g = torch.sigmoid(self.g).view(1, -1, 1, 1)
530
+ out = self.merge(h_main * (1 - g) + h_aux * g)
531
+ return self.o(out)
532
+
533
+ class PositionalEncoding(nn.Module):
534
+ def __init__(self, dims, ctx):
535
+ super(PositionalEncoding, self).__init__()
536
+ self.dims = dims
537
+ self.ctx = ctx
538
+ self.pe = self.get_positional_encoding(max_ctx=ctx)
539
+
540
+ def get_positional_encoding(self, max_ctx):
541
+ pe = torch.zeros(max_ctx, self.dims)
542
+ position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
543
+ div_term = torch.exp(
544
+ torch.arange(0, self.dims, 2, dtype=torch.float32)
545
+ * (-math.log(10000.0) / self.dims)
546
+ )
547
+ pe[:, 0::2] = torch.sin(position * div_term)
548
+ pe[:, 1::2] = torch.cos(position * div_term)
549
+ pe = pe.unsqueeze(0)
550
+ return pe.to(device)
551
+
552
+ def forward(self, x):
553
+ ctx = x.size(1)
554
+ pe = self.pe[:, :ctx, :]
555
+ x = x * math.sqrt(self.dims)
556
+ x = x + pe
557
+ return x
558
+
559
+
560
  def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
561
  title="", markers=None, marker_labels=None,
562
  show_voiced_regions=True, show_energy=False):
 
755
  features[:, 0::2] = torch.sin(position * div_term)
756
  features[:, 1::2] = torch.cos(position* div_term)
757
  self.register_buffer('sinusoid', tensor=features)
758
+ self.positional_embeddings = nn.Parameter(self.sinusoid.clone()) # type: ignore
759
  def forward(self, positions):
760
  position_embeddings = self.positional_embeddings[positions]
761
  return position_embeddings
762
 
763
  def sinusoids(length, channels, max_tscale=10000):
764
  assert channels % 2 == 0
765
+ log_tscale_increment = torch.log(torch.tensor(float(max_tscale))) / (channels // 2 - 1)
766
+ inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2, device=device, dtype=torch.float32))
767
+ scaled_t = torch.arange(length, device=device, dtype=torch.float32).unsqueeze(1) * inv_tscales.unsqueeze(0)
768
  return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
769
 
770
+ class SelfCriticalRL(nn.Module):
771
+ def __init__(self, model, tokenizer, reward_fn):
772
+ super().__init__()
773
+ self.model = model
774
+ self.tokenizer = tokenizer
775
+ self.reward_fn = reward_fn
776
+
777
+ def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
778
+
779
+ with torch.no_grad():
780
+ greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
781
+ greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
782
+ sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
783
+ sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
784
+
785
+ rewards = []
786
+ baseline = []
787
+ for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore
788
+ ref_text = self.tokenizer.decode(ref)
789
+ rewards.append(self.reward_fn(s, ref_text))
790
+ baseline.append(self.reward_fn(g, ref_text))
791
+ rewards = torch.tensor(rewards, device=device, dtype=torch.float)
792
+ baseline = torch.tensor(baseline, device=device, dtype=torch.float)
793
+ advantage = rewards - baseline
794
+ logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size]
795
+ log_probs = F.log_softmax(logits, dim=-1)
796
+ log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
797
+ log_probs_sum = log_probs_seq.sum(dim=1)
798
+ loss = -(advantage * log_probs_sum).mean()
799
+ return loss
800
+
801
+ class SelfTrainingModule(nn.Module):
802
+ def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
803
+ super().__init__()
804
+ self.model = model
805
+ self.tokenizer = tokenizer
806
+ self.quality_fn = quality_fn
807
+ self.threshold = threshold
808
+
809
+ def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
810
+ with torch.no_grad():
811
+ pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
812
+
813
+ if self.quality_fn is not None:
814
+ quality_scores = self.quality_fn(pred_ids, self.model, features)
815
+ mask = quality_scores > self.threshold
816
+ pred_ids = pred_ids[mask]
817
+ return pred_ids
818
+
819
+ def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
820
+ pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
821
+ logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
822
+ loss = nn.functional.cross_entropy(
823
+ logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
824
+ return loss
825
+
826
+ def confidence_indicator(pred_ids, model, features):
827
+ with torch.no_grad():
828
+ logits = model(input_ids=pred_ids, **features)["logits"]
829
+ probs = torch.softmax(logits, dim=-1)
830
+ max_probs, _ = probs.max(dim=-1)
831
+ return max_probs.mean(dim=1)
832
+
833
+ def wer_reward(hyp, ref):
834
+
835
+ hyp_words = hyp.split()
836
+ ref_words = ref.split()
837
+ d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
838
+ for i in range(len(hyp_words)+1):
839
+ d[i][0] = i
840
+ for j in range(len(ref_words)+1):
841
+ d[0][j] = j
842
+ for i in range(1, len(hyp_words)+1):
843
+ for j in range(1, len(ref_words)+1):
844
+ if hyp_words[i-1] == ref_words[j-1]:
845
+ d[i][j] = d[i-1][j-1]
846
+ else:
847
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
848
+ wer = d[-1][-1] / max(1, len(ref_words))
849
+ return -wer # negative WER as reward
850
+
851
  def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
852
  if isinstance(ids, torch.Tensor):
853
  ids = ids.tolist()
 
858
 
859
  def setup_tokenizer(dir: str):
860
  from tokenizers import Tokenizer
861
+ tokenizer = Tokenizer.from_file(f"{dir}")
862
  orig_encode = tokenizer.encode
863
  orig_decode = tokenizer.decode
864
 
 
925
  ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
926
  return sp_mel, ap_mel
927
 
928
+ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
929
+
930
+ # import torchaudio
931
+ # import torchaudio.functional
932
+ # import torchaudio.transforms
933
+
934
+ # torch_windows = {
935
+ # 'hann': torch.hann_window,
936
+ # 'hamming': torch.hamming_window,
937
+ # 'blackman': torch.blackman_window,
938
+ # 'bartlett': torch.bartlett_window,
939
+ # 'ones': torch.ones,
940
+ # None: torch.ones,
941
+ # }
942
+ # if dummy:
943
+ # return {
944
+ # "spectrogram": torch.zeros((1, 128, 100)),
945
+ # "f0": torch.zeros((1, 100)),
946
+ # "f0t": torch.zeros((1, 100)),
947
+ # "pitch": torch.zeros((1, 100)),
948
+ # "harmonics": torch.zeros((1, 128, 100)),
949
+ # "aperiodics": torch.zeros((1, 128, 100)),
950
+ # "crepe_time": None,
951
+ # "crepe_frequency": None,
952
+ # "crepe_confidence": None,
953
+ # "crepe_activation": None,
954
+ # }
955
 
956
  audio = batch["audio"]
957
  sample_rate = audio["sampling_rate"]
 
959
  wav = load_wave(wave_data=audio, sample_rate=sample_rate)
960
 
961
  spectrogram_config = {
962
+ # "hop_length": 256,
963
+ # "f_min": 150,
964
+ # "f_max": 2000,
965
+ # "n_mels": 128,
966
+ # "n_fft": 1024,
967
  "sample_rate": 16000,
968
+ # "pad_mode": "constant",
969
+ # "center": True,
970
+ # "power": 1.0,
971
+ # "window_fn": torch.hann_window,
972
+ # "mel_scale": "htk",
973
+ # "norm": None,
974
+ # "normalized": False,
975
  }
976
 
977
+ def crepe_predict(wav, sample_rate, viterbi=False):
978
+ import torchcrepe
979
+ wav = wav.numpy().astype(np.float32)
980
+ time, frequency, confidence, activation = torchcrepe.predict(
981
+ wav, sample_rate=sample_rate, viterbi=viterbi)
982
  crepe_time = torch.from_numpy(time)
983
  crepe_frequency = torch.from_numpy(frequency)
984
  crepe_confidence = torch.from_numpy(confidence)
985
  crepe_activation = torch.from_numpy(activation)
986
+ return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
987
+
988
+ if crepe:
989
+ crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
990
+
991
  else:
992
  crepe_time = None
993
  crepe_frequency = None
994
  crepe_confidence = None
995
  crepe_activation = None
996
 
997
+ # def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
998
+ # if isinstance(window_fn, str):
999
+ # window_fn = torch_windows[window_fn]
1000
+ # if window_fn is None:
1001
+ # window_fn = torch.ones(n_fft)
1002
+ # if isinstance(window_fn, torch.Tensor):
1003
+ # window_fn = window_fn.to(device)
1004
+ # return torchaudio.functional.spectrogram(
1005
+ # wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
1006
+ # window=window_fn, center=True, pad_mode="reflect", power=1.0)
1007
+
1008
+ # def mel_spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
1009
+ # transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
1010
+ # mel_spectrogram = transform(wav)
1011
+ # log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
1012
+ # log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1013
+ # spectrogram_tensor = (log_mel + 4.0) / 4.0
1014
+ # spectrogram_tensor = torch.tensor(spectrogram_tensor)
1015
+ # return spectrogram_tensor
1016
+ if spec:
1017
  transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
1018
  mel_spectrogram = transform(wav)
1019
  log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
 
1020
  log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
1021
  spectrogram_tensor = (log_mel + 4.0) / 4.0
1022
  spectrogram_tensor = torch.tensor(spectrogram_tensor)
1023
+
1024
+
1025
+
1026
+ # if spec:
1027
+ # if isinstance(wav, torch.Tensor):
1028
+ # wav = wav.to(device)
1029
+ # spectrogram_tensor = mel_spectrogram(wav, sample_rate, **spectrogram_config)
1030
+ # spectrogram_tensor = spectrogram_tensor.permute(1, 0)
1031
+
1032
+
1033
+ def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
1034
+ transform = torchaudio.transforms.MFCC(
1035
+ sample_rate=sample_rate,
1036
+ n_mfcc=n_mels,
1037
+ melkwargs={
1038
+ "n_fft": n_fft,
1039
+ "hop_length": hop_length,
1040
+ "window_fn": window_fn,
1041
+ "n_mels": n_mels,
1042
+ "center": True,
1043
+ "pad_mode": "reflect",
1044
+ "norm": None,
1045
+ "mel_scale": "htk",
1046
+ }
1047
+ )
1048
+ mfcc_tensor = transform(wav)
1049
+ return mfcc_tensor
1050
+
1051
+
1052
+ def compute_pitch(wav, sample_rate, hop_length=256):
1053
+ import pyworld as pw
1054
+ wav_np = wav.numpy().astype(np.float64)
1055
+ f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000)
1056
+ f0 = pw.stonemask(wav_np, f0, t, sample_rate)
1057
+ return f0, t
1058
+
1059
+ def compute_harmonics_and_aperiodics(wav, f0, t, sample_rate):
1060
+ import pyworld as pw
1061
+ wav_np = wav.numpy().astype(np.float64)
1062
+ sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
1063
+ ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
1064
+ harmonic_tensor = torch.from_numpy(sp)
1065
+ aperiodic_tensor = torch.from_numpy(ap)
1066
+ harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
1067
+ aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
1068
+ harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
1069
+ aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
1070
+ return harmonic_tensor, aperiodic_tensor
1071
+
1072
 
1073
  if f0 or f0t or pitch or harmonics or aperiodics:
1074
  wavnp = wav.numpy().astype(np.float64)
 
1092
  end_idx = torch.searchsorted(t2, token_ends, side="right")
1093
  pitch_tok = torch.zeros(T, dtype=torch.float32)
1094
  for i in range(T):
1095
+ lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore
1096
  segment = f0_np[lo:hi]
1097
  if mode == "mean":
1098
  pitch_tok[i] = segment.mean()
 
1103
  pitch_tok[pitch_tok < 100.0] = 0.0
1104
  bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
1105
  f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
1106
+ f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
1107
  else:
1108
  f0t_tensor = None
1109
 
1110
  if phase_mod:
1111
  tframe = torch.mean(t2[1:] - t2[:-1])
1112
  phi0 = 0.0
1113
+ omega = 2 * torch.pi * f0_tensor # type: ignore
1114
  dphi = omega * tframe
1115
  phi = torch.cumsum(dphi, dim=0) + phi0
1116
  phase = torch.remainder(phi, 2 * torch.pi)
 
1118
  phase = None
1119
 
1120
  if pitch:
1121
+ p_tensor = compute_pitch(wav, sample_rate, hop_length=hop_length)[0]
1122
+ p_tensor = torch.from_numpy(p_tensor)
1123
+ p_tensor = p_tensor.unsqueeze(0)
1124
+ # p_tensor = torch.from_numpy(f0_np)
1125
  else:
1126
  p_tensor = None
1127
 
 
1143
  else:
1144
  wave_tensor = None
1145
 
1146
+ if dummy:
1147
+ if spectrogram_tensor is not None:
1148
+ dummy_tensor = torch.ones_like(spectrogram_tensor)
1149
+ elif p_tensor is not None:
1150
+ dummy_tensor = torch.ones_like(p_tensor)
1151
+ elif f0_tensor is not None:
1152
+ dummy_tensor = torch.ones_like(f0_tensor)
1153
+ elif f0t_tensor is not None:
1154
+ dummy_tensor = torch.ones_like(f0t_tensor)
1155
+ else:
1156
+ batch_size = 128
1157
+ seq_len = 1024
1158
+ dummy_tensor = torch.ones(batch_size, seq_len)
1159
+ dummy_tensor = dummy_tensor.to(device)
1160
+
1161
+ else:
1162
+ dummy_tensor = None
1163
+
1164
  if debug:
1165
 
1166
+ print(f"['f0']: {f0_tensor.shape if f0 else None}")
1167
  print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
1168
  print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
1169
  print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
 
1172
  print(f"['labels']: {len(labels) if labels else None}")
1173
  print(f"['phase']: {phase.shape if phase else None}")
1174
  print(f"['pitch']: {p_tensor.shape if pitch else None}")
1175
+ print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
1176
  print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
1177
  print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
1178
  print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
1179
+ print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
1180
 
1181
  return {
1182
  "waveform": wave_tensor if waveform else None,
 
1192
  "crepe_frequency": crepe_frequency if crepe else None,
1193
  "crepe_confidence": crepe_confidence if crepe else None,
1194
  "crepe_activation": crepe_activation if crepe else None,
1195
+ "dummy": dummy_tensor if dummy else None,
1196
  }
1197
 
1198
  def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
 
1213
  "debug": False,
1214
  "phase_mod": False,
1215
  "crepe": False,
1216
+ "dummy": False,
1217
  }
1218
 
1219
  if load_saved:
 
1271
 
1272
  return train_dataset, test_dataset
1273
 
1274
+ def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
1275
+ if feature == "spectrogram":
1276
+ return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head, layer=layer, act=act, feature=feature, features=features)
1277
+ elif feature == "waveform":
1278
+ return WEncoder(input_dims, dims, head, layer, act, feature, features)
1279
+ elif feature == "pitch":
1280
+ return PEncoder(input_dims, dims, head, layer, act, feature, features)
1281
+ else:
1282
+ raise ValueError(f"Unknown feature type: {feature}")
1283
+
1284
+ class FEncoder(nn.Module):
1285
+ def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
1286
+ super().__init__()
1287
+
1288
+ self.head = head
1289
+ self.head_dim = dims // head
1290
+ self.dropout = 0.01
1291
+ self.use_rope = use_rope
1292
+ self.dims = dims
1293
+ self.debug = debug
1294
+ self.feature = feature
1295
+ self.mels = mels
1296
+ self.input_dims = input_dims
1297
+ act_fn = get_activation(act)
1298
+
1299
+ self.encoder = nn.Sequential(
1300
+ Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
1301
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1302
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1303
+
1304
+ if use_rope:
1305
+ if spec_shape is not None:
1306
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore
1307
+ else:
1308
+ self.rope = None
1309
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1310
+ self.norm = RMSNorm(dims)
1311
+
1312
+ def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
1313
+ batch, ctx, dims = x.shape
1314
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1315
+ freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
1316
+ x = self.rope.apply_rotary(x, freqs)# type: ignore
1317
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1318
+
1319
+ return x
1320
+
1321
+ def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
1322
+ x = self.encoder(x).permute(0, 2, 1)
1323
+ if self.use_rope:
1324
+ x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1325
+ else:
1326
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1327
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1328
+ print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1329
+ x = self.norm(x)
1330
+ return x
1331
+
1332
+ class WEncoder(nn.Module): # waveform encoder
1333
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
1334
+ super().__init__()
1335
+
1336
+ self.head = head
1337
+ self.head_dim = dims // head
1338
+ self.dropout = 0.01
1339
+ self.use_rope = use_rope
1340
+ self.dims = dims
1341
+ self.debug = debug
1342
+ act_fn = get_activation(act)
1343
+ self.target_length = None
1344
+ self.encoder = nn.Sequential(
1345
+ Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
1346
+ Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
1347
+ Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
1348
+
1349
+ if use_rope:
1350
+ if spec_shape is not None:
1351
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
1352
+ else:
1353
+ self.rope = None
1354
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1355
+ self.norm = RMSNorm(dims)
1356
+
1357
+ def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
1358
+ batch, ctx, dims = x.shape
1359
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1360
+ freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
1361
+ x = self.rope.apply_rotary(x, freqs)# type: ignore
1362
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1363
+ return x
1364
+
1365
+ def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
1366
+ x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims)
1367
+ if self.target_length and x.shape[1] != self.target_length:
1368
+ x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
1369
+ if self.use_rope:
1370
+ x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1371
+ else:
1372
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1373
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1374
+ print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1375
+ return self.norm(x)
1376
+
1377
+ class PEncoder(nn.Module): # pitch encoder
1378
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
1379
+ super().__init__()
1380
+
1381
+ self.head = head
1382
+ self.head_dim = dims // head
1383
+ self.dims = dims
1384
+ self.dropout = 0.01
1385
+ self.use_rope = use_rope
1386
+ self.debug = debug
1387
+ act_fn = get_activation(act)
1388
+
1389
+ self.attend_pitch = False
1390
+
1391
+ if self.attend_pitch:
1392
+ self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
1393
+ self.mlp = nn.Sequential(
1394
+ nn.Linear(dims, dims),
1395
+ nn.ReLU(),
1396
+ nn.Linear(dims, dims),
1397
+ )
1398
+ else:
1399
+ self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
1400
+ self.mlp = None
1401
+
1402
+ self.pitch_encoder = nn.Sequential(
1403
+ Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
1404
+ Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
1405
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1406
+
1407
+ # self.spectrogram_encoder = nn.Sequential(
1408
+ # Conv1d(input_dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1409
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
1410
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
1411
+
1412
+ # self.waveform_encoder = nn.Sequential(
1413
+ # Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
1414
+ # Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
1415
+ # Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
1416
+
1417
+ if use_rope:
1418
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
1419
+ else:
1420
+ self.rope = None
1421
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
1422
+ self.norm = RMSNorm(dims)
1423
+
1424
+ def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
1425
+ batch, ctx, dims = x.shape
1426
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
1427
+ freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore
1428
+ x = self.rope.apply_rotary(x, freqs)# type: ignore
1429
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
1430
+ return x
1431
+
1432
+ def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
1433
+ # f0=x
1434
+ # freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer)
1435
+ if x.dim() == 2:
1436
+ x = x.unsqueeze(0)
1437
+ if feature == "pitch":
1438
+ x = self.pitch_encoder(x).permute(0, 2, 1)
1439
+ # elif feature == "spectrogram":
1440
+ # x = self.spectrogram_encoder(x).permute(0, 2, 1)
1441
+ # elif feature == "waveform":
1442
+ # x = self.waveform_encoder(x).permute(0, 2, 1)
1443
+
1444
+ # if self.target_length and x.shape[1] != self.target_length:
1445
+ # x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
1446
+
1447
+ if self.use_rope:
1448
+ x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
1449
+
1450
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
1451
+ if self.mlp is not None:
1452
+ x = self.mlp(x)
1453
+
1454
+ if self.attend_pitch:
1455
+ if xa is not None:
1456
+ q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
1457
+ out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
1458
+
1459
+ x = x + out
1460
+
1461
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
1462
+ x = self.norm(x)
1463
+ print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
1464
+ return x
1465
+
1466
+
1467
  @dataclass
1468
  class DataCollator:
1469
  tokenizer: Any
 
1495
  batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
1496
  batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
1497
 
1498
+ elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
1499
  items = [f[key] for f in features if key in f]
1500
  items = [item for item in items if item is not None]
1501
  if not items: