Update model.py
Browse files
model.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import pyworld as pw
|
3 |
import os
|
4 |
import math
|
@@ -6,6 +5,7 @@ import warnings
|
|
6 |
import logging
|
7 |
import gzip
|
8 |
import base64
|
|
|
9 |
import torch
|
10 |
import torchaudio
|
11 |
import torchcrepe
|
@@ -16,12 +16,13 @@ import numpy as np
|
|
16 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
17 |
from functools import partial
|
18 |
from datetime import datetime
|
19 |
-
from datasets import load_dataset, Audio
|
20 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
21 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
22 |
import transformers
|
23 |
import evaluate
|
24 |
from dataclasses import dataclass
|
|
|
25 |
|
26 |
torch.backends.cudnn.allow_tf32 = True
|
27 |
torch.backends.cuda.matmul.allow_tf32 = True
|
@@ -36,6 +37,11 @@ warnings.filterwarnings("ignore")
|
|
36 |
logging.basicConfig(level=logging.ERROR)
|
37 |
tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
extractor = None
|
40 |
tokenizer = None
|
41 |
optimizer = None
|
@@ -61,7 +67,7 @@ class Dimensions:
|
|
61 |
cross_attn: bool
|
62 |
features: List[str]
|
63 |
f0_rotary: bool
|
64 |
-
|
65 |
def exists(v):
|
66 |
return v is not None
|
67 |
|
@@ -99,7 +105,7 @@ class RMSNorm(nn.Module):
|
|
99 |
self.eps = eps
|
100 |
self.elementwise_affine = elementwise_affine
|
101 |
if self.elementwise_affine:
|
102 |
-
self.weight = nn.Parameter(torch.empty(self.normalized_shape))
|
103 |
init.ones_(self.weight)
|
104 |
else:
|
105 |
self.register_parameter("weight", None)
|
@@ -128,6 +134,7 @@ def sinusoids(length, channels, max_timescale=10000):
|
|
128 |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
129 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
130 |
|
|
|
131 |
class rotary(nn.Module):
|
132 |
_seen = set()
|
133 |
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
|
@@ -144,32 +151,29 @@ class rotary(nn.Module):
|
|
144 |
self.dims = dims
|
145 |
self.max_ctx = max_ctx
|
146 |
self.radii = radii
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
150 |
|
151 |
-
self.
|
152 |
-
torch.tensor(
|
153 |
-
|
154 |
-
torch.tensor(
|
155 |
|
156 |
-
self.theta = nn.Parameter(
|
157 |
-
torch.tensor(float(theta)), requires_grad=learned_theta)
|
158 |
|
159 |
-
self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale),
|
160 |
-
requires_grad=learned_pitch)
|
161 |
|
162 |
freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
|
163 |
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
164 |
|
165 |
if radii:
|
166 |
-
self.radius = nn.Parameter(torch.ones(dims // 2),
|
167 |
-
requires_grad=learned_radius)
|
168 |
|
169 |
def get_pitch_bias(self, f0):
|
170 |
if f0 is None:
|
171 |
return None
|
172 |
-
|
173 |
f0_flat = f0.squeeze().float()
|
174 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
175 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
@@ -195,21 +199,18 @@ class rotary(nn.Module):
|
|
195 |
def align_f0(self, f0, token_length):
|
196 |
batch_size, f0_length = f0.shape
|
197 |
if f0_length == token_length:
|
198 |
-
return f0
|
199 |
frames_per_token = f0_length / token_length
|
200 |
-
|
201 |
indices = torch.arange(token_length, device=f0.device)
|
202 |
indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
|
203 |
-
#center_positions = ((indices + 0.5) * frames_per_token).long()
|
204 |
-
batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1)
|
205 |
return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
|
206 |
|
207 |
def scale_f0(self, f0):
|
208 |
f0_min = f0.min(dim=1, keepdim=True)[0]
|
209 |
f0_max = f0.max(dim=1, keepdim=True)[0]
|
210 |
denom = f0_max - f0_min + 1e-8
|
211 |
-
normalized_f0 = (f0 - f0_min) / denom
|
212 |
-
# normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
|
213 |
normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
|
214 |
return normalized_f0
|
215 |
|
@@ -231,47 +232,68 @@ class rotary(nn.Module):
|
|
231 |
log_freq = torch.log(freq)
|
232 |
log_min_freq = torch.log(min_freq)
|
233 |
log_max_freq = torch.log(max_freq)
|
234 |
-
|
235 |
mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
|
236 |
return mapped_log_freq
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def forward(self, x=None, f0=None, stage=None) -> Tensor:
|
239 |
if isinstance(x, int):
|
240 |
-
|
241 |
else:
|
242 |
-
batch,
|
243 |
-
t = torch.arange(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
if f0 is not None:
|
246 |
-
f0_mean
|
247 |
-
|
248 |
-
|
249 |
-
freqs = 1.0 / (
|
250 |
else:
|
251 |
freqs = self.freqs
|
252 |
|
253 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
254 |
freqs = freqs.float()
|
255 |
|
256 |
-
if self.radii and f0 is not None:
|
257 |
-
|
|
|
258 |
# radius = self.scale_f0(radius)
|
259 |
radius = F.softplus(self.radius) * radius
|
260 |
-
# radius = radius.unsqueeze(-1)
|
261 |
freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
|
262 |
else:
|
263 |
freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
|
264 |
# print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
|
265 |
if "rotary" in self.debug:
|
266 |
if f0 is not None:
|
267 |
-
key = f"{self._counter}_{
|
268 |
if key not in rotary._seen:
|
269 |
if not hasattr(self, '_prev_f0_theta'):
|
270 |
-
self._prev_f0_theta =
|
271 |
-
print(f"Step {self._counter}: Using raw F0 as theta: {
|
272 |
-
elif abs(self._prev_f0_theta -
|
273 |
-
print(f"Step {self._counter}: Using raw F0 as theta: {
|
274 |
-
self._prev_f0_theta =
|
275 |
rotary._seen.add(key)
|
276 |
self._counter += 1
|
277 |
return freqs
|
@@ -306,81 +328,6 @@ class rotary(nn.Module):
|
|
306 |
x1 = torch.view_as_real(x1).flatten(-2)
|
307 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
308 |
|
309 |
-
class SliceAttention(nn.Module):
|
310 |
-
def __init__(self, dims, heads, dropout=0.0):
|
311 |
-
super().__init__()
|
312 |
-
self.dims = dims
|
313 |
-
self.heads = heads
|
314 |
-
self.head_dim = dims // heads
|
315 |
-
self.scale = self.head_dim ** -0.5
|
316 |
-
|
317 |
-
self.q_proj = Linear(dims, dims)
|
318 |
-
self.k_proj = Linear(dims, dims)
|
319 |
-
self.v_proj = Linear(dims, dims)
|
320 |
-
self.out_proj = Linear(dims, dims)
|
321 |
-
self.dropout = nn.Dropout(dropout)
|
322 |
-
|
323 |
-
assert dims % heads == 0, f"Dimensions {dims} not divisible by heads {heads}"
|
324 |
-
|
325 |
-
def parallel_slice(self, q, k, v, mask=None):
|
326 |
-
batch, heads, ctx, dims = q.shape
|
327 |
-
head_dim = self.head_dim
|
328 |
-
batch, ctx, dims = q.shape
|
329 |
-
ctx_len = k.shape[1]
|
330 |
-
num_heads = dims // head_dim
|
331 |
-
|
332 |
-
scores = torch.zeros(batch, num_heads, ctx, ctx_len, device=q.device)
|
333 |
-
|
334 |
-
for h in range(num_heads):
|
335 |
-
start_idx = h * head_dim
|
336 |
-
end_idx = start_idx + head_dim
|
337 |
-
q_h = q[:, :, start_idx:end_idx]
|
338 |
-
k_h = k[:, :, start_idx:end_idx]
|
339 |
-
|
340 |
-
scores[:, h] = torch.bmm(q_h, k_h.transpose(1, 2)) / math.sqrt(head_dim)
|
341 |
-
|
342 |
-
if mask is not None:
|
343 |
-
scores = scores + mask.unsqueeze(0).unsqueeze(0)
|
344 |
-
|
345 |
-
attn_weights = F.softmax(scores, dim=-1)
|
346 |
-
|
347 |
-
output = torch.zeros_like(q)
|
348 |
-
for h in range(num_heads):
|
349 |
-
start_idx = h * head_dim
|
350 |
-
end_idx = start_idx + head_dim
|
351 |
-
v_h = v[:, :, start_idx:end_idx]
|
352 |
-
output[:, :, start_idx:end_idx] = torch.bmm(attn_weights[:, h], v_h)
|
353 |
-
return output
|
354 |
-
|
355 |
-
def forward(self, x, context=None, mask=None):
|
356 |
-
batch, ctx, _ = x.shape
|
357 |
-
if context is None:
|
358 |
-
context = x
|
359 |
-
|
360 |
-
ctx_len = context.shape[1]
|
361 |
-
q = self.q_proj(x)
|
362 |
-
k = self.k_proj(context)
|
363 |
-
v = self.v_proj(context)
|
364 |
-
output = torch.zeros_like(q)
|
365 |
-
|
366 |
-
for h in range(self.heads):
|
367 |
-
start_idx = h * self.head_dim
|
368 |
-
end_idx = start_idx + self.head_dim
|
369 |
-
|
370 |
-
q_h = q[:, :, start_idx:end_idx]
|
371 |
-
k_h = k[:, :, start_idx:end_idx]
|
372 |
-
v_h = v[:, :, start_idx:end_idx]
|
373 |
-
|
374 |
-
attn_scores = torch.bmm(q_h, k_h.transpose(1, 2)) * self.scale
|
375 |
-
if mask is not None:
|
376 |
-
attn_scores = attn_scores + mask[:ctx, :ctx_len].unsqueeze(0)
|
377 |
-
|
378 |
-
attn_weights = F.softmax(attn_scores, dim=-1)
|
379 |
-
attn_weights = self.dropout(attn_weights)
|
380 |
-
head_output = torch.bmm(attn_weights, v_h)
|
381 |
-
output[:, :, start_idx:end_idx] = head_output
|
382 |
-
return self.out_proj(output)
|
383 |
-
|
384 |
def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
|
385 |
|
386 |
batch, heads, ctx, dims = q.shape
|
@@ -403,44 +350,50 @@ def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
|
|
403 |
class MultiheadA(nn.Module):
|
404 |
_seen = set()
|
405 |
rbf = False
|
406 |
-
def __init__(self, dims: int, head: int, rotary_emb: bool =
|
407 |
zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
|
408 |
|
409 |
super(MultiheadA, self).__init__()
|
410 |
|
411 |
-
self.debug = debug
|
412 |
-
self.pad_token = 0
|
413 |
self.dims = dims
|
414 |
self.head = head
|
415 |
self.head_dim = dims // head
|
416 |
-
self.rotary_emb = rotary_emb
|
417 |
-
self.minz = minz
|
418 |
-
self.maxz = maxz
|
419 |
-
self.zero_val = zero_val
|
420 |
-
self.optim_attn = optim_attn
|
421 |
-
self._counter = 0
|
422 |
-
if dims % head != 0:
|
423 |
-
raise ValueError(f"Dimensions {dims} must be divisible by number of heads {head}.")
|
424 |
-
if zero_val < minz or zero_val > maxz:
|
425 |
-
raise ValueError(f"Zero value {zero_val} must be between {minz} and {maxz}.")
|
426 |
|
427 |
self.q = Linear(dims, dims)
|
428 |
self.k = Linear(dims, dims, bias=False)
|
429 |
self.v = Linear(dims, dims)
|
430 |
self.o = Linear(dims, dims)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True)
|
432 |
|
433 |
if rotary_emb:
|
434 |
self.rope = rotary(
|
435 |
dims=self.head_dim,
|
436 |
-
debug
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
439 |
else:
|
440 |
self.rope = None
|
441 |
|
442 |
def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
|
443 |
-
scale =
|
444 |
dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
|
445 |
if rbf_ratio <= 0.0:
|
446 |
return dot_scores
|
@@ -452,37 +405,25 @@ class MultiheadA(nn.Module):
|
|
452 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
453 |
|
454 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
|
455 |
-
return_attn: bool = False, f0: Tensor = None) -> tuple:
|
456 |
|
457 |
batch, ctx, dims = x.shape
|
458 |
-
scale =
|
459 |
|
460 |
z = default(xa, x)
|
461 |
q = self.q(x).to(x.dtype)
|
462 |
k = self.k(z).to(x.dtype)
|
463 |
v = self.v(z).to(x.dtype)
|
464 |
|
465 |
-
if self.rotary_emb:
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
else:
|
470 |
-
qf = self.rope(q.size(1))
|
471 |
-
kf = self.rope(k.size(1))
|
472 |
-
|
473 |
-
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
474 |
-
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
475 |
-
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
476 |
-
|
477 |
q = self.rope.apply_rotary(q, qf)
|
478 |
-
k = self.rope.apply_rotary(k, kf)
|
479 |
-
|
480 |
-
else:
|
481 |
-
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
482 |
-
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
483 |
-
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
484 |
-
batch, head, ctx, head_dim = q.shape
|
485 |
|
|
|
|
|
486 |
if self.optim_attn and not return_attn:
|
487 |
wv = optim_attn(q * scale, k * scale, v, mask=mask,
|
488 |
pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
|
@@ -490,33 +431,43 @@ class MultiheadA(nn.Module):
|
|
490 |
|
491 |
if self.rbf:
|
492 |
qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
|
|
|
|
493 |
|
494 |
-
|
495 |
-
if f0 is not None and self.rope.use_pbias:
|
496 |
pbias = self.rope.pbias(f0)
|
497 |
if pbias is not None:
|
498 |
-
|
499 |
-
|
|
|
|
|
500 |
zscale = torch.ones_like(token_ids)
|
501 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
502 |
zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
|
|
|
503 |
|
504 |
if mask is not None:
|
505 |
-
mask = mask[:q.shape[
|
506 |
-
|
|
|
507 |
qk = qk * zscale.unsqueeze(-2)
|
|
|
508 |
if return_attn:
|
509 |
-
|
|
|
|
|
510 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
511 |
-
wv =
|
|
|
512 |
|
513 |
if "multihead" in self.debug and self._counter % 100 == 0:
|
514 |
print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
|
515 |
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
|
516 |
print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
|
517 |
-
self._counter += 1
|
|
|
518 |
return self.o(wv), qk.detach()
|
519 |
-
|
520 |
class FCGate(nn.Module):
|
521 |
def __init__(self, dims, dim):
|
522 |
super().__init__()
|
@@ -572,7 +523,7 @@ class CMGate(nn.Module):
|
|
572 |
self.integration = Linear(dims*3, dims)
|
573 |
|
574 |
def forward(self, x, features):
|
575 |
-
sfeat = features.get("spectrogram", x)
|
576 |
wfeat = features.get("waveform", x)
|
577 |
pfeat = features.get("pitch", x)
|
578 |
spec = self.sgate(x) * sfeat
|
@@ -582,18 +533,26 @@ class CMGate(nn.Module):
|
|
582 |
combined = torch.cat([spec, wave, pitch], dim=-1)
|
583 |
return self.integration(combined)
|
584 |
|
585 |
-
class Residual(nn.Module):
|
586 |
_seen = set()
|
587 |
def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [],
|
588 |
fgate=False, tgate=False, mgate=False, cgate=False,
|
589 |
memory_size=512, features=None):
|
590 |
super().__init__()
|
591 |
-
|
592 |
-
self._counter = 0
|
593 |
-
self.dropout = 0.01
|
594 |
self.dims = dims
|
595 |
self.head = head
|
|
|
596 |
self.head_dim = dims // head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
self.cross_attn = cross_attn
|
598 |
self.debug = debug
|
599 |
self.fgate = fgate
|
@@ -627,12 +586,13 @@ class Residual(nn.Module):
|
|
627 |
if not any([fgate, tgate, mgate, cgate]):
|
628 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
629 |
|
630 |
-
def forward(self, x, xa=None, mask=None, f0=None, mode=None):
|
631 |
-
|
|
|
632 |
|
633 |
if self.attnb and xa is not None:
|
634 |
-
cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None)[0]
|
635 |
-
blend = torch.sigmoid(
|
636 |
x = blend * x + (1 - blend) * cross
|
637 |
|
638 |
normx = self.lnc(x)
|
@@ -661,11 +621,12 @@ class Residual(nn.Module):
|
|
661 |
x = x + mlp_gate * mlp_out
|
662 |
else:
|
663 |
x = x + mlp_out
|
|
|
664 |
if "residual" in self.debug and self._counter % 100 == 0:
|
665 |
print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
|
666 |
self._counter += 1
|
667 |
return x
|
668 |
-
|
669 |
class PEncoder(nn.Module):
|
670 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act):
|
671 |
super().__init__()
|
@@ -681,7 +642,7 @@ class PEncoder(nn.Module):
|
|
681 |
Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
|
682 |
Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
|
683 |
|
684 |
-
def forward(self, x, f0=None):
|
685 |
x = self.encoder(x).permute(0, 2, 1)
|
686 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
687 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
@@ -710,7 +671,7 @@ class WEncoder(nn.Module):
|
|
710 |
self.positional = lambda length: sinusoids(length, dims)
|
711 |
self.norm = RMSNorm(dims)
|
712 |
|
713 |
-
def forward(self, x, f0=None):
|
714 |
x = self.downsample(x)
|
715 |
x = self.encoder(x)
|
716 |
x = x.permute(0, 2, 1)
|
@@ -737,32 +698,34 @@ class FEncoder(nn.Module):
|
|
737 |
self.norm = RMSNorm(dims)
|
738 |
self._norm = RMSNorm(dims)
|
739 |
|
740 |
-
def forward(self, x, f0=None):
|
741 |
x = self.encoder(x).permute(0, 2, 1)
|
742 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
743 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
744 |
x = self._norm(x)
|
745 |
return x
|
746 |
-
|
747 |
class AudioEncoder(nn.Module):
|
748 |
_seen = set()
|
749 |
def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str],
|
750 |
debug: List[str], f0_rotary: bool = False, act: str = "gelu"):
|
751 |
super(AudioEncoder, self).__init__()
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
self.debug = debug
|
754 |
-
self.features = features
|
755 |
self._counter = 0
|
|
|
|
|
756 |
self.dropout = 0.01
|
757 |
self.f0_rotary = f0_rotary
|
758 |
-
self.dims = dims
|
759 |
-
self.ctx = ctx
|
760 |
-
self.head = head
|
761 |
-
self.head_dim = dims // head
|
762 |
-
|
763 |
-
self.rope = rotary(
|
764 |
-
dims=self.head_dim,
|
765 |
-
debug=debug,)
|
766 |
|
767 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
|
768 |
act_fn = act_map.get(act, nn.GELU())
|
@@ -794,7 +757,7 @@ class AudioEncoder(nn.Module):
|
|
794 |
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
|
795 |
})
|
796 |
|
797 |
-
def forward(self, x, f0=None):
|
798 |
outputs = {}
|
799 |
if self.f0_rotary:
|
800 |
f0 = f0 if f0 is not None else x.get("pitch")
|
@@ -804,7 +767,7 @@ class AudioEncoder(nn.Module):
|
|
804 |
if y in x and y in self.blocks:
|
805 |
f = x[y]
|
806 |
for block in self.blocks[y]:
|
807 |
-
f = block(f, f0=f0)
|
808 |
outputs[y] = f
|
809 |
|
810 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
@@ -820,9 +783,19 @@ class TextDecoder(nn.Module):
|
|
820 |
features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False):
|
821 |
super(TextDecoder, self).__init__()
|
822 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
823 |
self._counter = 0
|
|
|
824 |
self.dropout = 0.01
|
825 |
-
self.debug = debug
|
826 |
self.sequential = sequential
|
827 |
self.features = features
|
828 |
self.f0_rotary = f0_rotary
|
@@ -841,13 +814,13 @@ class TextDecoder(nn.Module):
|
|
841 |
for _ in range(layer)]) for f in features})
|
842 |
|
843 |
self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
|
844 |
-
|
845 |
self.ln_dec = RMSNorm(dims)
|
846 |
|
847 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
848 |
self.register_buffer("mask", mask, persistent=False)
|
849 |
|
850 |
-
def forward(self, x, enc, order=None, f0=None) -> Tensor:
|
|
|
851 |
x = x.to(device)
|
852 |
if self.f0_rotary:
|
853 |
f0 = f0
|
@@ -859,18 +832,18 @@ class TextDecoder(nn.Module):
|
|
859 |
x = self.token(x) + self.positional[:x.shape[1]]
|
860 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
861 |
for block in self._blocks:
|
862 |
-
x = block(x, f0=f0, mask=mask)
|
863 |
for f in order:
|
864 |
if f in enc:
|
865 |
xa = enc[f]
|
866 |
for block in self.blocks[f]:
|
867 |
-
out = block(x=x, xa=xa, f0=f0, mask=None)
|
868 |
-
a = torch.sigmoid(
|
869 |
x = a * out + (1 - a) * x
|
870 |
x = self.ln_dec(x)
|
871 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
872 |
|
873 |
-
class Echo(nn.Module):
|
874 |
def __init__(self, param: Dimensions):
|
875 |
super().__init__()
|
876 |
self.param = param
|
@@ -1054,7 +1027,7 @@ class DataCollator:
|
|
1054 |
|
1055 |
if "label" in features[0] and features[0]["label"] is not None:
|
1056 |
labels_list = [f["label"] for f in features]
|
1057 |
-
max_len = max(len(l) for l in labels_list)
|
1058 |
all_ids = []
|
1059 |
all_labels = []
|
1060 |
|
@@ -1230,6 +1203,8 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
|
|
1230 |
batch["envelope"] = torch.stack(envelope_list)
|
1231 |
batch["phase"] = torch.stack(phase_list)
|
1232 |
|
|
|
|
|
1233 |
wav_1d = wav.unsqueeze(0)
|
1234 |
|
1235 |
if waveforms:
|
@@ -1271,6 +1246,7 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
|
|
1271 |
f0, t = pw.dio(wav_np, sampling_rate,
|
1272 |
frame_period=hop_length/sampling_rate*1000)
|
1273 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
|
|
1274 |
batch["f0"] = torch.from_numpy(f0).float()
|
1275 |
|
1276 |
if spectrogram and waveforms and pitch:
|
@@ -1287,6 +1263,7 @@ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, f0=False,
|
|
1287 |
pitch_max = 600.0
|
1288 |
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
|
1289 |
|
|
|
1290 |
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
|
1291 |
return batch
|
1292 |
|
@@ -1356,6 +1333,7 @@ def compute_metrics(eval_pred, compute_result: bool = True,
|
|
1356 |
}
|
1357 |
|
1358 |
print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}")
|
|
|
1359 |
return metrics
|
1360 |
|
1361 |
logger = logging.getLogger(__name__)
|
@@ -1420,12 +1398,12 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
|
|
1420 |
"en_us",
|
1421 |
token=token,
|
1422 |
trust_remote_code=True,
|
1423 |
-
streaming=False
|
1424 |
-
|
1425 |
-
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000))
|
1426 |
|
1427 |
if sanity_check:
|
1428 |
-
dataset = dataset["test"].take(10)
|
1429 |
dataset = dataset.select_columns(["audio", "transcription"])
|
1430 |
logger.info(f"Sanity dataset size: {dataset.num_rows}")
|
1431 |
print(f"Sanity dataset size: {dataset.num_rows}")
|
@@ -1443,13 +1421,14 @@ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_
|
|
1443 |
len(x["audio"]["array"]) > 0 and
|
1444 |
len(x["audio"]["array"]) < 1500 * 160)
|
1445 |
|
1446 |
-
dataset = dataset.filter(filter_func).shuffle()
|
1447 |
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
1448 |
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
1449 |
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
|
1450 |
-
train_dataset = dataset["train"]
|
1451 |
-
test_dataset = dataset["test"]
|
1452 |
columns_to_remove = list(next(iter(dataset.values())).features)
|
|
|
|
|
|
|
1453 |
|
1454 |
train_dataset = train_dataset.map(
|
1455 |
function=prepare_fn,
|
@@ -1468,10 +1447,10 @@ def get_training_args(
|
|
1468 |
batch_eval_metrics: bool = False,
|
1469 |
max_steps: int = 10,
|
1470 |
save_steps: int = 1000,
|
1471 |
-
eval_steps: int =
|
1472 |
warmup_steps: int = 0,
|
1473 |
num_train_epochs: int = 1,
|
1474 |
-
logging_steps: int =
|
1475 |
eval_on_start: bool = False,
|
1476 |
learning_rate: float = 1e-4,
|
1477 |
weight_decay: float = 0.01,
|
@@ -1530,7 +1509,7 @@ def main():
|
|
1530 |
eval_steps = 1,
|
1531 |
warmup_steps = 0,
|
1532 |
logging_steps = 1,
|
1533 |
-
eval_on_start =
|
1534 |
learning_rate = 5e-6,
|
1535 |
weight_decay = 0.01,
|
1536 |
)
|
@@ -1538,11 +1517,11 @@ def main():
|
|
1538 |
training_args = get_training_args(
|
1539 |
log_dir,
|
1540 |
batch_eval_metrics = False,
|
1541 |
-
max_steps =
|
1542 |
-
save_steps =
|
1543 |
-
eval_steps =
|
1544 |
-
warmup_steps =
|
1545 |
-
logging_steps =
|
1546 |
eval_on_start = False,
|
1547 |
learning_rate = 2.5e-4,
|
1548 |
weight_decay = 0.01,
|
@@ -1562,22 +1541,20 @@ def main():
|
|
1562 |
text_dims=512,
|
1563 |
text_idx=4,
|
1564 |
act="swish",
|
1565 |
-
debug={}
|
1566 |
cross_attn=True,
|
1567 |
-
f0_rotary=
|
1568 |
-
features = ["spectrogram"], # ["spectrogram", "waveform", "pitch"]
|
1569 |
-
)
|
1570 |
|
1571 |
sanity_check = False
|
1572 |
-
|
1573 |
training_args = sanity(sanity_check)
|
1574 |
-
|
1575 |
dataset_config = {
|
1576 |
"spectrogram": True,
|
1577 |
"waveforms": False,
|
1578 |
"pitch": False,
|
1579 |
"downsamples": False,
|
1580 |
-
"f0":
|
1581 |
"hilbert": False,
|
1582 |
"hop_length": 128,
|
1583 |
"fmin": 150,
|
@@ -1608,7 +1585,6 @@ def main():
|
|
1608 |
sanity_check=sanity_check,
|
1609 |
dataset_config=dataset_config)
|
1610 |
|
1611 |
-
|
1612 |
trainer = Seq2SeqTrainer(
|
1613 |
args=training_args,
|
1614 |
model=model,
|
@@ -1617,10 +1593,9 @@ def main():
|
|
1617 |
data_collator=DataCollator(tokenizer=tokenizer),
|
1618 |
compute_metrics=metrics_fn,
|
1619 |
)
|
1620 |
-
|
1621 |
trainer.train()
|
1622 |
|
1623 |
if __name__ == "__main__":
|
1624 |
main()
|
1625 |
|
1626 |
-
|
|
|
|
|
1 |
import pyworld as pw
|
2 |
import os
|
3 |
import math
|
|
|
5 |
import logging
|
6 |
import gzip
|
7 |
import base64
|
8 |
+
from einops import rearrange, repeat
|
9 |
import torch
|
10 |
import torchaudio
|
11 |
import torchcrepe
|
|
|
16 |
from typing import Optional, Dict, Union, List, Tuple, Any
|
17 |
from functools import partial
|
18 |
from datetime import datetime
|
19 |
+
from datasets import load_dataset, Audio, concatenate_datasets
|
20 |
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
21 |
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
22 |
import transformers
|
23 |
import evaluate
|
24 |
from dataclasses import dataclass
|
25 |
+
from math import pi, log
|
26 |
|
27 |
torch.backends.cudnn.allow_tf32 = True
|
28 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
37 |
logging.basicConfig(level=logging.ERROR)
|
38 |
tox = {"device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), "dtype": torch.float32}
|
39 |
|
40 |
+
# %xmode Minimal
|
41 |
+
# %xmode Plain
|
42 |
+
# %xmode Context
|
43 |
+
# %xmode Verbose
|
44 |
+
|
45 |
extractor = None
|
46 |
tokenizer = None
|
47 |
optimizer = None
|
|
|
67 |
cross_attn: bool
|
68 |
features: List[str]
|
69 |
f0_rotary: bool
|
70 |
+
|
71 |
def exists(v):
|
72 |
return v is not None
|
73 |
|
|
|
105 |
self.eps = eps
|
106 |
self.elementwise_affine = elementwise_affine
|
107 |
if self.elementwise_affine:
|
108 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
|
109 |
init.ones_(self.weight)
|
110 |
else:
|
111 |
self.register_parameter("weight", None)
|
|
|
134 |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
135 |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
136 |
|
137 |
+
|
138 |
class rotary(nn.Module):
|
139 |
_seen = set()
|
140 |
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, radii=False,
|
|
|
151 |
self.dims = dims
|
152 |
self.max_ctx = max_ctx
|
153 |
self.radii = radii
|
154 |
+
max_freq = 10.0
|
155 |
+
f0_scale_factor = 0.5
|
156 |
+
self.learned_adaptation: bool = False
|
157 |
+
pitch_scale = 1.0 # theta_rescale = 1.0
|
158 |
|
159 |
+
if self.learned_adaptation:
|
160 |
+
self.f0_scale = nn.Parameter(torch.tensor(f0_scale_factor))
|
161 |
+
else:
|
162 |
+
self.register_buffer('f0_scale', torch.tensor(f0_scale_factor))
|
163 |
|
164 |
+
self.theta = nn.Parameter(torch.tensor(float(theta)), requires_grad=learned_theta)
|
|
|
165 |
|
166 |
+
self.pitch_scale = nn.Parameter(torch.tensor(pitch_scale), requires_grad=learned_pitch)
|
|
|
167 |
|
168 |
freqs = 1. / (theta ** (torch.arange(0, dims, 2)[:(dims // 2)].float() / dims))
|
169 |
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
170 |
|
171 |
if radii:
|
172 |
+
self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius)
|
|
|
173 |
|
174 |
def get_pitch_bias(self, f0):
|
175 |
if f0 is None:
|
176 |
return None
|
|
|
177 |
f0_flat = f0.squeeze().float()
|
178 |
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
179 |
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
|
|
199 |
def align_f0(self, f0, token_length):
|
200 |
batch_size, f0_length = f0.shape
|
201 |
if f0_length == token_length:
|
202 |
+
return f0
|
203 |
frames_per_token = f0_length / token_length
|
|
|
204 |
indices = torch.arange(token_length, device=f0.device)
|
205 |
indices = (indices * frames_per_token).long()#.clamp(max=f0_length-1)
|
206 |
+
batch_indices = torch.arange(batch_size, device=f0.device).unsqueeze(1) #center_positions = ((indices + 0.5) * frames_per_token).long()
|
|
|
207 |
return f0[batch_indices, indices.unsqueeze(0).expand(batch_size, -1)]
|
208 |
|
209 |
def scale_f0(self, f0):
|
210 |
f0_min = f0.min(dim=1, keepdim=True)[0]
|
211 |
f0_max = f0.max(dim=1, keepdim=True)[0]
|
212 |
denom = f0_max - f0_min + 1e-8
|
213 |
+
normalized_f0 = (f0 - f0_min) / denom # normalized_f0 = (f0 - f0_min) / (f0_max - f0_min)
|
|
|
214 |
normalized_f0 = torch.clamp(normalized_f0, 0.0, 1.0)
|
215 |
return normalized_f0
|
216 |
|
|
|
232 |
log_freq = torch.log(freq)
|
233 |
log_min_freq = torch.log(min_freq)
|
234 |
log_max_freq = torch.log(max_freq)
|
|
|
235 |
mapped_log_freq = ((log_freq - log_min_freq) / (log_max_freq - log_min_freq)) * torch.log(torch.tensor(target_max, device=self.device))
|
236 |
return mapped_log_freq
|
237 |
|
238 |
+
def get_f0_adapted_freqs(self, ctx, f0=None):
|
239 |
+
f0_min: float = 80.0, # Typical human voice low range
|
240 |
+
f0_max: float = 500.0, # Typical human voice high range
|
241 |
+
base_freq: float = 1.0,
|
242 |
+
positions = torch.arange(ctx, device=device, dtype=torch.float)
|
243 |
+
freqs = base_freq.clone()
|
244 |
+
if f0 is not None:
|
245 |
+
f0_norm = torch.clamp((f0 - f0_min) / (f0_max - f0_min), 0.0, 1.0)
|
246 |
+
freq_mod = torch.pow(torch.linspace(0.5, 1.5, self.dims//2, device=device),
|
247 |
+
f0_norm.unsqueeze(-1) * self.f0_scale)
|
248 |
+
freqs = freqs * freq_mod
|
249 |
+
freqs = torch.outer(positions, freqs)
|
250 |
+
return torch.polar(torch.ones_like(freqs), freqs)
|
251 |
+
|
252 |
def forward(self, x=None, f0=None, stage=None) -> Tensor:
|
253 |
if isinstance(x, int):
|
254 |
+
ctx = x
|
255 |
else:
|
256 |
+
batch, ctx, dims = x.shape
|
257 |
+
t = torch.arange(ctx, device=self.device).float()
|
258 |
+
|
259 |
+
if self.learned_adaptation:
|
260 |
+
freqs = self.get_f0_adapted_freqs(ctx, f0)
|
261 |
+
x_complex = torch.view_as_complex(
|
262 |
+
x.float().reshape(*x.shape[:-1], -1, 2).contiguous())
|
263 |
+
x_rotated = x_complex * freqs.unsqueeze(0).unsqueeze(0)
|
264 |
+
freqs = torch.view_as_real(x_rotated).flatten(3).type_as(x)
|
265 |
|
266 |
if f0 is not None:
|
267 |
+
f0_mean=f0.mean()+1e-8
|
268 |
+
pitch_scale=self.pitch_scale
|
269 |
+
theta=f0_mean*pitch_scale # f0_theta = f0_mean #(f0_mean * 1e-2 + 1.0) * (theta / 2 )
|
270 |
+
freqs = 1.0 / (theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
271 |
else:
|
272 |
freqs = self.freqs
|
273 |
|
274 |
freqs = torch.einsum('i,j->ij', t, freqs)
|
275 |
freqs = freqs.float()
|
276 |
|
277 |
+
if self.radii and f0 is not None:
|
278 |
+
|
279 |
+
radius = self.align_f0(f0, ctx)
|
280 |
# radius = self.scale_f0(radius)
|
281 |
radius = F.softplus(self.radius) * radius
|
282 |
+
# radius = radius.unsqueeze(-1) # Ensure radius is of shape (batch, ctx, dims//2)
|
283 |
freqs = torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
|
284 |
else:
|
285 |
freqs = torch.polar(torch.ones_like(freqs), freqs.unsqueeze(0))
|
286 |
# print(f"Step {self._counter}: Block: {stage}: Radius: {radius}")
|
287 |
if "rotary" in self.debug:
|
288 |
if f0 is not None:
|
289 |
+
key = f"{self._counter}_{theta:.2f}"
|
290 |
if key not in rotary._seen:
|
291 |
if not hasattr(self, '_prev_f0_theta'):
|
292 |
+
self._prev_f0_theta = theta
|
293 |
+
print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
|
294 |
+
elif abs(self._prev_f0_theta - theta) > 200.0:
|
295 |
+
print(f"Step {self._counter}: Using raw F0 as theta: {theta:.2f} Hz")
|
296 |
+
self._prev_f0_theta = theta
|
297 |
rotary._seen.add(key)
|
298 |
self._counter += 1
|
299 |
return freqs
|
|
|
328 |
x1 = torch.view_as_real(x1).flatten(-2)
|
329 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
def optim_attn(q, k, v, mask=None, scale=None, pad_token=0, fzero_val=0.0001):
|
332 |
|
333 |
batch, heads, ctx, dims = q.shape
|
|
|
350 |
class MultiheadA(nn.Module):
|
351 |
_seen = set()
|
352 |
rbf = False
|
353 |
+
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
|
354 |
zero_val: float = 0.0001, minz: float = 0.0, maxz: float = 0.001, debug: List[str] = [], optim_attn=False):
|
355 |
|
356 |
super(MultiheadA, self).__init__()
|
357 |
|
|
|
|
|
358 |
self.dims = dims
|
359 |
self.head = head
|
360 |
self.head_dim = dims // head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
self.q = Linear(dims, dims)
|
363 |
self.k = Linear(dims, dims, bias=False)
|
364 |
self.v = Linear(dims, dims)
|
365 |
self.o = Linear(dims, dims)
|
366 |
+
|
367 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
368 |
+
dtype = torch.float32
|
369 |
+
self.device = device
|
370 |
+
self.dtype = dtype
|
371 |
+
self.debug = debug
|
372 |
+
self._counter = 0
|
373 |
+
|
374 |
+
self.pad_token = 0
|
375 |
+
self.rotary_emb = rotary_emb
|
376 |
+
self.minz = minz
|
377 |
+
self.maxz = maxz
|
378 |
+
self.zero_val = zero_val
|
379 |
+
self.optim_attn = optim_attn
|
380 |
self.fzero = nn.Parameter(torch.tensor(zero_val, dtype=torch.float32), requires_grad=True)
|
381 |
|
382 |
if rotary_emb:
|
383 |
self.rope = rotary(
|
384 |
dims=self.head_dim,
|
385 |
+
debug=debug,
|
386 |
+
radii=False,
|
387 |
+
learned_pitch=False,
|
388 |
+
learned_freq=False,
|
389 |
+
learned_theta=False,
|
390 |
+
learned_radius=False,
|
391 |
+
)
|
392 |
else:
|
393 |
self.rope = None
|
394 |
|
395 |
def enhanced_attention_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
|
396 |
+
scale = self.head_dim ** -0.25
|
397 |
dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
|
398 |
if rbf_ratio <= 0.0:
|
399 |
return dot_scores
|
|
|
405 |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
406 |
|
407 |
def forward(self, x: Tensor, xa: Tensor = None, mask: Tensor = None,
|
408 |
+
return_attn: bool = False, f0: Tensor = None, stage=None) -> tuple:
|
409 |
|
410 |
batch, ctx, dims = x.shape
|
411 |
+
scale = self.head_dim ** -0.25
|
412 |
|
413 |
z = default(xa, x)
|
414 |
q = self.q(x).to(x.dtype)
|
415 |
k = self.k(z).to(x.dtype)
|
416 |
v = self.v(z).to(x.dtype)
|
417 |
|
418 |
+
if self.rotary_emb:
|
419 |
+
qf = self.rope(q.size(1), f0=f0, stage=stage)
|
420 |
+
kf = self.rope(k.size(1), f0=f0, stage=stage)
|
421 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
q = self.rope.apply_rotary(q, qf)
|
423 |
+
k = self.rope.apply_rotary(k, kf)
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.head), [q, k, v]) #[batch_size, sequence_length, model_dimension]
|
426 |
+
|
427 |
if self.optim_attn and not return_attn:
|
428 |
wv = optim_attn(q * scale, k * scale, v, mask=mask,
|
429 |
pad_token=self.pad_token, fzero_val=torch.clamp(F.softplus(self.fzero), self.minz, self.maxz).item())
|
|
|
431 |
|
432 |
if self.rbf:
|
433 |
qk = self.enhanced_attention_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
|
434 |
+
else:
|
435 |
+
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
436 |
|
437 |
+
if f0 is not None and self.rotary_emb and self.rope.use_pbias:
|
|
|
438 |
pbias = self.rope.pbias(f0)
|
439 |
if pbias is not None:
|
440 |
+
pbias = rearrange(pbias, 'b h i j -> (b h) i j') # batch * head, sequence_length, sequence_length
|
441 |
+
qk = qk + pbias[:, :q.shape[1], :q.shape[1]]
|
442 |
+
|
443 |
+
token_ids = rearrange(k, '(b h) n d -> b h n d', b=batch)[:, :, :, 0]
|
444 |
zscale = torch.ones_like(token_ids)
|
445 |
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
|
446 |
zscale[token_ids.float() == self.pad_token] = fzero.to(q.device, q.dtype)
|
447 |
+
zscale = rearrange(zscale, 'b h n -> (b h) n')
|
448 |
|
449 |
if mask is not None:
|
450 |
+
mask = mask[:q.shape[1], :q.shape[1]]
|
451 |
+
expanded_mask = mask.unsqueeze(0).expand(batch * self.head, -1, -1)
|
452 |
+
qk = qk + expanded_mask * zscale.unsqueeze(-2)
|
453 |
qk = qk * zscale.unsqueeze(-2)
|
454 |
+
|
455 |
if return_attn:
|
456 |
+
v_reshaped = rearrange(v, '(b h) n d -> b h n d', b=batch) # (batch_size, head, sequence_length, head_dim)
|
457 |
+
return qk, v_reshaped
|
458 |
+
|
459 |
w = F.softmax(qk, dim=-1).to(q.dtype)
|
460 |
+
wv = w @ v
|
461 |
+
wv = rearrange(wv, '(b h) n d -> b n (h d)', b=batch, h=self.head) # (batch_size, sequence_length, model_dimension)
|
462 |
|
463 |
if "multihead" in self.debug and self._counter % 100 == 0:
|
464 |
print(f"Step {self._counter}: Using rotary embeddings: {self.rotary_emb}")
|
465 |
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape}")
|
466 |
print(f"Attention shape: {qk.shape}, wv shape: {wv.shape}")
|
467 |
+
self._counter += 1
|
468 |
+
|
469 |
return self.o(wv), qk.detach()
|
470 |
+
|
471 |
class FCGate(nn.Module):
|
472 |
def __init__(self, dims, dim):
|
473 |
super().__init__()
|
|
|
523 |
self.integration = Linear(dims*3, dims)
|
524 |
|
525 |
def forward(self, x, features):
|
526 |
+
sfeat = features.get("spectrogram", x) # Default to input if missing
|
527 |
wfeat = features.get("waveform", x)
|
528 |
pfeat = features.get("pitch", x)
|
529 |
spec = self.sgate(x) * sfeat
|
|
|
533 |
combined = torch.cat([spec, wave, pitch], dim=-1)
|
534 |
return self.integration(combined)
|
535 |
|
536 |
+
class Residual(nn.Module): # noqa: F811
|
537 |
_seen = set()
|
538 |
def __init__(self, dims: int, head: int, ctx, act, cross_attn=True, debug: List[str] = [],
|
539 |
fgate=False, tgate=False, mgate=False, cgate=False,
|
540 |
memory_size=512, features=None):
|
541 |
super().__init__()
|
542 |
+
|
|
|
|
|
543 |
self.dims = dims
|
544 |
self.head = head
|
545 |
+
self.ctx = ctx
|
546 |
self.head_dim = dims // head
|
547 |
+
|
548 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
549 |
+
dtype = torch.float32
|
550 |
+
self.device = device
|
551 |
+
self.dtype = dtype
|
552 |
+
self.debug = debug
|
553 |
+
self._counter = 0
|
554 |
+
|
555 |
+
self.dropout = 0.01
|
556 |
self.cross_attn = cross_attn
|
557 |
self.debug = debug
|
558 |
self.fgate = fgate
|
|
|
586 |
if not any([fgate, tgate, mgate, cgate]):
|
587 |
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
|
588 |
|
589 |
+
def forward(self, x, xa=None, mask=None, f0=None, mode=None, stage=None):
|
590 |
+
bln = self.blend
|
591 |
+
x = x + self.attna(self.lna(x), mask=mask, f0=f0, stage=stage)[0]
|
592 |
|
593 |
if self.attnb and xa is not None:
|
594 |
+
cross = self.attnb(self.lnb(x), xa, f0=f0, mask=None, stage=stage)[0]
|
595 |
+
blend = torch.sigmoid(bln)
|
596 |
x = blend * x + (1 - blend) * cross
|
597 |
|
598 |
normx = self.lnc(x)
|
|
|
621 |
x = x + mlp_gate * mlp_out
|
622 |
else:
|
623 |
x = x + mlp_out
|
624 |
+
|
625 |
if "residual" in self.debug and self._counter % 100 == 0:
|
626 |
print(f"Step {self._counter}: Residual block output shape: {x.shape}, xa shape: {xa.shape if xa is not None else None}")
|
627 |
self._counter += 1
|
628 |
return x
|
629 |
+
|
630 |
class PEncoder(nn.Module):
|
631 |
def __init__(self, input_dims, dims, head, layer, kernel_size, act):
|
632 |
super().__init__()
|
|
|
642 |
Conv1d(dims//4, dims//2, kernel_size=5, stride=4, padding=2), act_fn,
|
643 |
Conv1d(dims//2, dims, kernel_size=5, stride=5, padding=2),act_fn)
|
644 |
|
645 |
+
def forward(self, x, f0=None, stage=None):
|
646 |
x = self.encoder(x).permute(0, 2, 1)
|
647 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
648 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
|
|
671 |
self.positional = lambda length: sinusoids(length, dims)
|
672 |
self.norm = RMSNorm(dims)
|
673 |
|
674 |
+
def forward(self, x, f0=None, stage=None):
|
675 |
x = self.downsample(x)
|
676 |
x = self.encoder(x)
|
677 |
x = x.permute(0, 2, 1)
|
|
|
698 |
self.norm = RMSNorm(dims)
|
699 |
self._norm = RMSNorm(dims)
|
700 |
|
701 |
+
def forward(self, x, f0=None, stage=None):
|
702 |
x = self.encoder(x).permute(0, 2, 1)
|
703 |
x = x + self.positional(x.shape[1]).to(x.device, x.dtype)
|
704 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
705 |
x = self._norm(x)
|
706 |
return x
|
707 |
+
|
708 |
class AudioEncoder(nn.Module):
|
709 |
_seen = set()
|
710 |
def __init__(self, mels: int, layer: int, dims: int, head: int, ctx: int, features: List[str],
|
711 |
debug: List[str], f0_rotary: bool = False, act: str = "gelu"):
|
712 |
super(AudioEncoder, self).__init__()
|
713 |
+
|
714 |
+
self.dims = dims
|
715 |
+
self.head = head
|
716 |
+
self.ctx = ctx
|
717 |
+
self.head_dim = dims // head
|
718 |
+
|
719 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
720 |
+
dtype = torch.float32
|
721 |
+
self.device = device
|
722 |
+
self.dtype = dtype
|
723 |
self.debug = debug
|
|
|
724 |
self._counter = 0
|
725 |
+
|
726 |
+
self.features = features
|
727 |
self.dropout = 0.01
|
728 |
self.f0_rotary = f0_rotary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
|
730 |
act_map = {"gelu": nn.GELU(), "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(), "swish": nn.SiLU(), "tanhshrink": nn.Tanhshrink(), "softplus": nn.Softplus(), "softshrink": nn.Softshrink(), "leaky_relu": nn.LeakyReLU(), "elu": nn.ELU()}
|
731 |
act_fn = act_map.get(act, nn.GELU())
|
|
|
757 |
[Residual(dims=dims, head=head, ctx=ctx, act=act, debug=debug) for _ in range(layer)] if "spec_phase" in features else None),
|
758 |
})
|
759 |
|
760 |
+
def forward(self, x, f0=None, stage="encoder"):
|
761 |
outputs = {}
|
762 |
if self.f0_rotary:
|
763 |
f0 = f0 if f0 is not None else x.get("pitch")
|
|
|
767 |
if y in x and y in self.blocks:
|
768 |
f = x[y]
|
769 |
for block in self.blocks[y]:
|
770 |
+
f = block(f, f0=f0, stage=stage)
|
771 |
outputs[y] = f
|
772 |
|
773 |
if "encoder" in self.debug and self._counter % 100 == 0:
|
|
|
783 |
features: List[str], debug: List[str], f0_rotary: bool = False, sequential=False):
|
784 |
super(TextDecoder, self).__init__()
|
785 |
|
786 |
+
self.dims = dims
|
787 |
+
self.head = head
|
788 |
+
self.ctx = ctx
|
789 |
+
self.head_dim = dims // head
|
790 |
+
|
791 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
792 |
+
dtype = torch.float32
|
793 |
+
self.device = device
|
794 |
+
self.dtype = dtype
|
795 |
+
self.debug = debug
|
796 |
self._counter = 0
|
797 |
+
|
798 |
self.dropout = 0.01
|
|
|
799 |
self.sequential = sequential
|
800 |
self.features = features
|
801 |
self.f0_rotary = f0_rotary
|
|
|
814 |
for _ in range(layer)]) for f in features})
|
815 |
|
816 |
self.blend = nn.ParameterDict({f: nn.Parameter(torch.tensor(0.5)) for f in features})
|
|
|
817 |
self.ln_dec = RMSNorm(dims)
|
818 |
|
819 |
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
|
820 |
self.register_buffer("mask", mask, persistent=False)
|
821 |
|
822 |
+
def forward(self, x, enc, order=None, f0=None, stage='decoder') -> Tensor:
|
823 |
+
bln = self.blend
|
824 |
x = x.to(device)
|
825 |
if self.f0_rotary:
|
826 |
f0 = f0
|
|
|
832 |
x = self.token(x) + self.positional[:x.shape[1]]
|
833 |
x = F.dropout(x, p=self.dropout, training=self.training)
|
834 |
for block in self._blocks:
|
835 |
+
x = block(x, f0=f0, mask=mask, stage=stage)
|
836 |
for f in order:
|
837 |
if f in enc:
|
838 |
xa = enc[f]
|
839 |
for block in self.blocks[f]:
|
840 |
+
out = block(x=x, xa=xa, f0=f0, mask=None, stage=stage)
|
841 |
+
a = torch.sigmoid(bln[f])
|
842 |
x = a * out + (1 - a) * x
|
843 |
x = self.ln_dec(x)
|
844 |
return x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
|
845 |
|
846 |
+
class Echo(nn.Module): # Echo the research model
|
847 |
def __init__(self, param: Dimensions):
|
848 |
super().__init__()
|
849 |
self.param = param
|
|
|
1027 |
|
1028 |
if "label" in features[0] and features[0]["label"] is not None:
|
1029 |
labels_list = [f["label"] for f in features]
|
1030 |
+
max_len = max(len(l) for l in labels_list) # noqa: E741
|
1031 |
all_ids = []
|
1032 |
all_labels = []
|
1033 |
|
|
|
1203 |
batch["envelope"] = torch.stack(envelope_list)
|
1204 |
batch["phase"] = torch.stack(phase_list)
|
1205 |
|
1206 |
+
# batch["spec_envelope_freq"] = hilbert_transform_2d(spec, dim=-2)
|
1207 |
+
|
1208 |
wav_1d = wav.unsqueeze(0)
|
1209 |
|
1210 |
if waveforms:
|
|
|
1246 |
f0, t = pw.dio(wav_np, sampling_rate,
|
1247 |
frame_period=hop_length/sampling_rate*1000)
|
1248 |
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
1249 |
+
f0 = f0
|
1250 |
batch["f0"] = torch.from_numpy(f0).float()
|
1251 |
|
1252 |
if spectrogram and waveforms and pitch:
|
|
|
1263 |
pitch_max = 600.0
|
1264 |
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
|
1265 |
|
1266 |
+
# print(f"Spectrogram shape: {batch['spectrogram'].shape}, Waveform shape: {batch['waveform'].shape if 'waveform' in batch else None}, Pitch shape: {batch['pitch'].shape if 'pitch' in batch else None}, F0 shape: {batch['f0'].shape if 'f0' in batch else None}, Envelope shape: {batch['envelope'].shape if 'envelope' in batch else None}, Phase shape: {batch['phase'].shape if 'phase' in batch else None}")
|
1267 |
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
|
1268 |
return batch
|
1269 |
|
|
|
1333 |
}
|
1334 |
|
1335 |
print(f"Computed metrics: WER={wer:.2f}%, Params={trainable_params:.2f}M, Efficiency={efficiency_score:.4f}")
|
1336 |
+
|
1337 |
return metrics
|
1338 |
|
1339 |
logger = logging.getLogger(__name__)
|
|
|
1398 |
"en_us",
|
1399 |
token=token,
|
1400 |
trust_remote_code=True,
|
1401 |
+
streaming=False)
|
1402 |
+
|
1403 |
+
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
|
1404 |
|
1405 |
if sanity_check:
|
1406 |
+
dataset = dataset["test"].take(10)
|
1407 |
dataset = dataset.select_columns(["audio", "transcription"])
|
1408 |
logger.info(f"Sanity dataset size: {dataset.num_rows}")
|
1409 |
print(f"Sanity dataset size: {dataset.num_rows}")
|
|
|
1421 |
len(x["audio"]["array"]) > 0 and
|
1422 |
len(x["audio"]["array"]) < 1500 * 160)
|
1423 |
|
1424 |
+
dataset = dataset.filter(filter_func).shuffle(seed=4)
|
1425 |
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
1426 |
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
1427 |
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
|
|
|
|
|
1428 |
columns_to_remove = list(next(iter(dataset.values())).features)
|
1429 |
+
train_dataset = dataset["train"]
|
1430 |
+
test_dataset = dataset["test"].take(50) # Limit test set size for faster processing
|
1431 |
+
logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
|
1432 |
|
1433 |
train_dataset = train_dataset.map(
|
1434 |
function=prepare_fn,
|
|
|
1447 |
batch_eval_metrics: bool = False,
|
1448 |
max_steps: int = 10,
|
1449 |
save_steps: int = 1000,
|
1450 |
+
eval_steps: int = 1,
|
1451 |
warmup_steps: int = 0,
|
1452 |
num_train_epochs: int = 1,
|
1453 |
+
logging_steps: int = 1,
|
1454 |
eval_on_start: bool = False,
|
1455 |
learning_rate: float = 1e-4,
|
1456 |
weight_decay: float = 0.01,
|
|
|
1509 |
eval_steps = 1,
|
1510 |
warmup_steps = 0,
|
1511 |
logging_steps = 1,
|
1512 |
+
eval_on_start = False,
|
1513 |
learning_rate = 5e-6,
|
1514 |
weight_decay = 0.01,
|
1515 |
)
|
|
|
1517 |
training_args = get_training_args(
|
1518 |
log_dir,
|
1519 |
batch_eval_metrics = False,
|
1520 |
+
max_steps = 1000,
|
1521 |
+
save_steps = 1000,
|
1522 |
+
eval_steps = 100,
|
1523 |
+
warmup_steps = 100,
|
1524 |
+
logging_steps = 10,
|
1525 |
eval_on_start = False,
|
1526 |
learning_rate = 2.5e-4,
|
1527 |
weight_decay = 0.01,
|
|
|
1541 |
text_dims=512,
|
1542 |
text_idx=4,
|
1543 |
act="swish",
|
1544 |
+
debug={"rotary"},#{"encoder", "decoder", "residual", "rotary"},
|
1545 |
cross_attn=True,
|
1546 |
+
f0_rotary=False,
|
1547 |
+
features = ["spectrogram"], # ["spectrogram", "waveform", "pitch"]
|
1548 |
+
)# features = ["spectrogram", "spec_envelope", "spec_phase"],
|
1549 |
|
1550 |
sanity_check = False
|
|
|
1551 |
training_args = sanity(sanity_check)
|
|
|
1552 |
dataset_config = {
|
1553 |
"spectrogram": True,
|
1554 |
"waveforms": False,
|
1555 |
"pitch": False,
|
1556 |
"downsamples": False,
|
1557 |
+
"f0": False,
|
1558 |
"hilbert": False,
|
1559 |
"hop_length": 128,
|
1560 |
"fmin": 150,
|
|
|
1585 |
sanity_check=sanity_check,
|
1586 |
dataset_config=dataset_config)
|
1587 |
|
|
|
1588 |
trainer = Seq2SeqTrainer(
|
1589 |
args=training_args,
|
1590 |
model=model,
|
|
|
1593 |
data_collator=DataCollator(tokenizer=tokenizer),
|
1594 |
compute_metrics=metrics_fn,
|
1595 |
)
|
1596 |
+
|
1597 |
trainer.train()
|
1598 |
|
1599 |
if __name__ == "__main__":
|
1600 |
main()
|
1601 |
|
|