Upload echoutils.py
Browse files- echoutils.py +817 -56
echoutils.py
CHANGED
@@ -38,10 +38,37 @@ def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, de
|
|
38 |
else:
|
39 |
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
|
40 |
if padding_mask is not None:
|
41 |
-
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
|
42 |
mask = mask | (~padding_mask)
|
43 |
return mask
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def mask_win(text_ctx, aud_ctx):
|
46 |
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
|
47 |
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
|
@@ -93,18 +120,23 @@ def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
|
|
93 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
94 |
return out, None
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
108 |
|
109 |
def mel_scale_scalar(freq: float) -> float:
|
110 |
return 1127.0 * math.log(1.0 + freq / 700.0)
|
@@ -142,7 +174,7 @@ def track_xa(new_xa, operation=""):
|
|
142 |
current_id = id(new_xa)
|
143 |
if current_id != xa_id[0]:
|
144 |
print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
|
145 |
-
xa_id[0] = current_id
|
146 |
else:
|
147 |
print(f"xa REUSE: {current_id} in {operation}")
|
148 |
return new_xa
|
@@ -163,18 +195,6 @@ def get_activation(act: str) -> nn.Module:
|
|
163 |
}
|
164 |
return act_map.get(act, nn.GELU())
|
165 |
|
166 |
-
@dataclass
|
167 |
-
class Dimensions:
|
168 |
-
vocab: int
|
169 |
-
mels: int
|
170 |
-
ctx: int
|
171 |
-
dims: int
|
172 |
-
head: int
|
173 |
-
layer: int
|
174 |
-
act: str
|
175 |
-
debug: List[str]
|
176 |
-
features: List[str]
|
177 |
-
|
178 |
def get_generation_config(param):
|
179 |
return GenerationConfig( # type: ignore
|
180 |
max_length=param.text_ctx,
|
@@ -193,6 +213,350 @@ def get_generation_config(param):
|
|
193 |
use_cache=False,
|
194 |
return_timestamps=False)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
|
197 |
title="", markers=None, marker_labels=None,
|
198 |
show_voiced_regions=True, show_energy=False):
|
@@ -391,18 +755,99 @@ class Sinusoids(nn.Module):
|
|
391 |
features[:, 0::2] = torch.sin(position * div_term)
|
392 |
features[:, 1::2] = torch.cos(position* div_term)
|
393 |
self.register_buffer('sinusoid', tensor=features)
|
394 |
-
self.positional_embeddings = nn.Parameter(self.sinusoid.clone())
|
395 |
def forward(self, positions):
|
396 |
position_embeddings = self.positional_embeddings[positions]
|
397 |
return position_embeddings
|
398 |
|
399 |
def sinusoids(length, channels, max_tscale=10000):
|
400 |
assert channels % 2 == 0
|
401 |
-
log_tscale_increment =
|
402 |
-
inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2))
|
403 |
-
scaled_t = torch.arange(length
|
404 |
return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
|
405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
407 |
if isinstance(ids, torch.Tensor):
|
408 |
ids = ids.tolist()
|
@@ -413,7 +858,7 @@ def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
413 |
|
414 |
def setup_tokenizer(dir: str):
|
415 |
from tokenizers import Tokenizer
|
416 |
-
tokenizer = Tokenizer.from_file(f"{dir}
|
417 |
orig_encode = tokenizer.encode
|
418 |
orig_decode = tokenizer.decode
|
419 |
|
@@ -480,7 +925,33 @@ def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
|
|
480 |
ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
|
481 |
return sp_mel, ap_mel
|
482 |
|
483 |
-
def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
|
485 |
audio = batch["audio"]
|
486 |
sample_rate = audio["sampling_rate"]
|
@@ -488,43 +959,116 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
488 |
wav = load_wave(wave_data=audio, sample_rate=sample_rate)
|
489 |
|
490 |
spectrogram_config = {
|
491 |
-
"hop_length": 256,
|
492 |
-
"f_min": 150,
|
493 |
-
"f_max": 2000,
|
494 |
-
"n_mels": 128,
|
495 |
-
"n_fft": 1024,
|
496 |
"sample_rate": 16000,
|
497 |
-
"pad_mode": "constant",
|
498 |
-
"center": True,
|
499 |
-
"power": 1.0,
|
500 |
-
"window_fn": torch.hann_window,
|
501 |
-
"mel_scale": "htk",
|
502 |
-
"norm": None,
|
503 |
-
"normalized": False,
|
504 |
}
|
505 |
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
508 |
crepe_time = torch.from_numpy(time)
|
509 |
crepe_frequency = torch.from_numpy(frequency)
|
510 |
crepe_confidence = torch.from_numpy(confidence)
|
511 |
crepe_activation = torch.from_numpy(activation)
|
|
|
|
|
|
|
|
|
|
|
512 |
else:
|
513 |
crepe_time = None
|
514 |
crepe_frequency = None
|
515 |
crepe_confidence = None
|
516 |
crepe_activation = None
|
517 |
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
|
520 |
mel_spectrogram = transform(wav)
|
521 |
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
522 |
-
# spectrogram_tensor = mel_spectrogram.log10()
|
523 |
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
524 |
spectrogram_tensor = (log_mel + 4.0) / 4.0
|
525 |
spectrogram_tensor = torch.tensor(spectrogram_tensor)
|
526 |
-
|
527 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
if f0 or f0t or pitch or harmonics or aperiodics:
|
530 |
wavnp = wav.numpy().astype(np.float64)
|
@@ -548,7 +1092,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
548 |
end_idx = torch.searchsorted(t2, token_ends, side="right")
|
549 |
pitch_tok = torch.zeros(T, dtype=torch.float32)
|
550 |
for i in range(T):
|
551 |
-
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i])
|
552 |
segment = f0_np[lo:hi]
|
553 |
if mode == "mean":
|
554 |
pitch_tok[i] = segment.mean()
|
@@ -559,14 +1103,14 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
559 |
pitch_tok[pitch_tok < 100.0] = 0.0
|
560 |
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
|
561 |
f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
|
562 |
-
|
563 |
else:
|
564 |
f0t_tensor = None
|
565 |
|
566 |
if phase_mod:
|
567 |
tframe = torch.mean(t2[1:] - t2[:-1])
|
568 |
phi0 = 0.0
|
569 |
-
omega = 2 * torch.pi * f0_tensor
|
570 |
dphi = omega * tframe
|
571 |
phi = torch.cumsum(dphi, dim=0) + phi0
|
572 |
phase = torch.remainder(phi, 2 * torch.pi)
|
@@ -574,7 +1118,10 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
574 |
phase = None
|
575 |
|
576 |
if pitch:
|
577 |
-
p_tensor =
|
|
|
|
|
|
|
578 |
else:
|
579 |
p_tensor = None
|
580 |
|
@@ -596,9 +1143,27 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
596 |
else:
|
597 |
wave_tensor = None
|
598 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
if debug:
|
600 |
|
601 |
-
print(f"['f0']: {f0_tensor.shape if f0 else None}")
|
602 |
print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
|
603 |
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
|
604 |
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
|
@@ -607,10 +1172,11 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
607 |
print(f"['labels']: {len(labels) if labels else None}")
|
608 |
print(f"['phase']: {phase.shape if phase else None}")
|
609 |
print(f"['pitch']: {p_tensor.shape if pitch else None}")
|
610 |
-
print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
|
611 |
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
|
612 |
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
|
613 |
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
|
|
|
614 |
|
615 |
return {
|
616 |
"waveform": wave_tensor if waveform else None,
|
@@ -626,6 +1192,7 @@ def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t
|
|
626 |
"crepe_frequency": crepe_frequency if crepe else None,
|
627 |
"crepe_confidence": crepe_confidence if crepe else None,
|
628 |
"crepe_activation": crepe_activation if crepe else None,
|
|
|
629 |
}
|
630 |
|
631 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
@@ -646,6 +1213,7 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
|
|
646 |
"debug": False,
|
647 |
"phase_mod": False,
|
648 |
"crepe": False,
|
|
|
649 |
}
|
650 |
|
651 |
if load_saved:
|
@@ -703,6 +1271,199 @@ def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, st
|
|
703 |
|
704 |
return train_dataset, test_dataset
|
705 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
706 |
@dataclass
|
707 |
class DataCollator:
|
708 |
tokenizer: Any
|
@@ -734,7 +1495,7 @@ class DataCollator:
|
|
734 |
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
735 |
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
736 |
|
737 |
-
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation"]:
|
738 |
items = [f[key] for f in features if key in f]
|
739 |
items = [item for item in items if item is not None]
|
740 |
if not items:
|
|
|
38 |
else:
|
39 |
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
|
40 |
if padding_mask is not None:
|
41 |
+
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
|
42 |
mask = mask | (~padding_mask)
|
43 |
return mask
|
44 |
|
45 |
+
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
|
46 |
+
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
|
47 |
+
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
|
48 |
+
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
|
49 |
+
qk_cosine = qk_cosine + mask
|
50 |
+
weights = F.softmax(qk_cosine, dim=-1)
|
51 |
+
out = torch.matmul(weights, v)
|
52 |
+
return out
|
53 |
+
|
54 |
+
def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
|
55 |
+
dot_scores = torch.matmul(q, k.transpose(-1, -2))
|
56 |
+
if rbf_ratio <= 0.0:
|
57 |
+
return dot_scores
|
58 |
+
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
|
59 |
+
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
|
60 |
+
qk = torch.matmul(q, k.transpose(-1, -2))
|
61 |
+
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
|
62 |
+
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
|
63 |
+
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
|
64 |
+
|
65 |
+
def sliding_window_mask(q_len, k_len, window, device):
|
66 |
+
# mask[i, j] = 1 if j in [i-window+1, i], else 0
|
67 |
+
idxs = torch.arange(q_len, device=device).unsqueeze(1)
|
68 |
+
jdxs = torch.arange(k_len, device=device).unsqueeze(0)
|
69 |
+
mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
|
70 |
+
return mask.float() # shape: (q_len, k_len)
|
71 |
+
|
72 |
def mask_win(text_ctx, aud_ctx):
|
73 |
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
|
74 |
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
|
|
|
120 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
121 |
return out, None
|
122 |
|
123 |
+
class KVCache(nn.Module):
|
124 |
+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
|
125 |
+
super().__init__()
|
126 |
+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
127 |
+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
|
128 |
+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
|
129 |
+
|
130 |
+
def update(self, input_pos, k_val, v_val):
|
131 |
+
# input_pos: [S], k_val: [B, H, S, D]
|
132 |
+
assert input_pos.shape[0] == k_val.shape[2]
|
133 |
|
134 |
+
k_out = self.k_cache
|
135 |
+
v_out = self.v_cache
|
136 |
+
k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue]
|
137 |
+
v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue]
|
138 |
+
|
139 |
+
return k_out, v_out
|
140 |
|
141 |
def mel_scale_scalar(freq: float) -> float:
|
142 |
return 1127.0 * math.log(1.0 + freq / 700.0)
|
|
|
174 |
current_id = id(new_xa)
|
175 |
if current_id != xa_id[0]:
|
176 |
print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}")
|
177 |
+
xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue]
|
178 |
else:
|
179 |
print(f"xa REUSE: {current_id} in {operation}")
|
180 |
return new_xa
|
|
|
195 |
}
|
196 |
return act_map.get(act, nn.GELU())
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def get_generation_config(param):
|
199 |
return GenerationConfig( # type: ignore
|
200 |
max_length=param.text_ctx,
|
|
|
213 |
use_cache=False,
|
214 |
return_timestamps=False)
|
215 |
|
216 |
+
# class rotary(nn.Module):
|
217 |
+
# def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
|
218 |
+
|
219 |
+
# super(rotary, self).__init__()
|
220 |
+
# self.use_pbias = use_pbias
|
221 |
+
# self.dims = dims
|
222 |
+
# self.head = head
|
223 |
+
# self.head_dim = dims // head
|
224 |
+
# self.radii = radii
|
225 |
+
# self.debug = debug
|
226 |
+
# self.counter = 0
|
227 |
+
# self.last_theta = None
|
228 |
+
# self.axial = axial
|
229 |
+
|
230 |
+
# self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
|
231 |
+
# theta = (torch.tensor(10000, device=device, dtype=dtype))
|
232 |
+
# self.theta = nn.Parameter(theta, requires_grad=True)
|
233 |
+
# self.theta_values = []
|
234 |
+
|
235 |
+
# if axial and spec_shape is not None:
|
236 |
+
# time_frames, freq_bins = spec_shape
|
237 |
+
# self.time_frames = time_frames
|
238 |
+
# self.freq_bins = freq_bins
|
239 |
+
|
240 |
+
# time_theta = 50.0
|
241 |
+
# time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
|
242 |
+
# self.register_buffer('time_freqs', time_freqs)
|
243 |
+
|
244 |
+
# freq_theta = 100.0
|
245 |
+
# freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
|
246 |
+
# self.register_buffer('freq_freqs', freq_freqs)
|
247 |
+
|
248 |
+
# def pitch_bias(self, f0):
|
249 |
+
# if f0 is None:
|
250 |
+
# return None
|
251 |
+
# f0_flat = f0.squeeze().float()
|
252 |
+
# f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
|
253 |
+
# f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
|
254 |
+
# f0_norm.unsqueeze(1)))
|
255 |
+
# return f0_sim.unsqueeze(0).unsqueeze(0)
|
256 |
+
|
257 |
+
# def theta_freqs(self, theta):
|
258 |
+
# if theta.dim() == 0:
|
259 |
+
# theta = theta.unsqueeze(0)
|
260 |
+
# freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
|
261 |
+
# torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
|
262 |
+
# self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
|
263 |
+
# return freq
|
264 |
+
|
265 |
+
# def _apply_radii(self, freqs, f0, ctx):
|
266 |
+
# if self.radii and f0 is not None:
|
267 |
+
# radius = f0.to(device, dtype)
|
268 |
+
# L = radius.shape[0]
|
269 |
+
# if L != ctx:
|
270 |
+
# feature = L / ctx
|
271 |
+
# idx = torch.arange(ctx, device=f0.device)
|
272 |
+
# idx = (idx * feature).long().clamp(0, L - 1)
|
273 |
+
# radius = radius[idx]
|
274 |
+
# return torch.polar(radius.unsqueeze(-1), freqs), radius
|
275 |
+
# else:
|
276 |
+
# return torch.polar(radius.unsqueeze(-1), freqs), radius
|
277 |
+
# else:
|
278 |
+
# return torch.polar(torch.ones_like(freqs), freqs), None
|
279 |
+
|
280 |
+
# def check_f0(self, f0, f0t, ctx):
|
281 |
+
# if f0 is not None and f0.shape[1] == ctx:
|
282 |
+
# return f0
|
283 |
+
# elif f0t is not None and f0t.shape[1] == ctx:
|
284 |
+
# return f0t
|
285 |
+
# else:
|
286 |
+
# return None
|
287 |
+
|
288 |
+
# def axial_freqs(self, ctx):
|
289 |
+
# if not self.axial:
|
290 |
+
# return None
|
291 |
+
# time_frames = self.time_frames
|
292 |
+
# freq_bins = self.freq_bins
|
293 |
+
|
294 |
+
# t = torch.arange(ctx, device=device, dtype=dtype)
|
295 |
+
# t_x = (t % time_frames).float()
|
296 |
+
# t_y = torch.div(t, time_frames, rounding_mode='floor').float()
|
297 |
+
# freqs_x = torch.outer(t_x, self.time_freqs)
|
298 |
+
# freqs_y = torch.outer(t_y, self.freq_freqs)
|
299 |
+
# freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
300 |
+
# freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
301 |
+
# return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
302 |
+
|
303 |
+
# def forward(self, x=None, feats=None, feature=None, layer=None) -> Tensor:
|
304 |
+
# ctx=x
|
305 |
+
# f0 = feats.get("f0") if feats is not None else None
|
306 |
+
# f0t = feats.get("f0t") if feats is not None else None
|
307 |
+
|
308 |
+
# f0 = self.check_f0(f0, f0t, ctx)
|
309 |
+
# if f0 is not None:
|
310 |
+
# # if f0.dim() == 2:
|
311 |
+
# # f0 = f0.squeeze(0)
|
312 |
+
# theta = f0 + self.theta
|
313 |
+
# else:
|
314 |
+
# theta = self.theta
|
315 |
+
# freqs = self.theta_freqs(theta)
|
316 |
+
# t = torch.arange(ctx, device=device, dtype=dtype) # type: ignore
|
317 |
+
# freqs = t[:, None] * freqs
|
318 |
+
# freqs, radius = self._apply_radii(freqs, f0, ctx)
|
319 |
+
|
320 |
+
# if self.axial and feature == "spectrogram":
|
321 |
+
# freqs_2d = self.axial_freqs(ctx)
|
322 |
+
# if freqs_2d is not None:
|
323 |
+
# return freqs_2d.unsqueeze(0)
|
324 |
+
|
325 |
+
# if "radius" in self.debug and self.counter == 10:
|
326 |
+
# print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
327 |
+
# self.counter += 1
|
328 |
+
# return freqs.unsqueeze(0)
|
329 |
+
|
330 |
+
# @staticmethod
|
331 |
+
# def split(X: Tensor):
|
332 |
+
# half_dim = X.shape[-1] // 2
|
333 |
+
# return X[..., :half_dim], X[..., half_dim:]
|
334 |
+
|
335 |
+
# @staticmethod
|
336 |
+
# def apply_rotary(x, freqs):
|
337 |
+
# x1 = x[..., :freqs.shape[-1]*2]
|
338 |
+
# x2 = x[..., freqs.shape[-1]*2:]
|
339 |
+
# orig_shape = x1.shape
|
340 |
+
# if x1.ndim == 2:
|
341 |
+
# x1 = x1.unsqueeze(0)
|
342 |
+
# x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
343 |
+
# x1 = torch.view_as_complex(x1) * freqs
|
344 |
+
# x1 = torch.view_as_real(x1).flatten(-2)
|
345 |
+
# x1 = x1.view(orig_shape)
|
346 |
+
# return torch.cat([x1.type_as(x), x2], dim=-1)
|
347 |
+
|
348 |
+
|
349 |
+
# class feature_encoder(nn.Module):
|
350 |
+
# def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
|
351 |
+
# """
|
352 |
+
# Feature encoder for audio processing.
|
353 |
+
# """
|
354 |
+
# super().__init__()
|
355 |
+
|
356 |
+
# self.dims = dims
|
357 |
+
# self.head = head
|
358 |
+
# self.head_dim = dims // head
|
359 |
+
# self.dropout = 0.01
|
360 |
+
# self.use_rope = use_rope
|
361 |
+
# self.attend_feature = attend_feature
|
362 |
+
# self.target_length = target_length
|
363 |
+
# self.feature = feature
|
364 |
+
|
365 |
+
# self.debug = debug
|
366 |
+
# act_fn = get_activation(act)
|
367 |
+
|
368 |
+
# if self.attend_feature:
|
369 |
+
# self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
|
370 |
+
# self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
|
371 |
+
# else:
|
372 |
+
# self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
|
373 |
+
# self.mlp = None
|
374 |
+
|
375 |
+
# self.spectrogram = nn.Sequential(
|
376 |
+
# Conv1d(mels, dims, kernel_size=3), act_fn,
|
377 |
+
# Conv1d(dims, dims, kernel_size=3), act_fn,
|
378 |
+
# Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
|
379 |
+
|
380 |
+
# self.waveform = nn.Sequential(
|
381 |
+
# Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
|
382 |
+
# Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
|
383 |
+
# Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
|
384 |
+
|
385 |
+
# self.pitch = nn.Sequential(
|
386 |
+
# Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
|
387 |
+
# Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
|
388 |
+
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
389 |
+
|
390 |
+
# if use_rope:
|
391 |
+
# # if spec_shape is not None:
|
392 |
+
# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
393 |
+
# self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
|
394 |
+
# else:
|
395 |
+
# self.rope = None
|
396 |
+
# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
397 |
+
# self.norm = RMSNorm(dims)
|
398 |
+
|
399 |
+
# def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
|
400 |
+
# if isinstance(x, int):
|
401 |
+
# ctx = x
|
402 |
+
# elif isinstance(x, torch.Tensor):
|
403 |
+
# ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
|
404 |
+
# batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
|
405 |
+
|
406 |
+
# x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
407 |
+
# freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
|
408 |
+
# x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue]
|
409 |
+
# x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
410 |
+
# return x
|
411 |
+
|
412 |
+
# def mel_scalar(self, freq: float) -> float:
|
413 |
+
# return 1127.0 * math.log(1.0 + freq / 700.0)
|
414 |
+
|
415 |
+
# def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
|
416 |
+
# target_length = x.shape[1] if self.target_length is None else self.target_length
|
417 |
+
|
418 |
+
# if feature == "pitch":
|
419 |
+
# xp = x.clone()
|
420 |
+
# enc_dict = feats if feats is not None else {}
|
421 |
+
# enc_dict = dict(enc_dict)
|
422 |
+
# enc_dict["f0"] = xp
|
423 |
+
# # xp = self.mel_scalar(xp.mean())
|
424 |
+
# # print(f"Using pitch scalar: {xp}")
|
425 |
+
# # max_tscale = xp*300
|
426 |
+
# # print(f"Using max_tscale: {max_tscale}")
|
427 |
+
# feats = enc_dict
|
428 |
+
# if x.dim() == 2:
|
429 |
+
# x = x.unsqueeze(0)
|
430 |
+
# x = self.pitch(x).permute(0, 2, 1)
|
431 |
+
|
432 |
+
# if feature == "phase":
|
433 |
+
# if x.dim() == 2:
|
434 |
+
# x = x.unsqueeze(0)
|
435 |
+
# x = self.pitch(x).permute(0, 2, 1)
|
436 |
+
|
437 |
+
# if feature == "waveform":
|
438 |
+
# if x.dim() == 2:
|
439 |
+
# x = x.unsqueeze(0)
|
440 |
+
# x = self.waveform(x).permute(0, 2, 1)
|
441 |
+
# if target_length and x.shape[1] != self.target_length:
|
442 |
+
# x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
|
443 |
+
|
444 |
+
# if feature == "harmonics":
|
445 |
+
# if x.dim() == 2:
|
446 |
+
# x = x.unsqueeze(0)
|
447 |
+
# x = self.spectrogram(x).permute(0, 2, 1)
|
448 |
+
|
449 |
+
# if feature == "aperiodic":
|
450 |
+
# if x.dim() == 2:
|
451 |
+
# x = x.unsqueeze(0)
|
452 |
+
# x = self.spectrogram(x).permute(0, 2, 1)
|
453 |
+
|
454 |
+
# if feature == "spectrogram":
|
455 |
+
# if x.dim() == 2:
|
456 |
+
# x = x.unsqueeze(0)
|
457 |
+
# x = self.spectrogram(x).permute(0, 2, 1)
|
458 |
+
|
459 |
+
# if self.use_rope:
|
460 |
+
# x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
|
461 |
+
# x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
|
462 |
+
# else:
|
463 |
+
# max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
|
464 |
+
# x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
|
465 |
+
# x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
466 |
+
# x = self.norm(x)
|
467 |
+
|
468 |
+
# if self.attend_feature:
|
469 |
+
# xa = feats[feature] # pyright: ignore[reportOptionalSubscript]
|
470 |
+
# if xa is not None:
|
471 |
+
# q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
|
472 |
+
# out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
|
473 |
+
# x = x + out
|
474 |
+
|
475 |
+
# x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
476 |
+
# x = self.norm(x)
|
477 |
+
# return x
|
478 |
+
|
479 |
+
class OneShot(nn.Module):
|
480 |
+
def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
|
481 |
+
super().__init__()
|
482 |
+
if features is None:
|
483 |
+
features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
|
484 |
+
self.head = head
|
485 |
+
self.head_dim = dims // head
|
486 |
+
self.scale = 1.0 // len(features) if features else scale
|
487 |
+
|
488 |
+
self.q = Linear(dims, dims)
|
489 |
+
self.k = Linear(dims, dims)
|
490 |
+
|
491 |
+
def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
|
492 |
+
B, L, D = x.shape
|
493 |
+
K = xa.size(1)
|
494 |
+
q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
|
495 |
+
k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
|
496 |
+
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
|
497 |
+
return bias
|
498 |
+
|
499 |
+
class curiosity(nn.Module):
|
500 |
+
def __init__(self, d, h, bias=True):
|
501 |
+
super().__init__()
|
502 |
+
self.h = h
|
503 |
+
self.dh = d // h
|
504 |
+
self.qkv = nn.Linear(d, d * 3, bias=bias)
|
505 |
+
self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
|
506 |
+
self.o = nn.Linear(d, d, bias=bias)
|
507 |
+
self.g = nn.Parameter(torch.zeros(h))
|
508 |
+
|
509 |
+
def split(self, x):
|
510 |
+
b, t, _ = x.shape
|
511 |
+
return x.view(b, t, self.h, self.dh).transpose(1, 2)
|
512 |
+
|
513 |
+
def merge(self, x):
|
514 |
+
b, h, t, dh = x.shape
|
515 |
+
return x.transpose(1, 2).contiguous().view(b, t, h * dh)
|
516 |
+
|
517 |
+
def forward(self, x, xa, mask=None):
|
518 |
+
q, k, v = self.qkv(x).chunk(3, -1)
|
519 |
+
qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
|
520 |
+
q, k, v = map(self.split, (q, k, v))
|
521 |
+
qa, ka, va = map(self.split, (qa, ka, va))
|
522 |
+
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
|
523 |
+
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
|
524 |
+
if mask is not None: dots = dots.masked_fill(mask, -9e15)
|
525 |
+
p = dots.softmax(-1)
|
526 |
+
pa = dots_aux.softmax(-1)
|
527 |
+
h_main = p @ v
|
528 |
+
h_aux = pa @ va
|
529 |
+
g = torch.sigmoid(self.g).view(1, -1, 1, 1)
|
530 |
+
out = self.merge(h_main * (1 - g) + h_aux * g)
|
531 |
+
return self.o(out)
|
532 |
+
|
533 |
+
class PositionalEncoding(nn.Module):
|
534 |
+
def __init__(self, dims, ctx):
|
535 |
+
super(PositionalEncoding, self).__init__()
|
536 |
+
self.dims = dims
|
537 |
+
self.ctx = ctx
|
538 |
+
self.pe = self.get_positional_encoding(max_ctx=ctx)
|
539 |
+
|
540 |
+
def get_positional_encoding(self, max_ctx):
|
541 |
+
pe = torch.zeros(max_ctx, self.dims)
|
542 |
+
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
|
543 |
+
div_term = torch.exp(
|
544 |
+
torch.arange(0, self.dims, 2, dtype=torch.float32)
|
545 |
+
* (-math.log(10000.0) / self.dims)
|
546 |
+
)
|
547 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
548 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
549 |
+
pe = pe.unsqueeze(0)
|
550 |
+
return pe.to(device)
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
ctx = x.size(1)
|
554 |
+
pe = self.pe[:, :ctx, :]
|
555 |
+
x = x * math.sqrt(self.dims)
|
556 |
+
x = x + pe
|
557 |
+
return x
|
558 |
+
|
559 |
+
|
560 |
def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
|
561 |
title="", markers=None, marker_labels=None,
|
562 |
show_voiced_regions=True, show_energy=False):
|
|
|
755 |
features[:, 0::2] = torch.sin(position * div_term)
|
756 |
features[:, 1::2] = torch.cos(position* div_term)
|
757 |
self.register_buffer('sinusoid', tensor=features)
|
758 |
+
self.positional_embeddings = nn.Parameter(self.sinusoid.clone()) # type: ignore
|
759 |
def forward(self, positions):
|
760 |
position_embeddings = self.positional_embeddings[positions]
|
761 |
return position_embeddings
|
762 |
|
763 |
def sinusoids(length, channels, max_tscale=10000):
|
764 |
assert channels % 2 == 0
|
765 |
+
log_tscale_increment = torch.log(torch.tensor(float(max_tscale))) / (channels // 2 - 1)
|
766 |
+
inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2, device=device, dtype=torch.float32))
|
767 |
+
scaled_t = torch.arange(length, device=device, dtype=torch.float32).unsqueeze(1) * inv_tscales.unsqueeze(0)
|
768 |
return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
|
769 |
|
770 |
+
class SelfCriticalRL(nn.Module):
|
771 |
+
def __init__(self, model, tokenizer, reward_fn):
|
772 |
+
super().__init__()
|
773 |
+
self.model = model
|
774 |
+
self.tokenizer = tokenizer
|
775 |
+
self.reward_fn = reward_fn
|
776 |
+
|
777 |
+
def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
|
778 |
+
|
779 |
+
with torch.no_grad():
|
780 |
+
greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
|
781 |
+
greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
|
782 |
+
sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
|
783 |
+
sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
|
784 |
+
|
785 |
+
rewards = []
|
786 |
+
baseline = []
|
787 |
+
for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore
|
788 |
+
ref_text = self.tokenizer.decode(ref)
|
789 |
+
rewards.append(self.reward_fn(s, ref_text))
|
790 |
+
baseline.append(self.reward_fn(g, ref_text))
|
791 |
+
rewards = torch.tensor(rewards, device=device, dtype=torch.float)
|
792 |
+
baseline = torch.tensor(baseline, device=device, dtype=torch.float)
|
793 |
+
advantage = rewards - baseline
|
794 |
+
logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size]
|
795 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
796 |
+
log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
|
797 |
+
log_probs_sum = log_probs_seq.sum(dim=1)
|
798 |
+
loss = -(advantage * log_probs_sum).mean()
|
799 |
+
return loss
|
800 |
+
|
801 |
+
class SelfTrainingModule(nn.Module):
|
802 |
+
def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
|
803 |
+
super().__init__()
|
804 |
+
self.model = model
|
805 |
+
self.tokenizer = tokenizer
|
806 |
+
self.quality_fn = quality_fn
|
807 |
+
self.threshold = threshold
|
808 |
+
|
809 |
+
def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
|
810 |
+
with torch.no_grad():
|
811 |
+
pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
|
812 |
+
|
813 |
+
if self.quality_fn is not None:
|
814 |
+
quality_scores = self.quality_fn(pred_ids, self.model, features)
|
815 |
+
mask = quality_scores > self.threshold
|
816 |
+
pred_ids = pred_ids[mask]
|
817 |
+
return pred_ids
|
818 |
+
|
819 |
+
def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
|
820 |
+
pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
|
821 |
+
logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
|
822 |
+
loss = nn.functional.cross_entropy(
|
823 |
+
logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
|
824 |
+
return loss
|
825 |
+
|
826 |
+
def confidence_indicator(pred_ids, model, features):
|
827 |
+
with torch.no_grad():
|
828 |
+
logits = model(input_ids=pred_ids, **features)["logits"]
|
829 |
+
probs = torch.softmax(logits, dim=-1)
|
830 |
+
max_probs, _ = probs.max(dim=-1)
|
831 |
+
return max_probs.mean(dim=1)
|
832 |
+
|
833 |
+
def wer_reward(hyp, ref):
|
834 |
+
|
835 |
+
hyp_words = hyp.split()
|
836 |
+
ref_words = ref.split()
|
837 |
+
d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
|
838 |
+
for i in range(len(hyp_words)+1):
|
839 |
+
d[i][0] = i
|
840 |
+
for j in range(len(ref_words)+1):
|
841 |
+
d[0][j] = j
|
842 |
+
for i in range(1, len(hyp_words)+1):
|
843 |
+
for j in range(1, len(ref_words)+1):
|
844 |
+
if hyp_words[i-1] == ref_words[j-1]:
|
845 |
+
d[i][j] = d[i-1][j-1]
|
846 |
+
else:
|
847 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
848 |
+
wer = d[-1][-1] / max(1, len(ref_words))
|
849 |
+
return -wer # negative WER as reward
|
850 |
+
|
851 |
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
852 |
if isinstance(ids, torch.Tensor):
|
853 |
ids = ids.tolist()
|
|
|
858 |
|
859 |
def setup_tokenizer(dir: str):
|
860 |
from tokenizers import Tokenizer
|
861 |
+
tokenizer = Tokenizer.from_file(f"{dir}")
|
862 |
orig_encode = tokenizer.encode
|
863 |
orig_decode = tokenizer.decode
|
864 |
|
|
|
925 |
ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
|
926 |
return sp_mel, ap_mel
|
927 |
|
928 |
+
def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
|
929 |
+
|
930 |
+
# import torchaudio
|
931 |
+
# import torchaudio.functional
|
932 |
+
# import torchaudio.transforms
|
933 |
+
|
934 |
+
# torch_windows = {
|
935 |
+
# 'hann': torch.hann_window,
|
936 |
+
# 'hamming': torch.hamming_window,
|
937 |
+
# 'blackman': torch.blackman_window,
|
938 |
+
# 'bartlett': torch.bartlett_window,
|
939 |
+
# 'ones': torch.ones,
|
940 |
+
# None: torch.ones,
|
941 |
+
# }
|
942 |
+
# if dummy:
|
943 |
+
# return {
|
944 |
+
# "spectrogram": torch.zeros((1, 128, 100)),
|
945 |
+
# "f0": torch.zeros((1, 100)),
|
946 |
+
# "f0t": torch.zeros((1, 100)),
|
947 |
+
# "pitch": torch.zeros((1, 100)),
|
948 |
+
# "harmonics": torch.zeros((1, 128, 100)),
|
949 |
+
# "aperiodics": torch.zeros((1, 128, 100)),
|
950 |
+
# "crepe_time": None,
|
951 |
+
# "crepe_frequency": None,
|
952 |
+
# "crepe_confidence": None,
|
953 |
+
# "crepe_activation": None,
|
954 |
+
# }
|
955 |
|
956 |
audio = batch["audio"]
|
957 |
sample_rate = audio["sampling_rate"]
|
|
|
959 |
wav = load_wave(wave_data=audio, sample_rate=sample_rate)
|
960 |
|
961 |
spectrogram_config = {
|
962 |
+
# "hop_length": 256,
|
963 |
+
# "f_min": 150,
|
964 |
+
# "f_max": 2000,
|
965 |
+
# "n_mels": 128,
|
966 |
+
# "n_fft": 1024,
|
967 |
"sample_rate": 16000,
|
968 |
+
# "pad_mode": "constant",
|
969 |
+
# "center": True,
|
970 |
+
# "power": 1.0,
|
971 |
+
# "window_fn": torch.hann_window,
|
972 |
+
# "mel_scale": "htk",
|
973 |
+
# "norm": None,
|
974 |
+
# "normalized": False,
|
975 |
}
|
976 |
|
977 |
+
def crepe_predict(wav, sample_rate, viterbi=False):
|
978 |
+
import torchcrepe
|
979 |
+
wav = wav.numpy().astype(np.float32)
|
980 |
+
time, frequency, confidence, activation = torchcrepe.predict(
|
981 |
+
wav, sample_rate=sample_rate, viterbi=viterbi)
|
982 |
crepe_time = torch.from_numpy(time)
|
983 |
crepe_frequency = torch.from_numpy(frequency)
|
984 |
crepe_confidence = torch.from_numpy(confidence)
|
985 |
crepe_activation = torch.from_numpy(activation)
|
986 |
+
return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
|
987 |
+
|
988 |
+
if crepe:
|
989 |
+
crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
|
990 |
+
|
991 |
else:
|
992 |
crepe_time = None
|
993 |
crepe_frequency = None
|
994 |
crepe_confidence = None
|
995 |
crepe_activation = None
|
996 |
|
997 |
+
# def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
|
998 |
+
# if isinstance(window_fn, str):
|
999 |
+
# window_fn = torch_windows[window_fn]
|
1000 |
+
# if window_fn is None:
|
1001 |
+
# window_fn = torch.ones(n_fft)
|
1002 |
+
# if isinstance(window_fn, torch.Tensor):
|
1003 |
+
# window_fn = window_fn.to(device)
|
1004 |
+
# return torchaudio.functional.spectrogram(
|
1005 |
+
# wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
|
1006 |
+
# window=window_fn, center=True, pad_mode="reflect", power=1.0)
|
1007 |
+
|
1008 |
+
# def mel_spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
|
1009 |
+
# transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
|
1010 |
+
# mel_spectrogram = transform(wav)
|
1011 |
+
# log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
1012 |
+
# log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1013 |
+
# spectrogram_tensor = (log_mel + 4.0) / 4.0
|
1014 |
+
# spectrogram_tensor = torch.tensor(spectrogram_tensor)
|
1015 |
+
# return spectrogram_tensor
|
1016 |
+
if spec:
|
1017 |
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
|
1018 |
mel_spectrogram = transform(wav)
|
1019 |
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
|
|
1020 |
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
1021 |
spectrogram_tensor = (log_mel + 4.0) / 4.0
|
1022 |
spectrogram_tensor = torch.tensor(spectrogram_tensor)
|
1023 |
+
|
1024 |
+
|
1025 |
+
|
1026 |
+
# if spec:
|
1027 |
+
# if isinstance(wav, torch.Tensor):
|
1028 |
+
# wav = wav.to(device)
|
1029 |
+
# spectrogram_tensor = mel_spectrogram(wav, sample_rate, **spectrogram_config)
|
1030 |
+
# spectrogram_tensor = spectrogram_tensor.permute(1, 0)
|
1031 |
+
|
1032 |
+
|
1033 |
+
def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
|
1034 |
+
transform = torchaudio.transforms.MFCC(
|
1035 |
+
sample_rate=sample_rate,
|
1036 |
+
n_mfcc=n_mels,
|
1037 |
+
melkwargs={
|
1038 |
+
"n_fft": n_fft,
|
1039 |
+
"hop_length": hop_length,
|
1040 |
+
"window_fn": window_fn,
|
1041 |
+
"n_mels": n_mels,
|
1042 |
+
"center": True,
|
1043 |
+
"pad_mode": "reflect",
|
1044 |
+
"norm": None,
|
1045 |
+
"mel_scale": "htk",
|
1046 |
+
}
|
1047 |
+
)
|
1048 |
+
mfcc_tensor = transform(wav)
|
1049 |
+
return mfcc_tensor
|
1050 |
+
|
1051 |
+
|
1052 |
+
def compute_pitch(wav, sample_rate, hop_length=256):
|
1053 |
+
import pyworld as pw
|
1054 |
+
wav_np = wav.numpy().astype(np.float64)
|
1055 |
+
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000)
|
1056 |
+
f0 = pw.stonemask(wav_np, f0, t, sample_rate)
|
1057 |
+
return f0, t
|
1058 |
+
|
1059 |
+
def compute_harmonics_and_aperiodics(wav, f0, t, sample_rate):
|
1060 |
+
import pyworld as pw
|
1061 |
+
wav_np = wav.numpy().astype(np.float64)
|
1062 |
+
sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
|
1063 |
+
ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
|
1064 |
+
harmonic_tensor = torch.from_numpy(sp)
|
1065 |
+
aperiodic_tensor = torch.from_numpy(ap)
|
1066 |
+
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
|
1067 |
+
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
|
1068 |
+
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
|
1069 |
+
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
|
1070 |
+
return harmonic_tensor, aperiodic_tensor
|
1071 |
+
|
1072 |
|
1073 |
if f0 or f0t or pitch or harmonics or aperiodics:
|
1074 |
wavnp = wav.numpy().astype(np.float64)
|
|
|
1092 |
end_idx = torch.searchsorted(t2, token_ends, side="right")
|
1093 |
pitch_tok = torch.zeros(T, dtype=torch.float32)
|
1094 |
for i in range(T):
|
1095 |
+
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore
|
1096 |
segment = f0_np[lo:hi]
|
1097 |
if mode == "mean":
|
1098 |
pitch_tok[i] = segment.mean()
|
|
|
1103 |
pitch_tok[pitch_tok < 100.0] = 0.0
|
1104 |
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
|
1105 |
f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
|
1106 |
+
f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
|
1107 |
else:
|
1108 |
f0t_tensor = None
|
1109 |
|
1110 |
if phase_mod:
|
1111 |
tframe = torch.mean(t2[1:] - t2[:-1])
|
1112 |
phi0 = 0.0
|
1113 |
+
omega = 2 * torch.pi * f0_tensor # type: ignore
|
1114 |
dphi = omega * tframe
|
1115 |
phi = torch.cumsum(dphi, dim=0) + phi0
|
1116 |
phase = torch.remainder(phi, 2 * torch.pi)
|
|
|
1118 |
phase = None
|
1119 |
|
1120 |
if pitch:
|
1121 |
+
p_tensor = compute_pitch(wav, sample_rate, hop_length=hop_length)[0]
|
1122 |
+
p_tensor = torch.from_numpy(p_tensor)
|
1123 |
+
p_tensor = p_tensor.unsqueeze(0)
|
1124 |
+
# p_tensor = torch.from_numpy(f0_np)
|
1125 |
else:
|
1126 |
p_tensor = None
|
1127 |
|
|
|
1143 |
else:
|
1144 |
wave_tensor = None
|
1145 |
|
1146 |
+
if dummy:
|
1147 |
+
if spectrogram_tensor is not None:
|
1148 |
+
dummy_tensor = torch.ones_like(spectrogram_tensor)
|
1149 |
+
elif p_tensor is not None:
|
1150 |
+
dummy_tensor = torch.ones_like(p_tensor)
|
1151 |
+
elif f0_tensor is not None:
|
1152 |
+
dummy_tensor = torch.ones_like(f0_tensor)
|
1153 |
+
elif f0t_tensor is not None:
|
1154 |
+
dummy_tensor = torch.ones_like(f0t_tensor)
|
1155 |
+
else:
|
1156 |
+
batch_size = 128
|
1157 |
+
seq_len = 1024
|
1158 |
+
dummy_tensor = torch.ones(batch_size, seq_len)
|
1159 |
+
dummy_tensor = dummy_tensor.to(device)
|
1160 |
+
|
1161 |
+
else:
|
1162 |
+
dummy_tensor = None
|
1163 |
+
|
1164 |
if debug:
|
1165 |
|
1166 |
+
print(f"['f0']: {f0_tensor.shape if f0 else None}")
|
1167 |
print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
|
1168 |
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
|
1169 |
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
|
|
|
1172 |
print(f"['labels']: {len(labels) if labels else None}")
|
1173 |
print(f"['phase']: {phase.shape if phase else None}")
|
1174 |
print(f"['pitch']: {p_tensor.shape if pitch else None}")
|
1175 |
+
print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
|
1176 |
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
|
1177 |
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
|
1178 |
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
|
1179 |
+
print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
|
1180 |
|
1181 |
return {
|
1182 |
"waveform": wave_tensor if waveform else None,
|
|
|
1192 |
"crepe_frequency": crepe_frequency if crepe else None,
|
1193 |
"crepe_confidence": crepe_confidence if crepe else None,
|
1194 |
"crepe_activation": crepe_activation if crepe else None,
|
1195 |
+
"dummy": dummy_tensor if dummy else None,
|
1196 |
}
|
1197 |
|
1198 |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
|
|
|
1213 |
"debug": False,
|
1214 |
"phase_mod": False,
|
1215 |
"crepe": False,
|
1216 |
+
"dummy": False,
|
1217 |
}
|
1218 |
|
1219 |
if load_saved:
|
|
|
1271 |
|
1272 |
return train_dataset, test_dataset
|
1273 |
|
1274 |
+
def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
|
1275 |
+
if feature == "spectrogram":
|
1276 |
+
return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head, layer=layer, act=act, feature=feature, features=features)
|
1277 |
+
elif feature == "waveform":
|
1278 |
+
return WEncoder(input_dims, dims, head, layer, act, feature, features)
|
1279 |
+
elif feature == "pitch":
|
1280 |
+
return PEncoder(input_dims, dims, head, layer, act, feature, features)
|
1281 |
+
else:
|
1282 |
+
raise ValueError(f"Unknown feature type: {feature}")
|
1283 |
+
|
1284 |
+
class FEncoder(nn.Module):
|
1285 |
+
def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
|
1286 |
+
super().__init__()
|
1287 |
+
|
1288 |
+
self.head = head
|
1289 |
+
self.head_dim = dims // head
|
1290 |
+
self.dropout = 0.01
|
1291 |
+
self.use_rope = use_rope
|
1292 |
+
self.dims = dims
|
1293 |
+
self.debug = debug
|
1294 |
+
self.feature = feature
|
1295 |
+
self.mels = mels
|
1296 |
+
self.input_dims = input_dims
|
1297 |
+
act_fn = get_activation(act)
|
1298 |
+
|
1299 |
+
self.encoder = nn.Sequential(
|
1300 |
+
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
1301 |
+
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
1302 |
+
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
1303 |
+
|
1304 |
+
if use_rope:
|
1305 |
+
if spec_shape is not None:
|
1306 |
+
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore
|
1307 |
+
else:
|
1308 |
+
self.rope = None
|
1309 |
+
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
1310 |
+
self.norm = RMSNorm(dims)
|
1311 |
+
|
1312 |
+
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
|
1313 |
+
batch, ctx, dims = x.shape
|
1314 |
+
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1315 |
+
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
|
1316 |
+
x = self.rope.apply_rotary(x, freqs)# type: ignore
|
1317 |
+
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1318 |
+
|
1319 |
+
return x
|
1320 |
+
|
1321 |
+
def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
|
1322 |
+
x = self.encoder(x).permute(0, 2, 1)
|
1323 |
+
if self.use_rope:
|
1324 |
+
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
1325 |
+
else:
|
1326 |
+
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
1327 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
1328 |
+
print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
1329 |
+
x = self.norm(x)
|
1330 |
+
return x
|
1331 |
+
|
1332 |
+
class WEncoder(nn.Module): # waveform encoder
|
1333 |
+
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
|
1334 |
+
super().__init__()
|
1335 |
+
|
1336 |
+
self.head = head
|
1337 |
+
self.head_dim = dims // head
|
1338 |
+
self.dropout = 0.01
|
1339 |
+
self.use_rope = use_rope
|
1340 |
+
self.dims = dims
|
1341 |
+
self.debug = debug
|
1342 |
+
act_fn = get_activation(act)
|
1343 |
+
self.target_length = None
|
1344 |
+
self.encoder = nn.Sequential(
|
1345 |
+
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
|
1346 |
+
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
|
1347 |
+
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
|
1348 |
+
|
1349 |
+
if use_rope:
|
1350 |
+
if spec_shape is not None:
|
1351 |
+
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
|
1352 |
+
else:
|
1353 |
+
self.rope = None
|
1354 |
+
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
1355 |
+
self.norm = RMSNorm(dims)
|
1356 |
+
|
1357 |
+
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
|
1358 |
+
batch, ctx, dims = x.shape
|
1359 |
+
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1360 |
+
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
|
1361 |
+
x = self.rope.apply_rotary(x, freqs)# type: ignore
|
1362 |
+
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1363 |
+
return x
|
1364 |
+
|
1365 |
+
def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
|
1366 |
+
x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims)
|
1367 |
+
if self.target_length and x.shape[1] != self.target_length:
|
1368 |
+
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
|
1369 |
+
if self.use_rope:
|
1370 |
+
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
1371 |
+
else:
|
1372 |
+
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
1373 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
1374 |
+
print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
1375 |
+
return self.norm(x)
|
1376 |
+
|
1377 |
+
class PEncoder(nn.Module): # pitch encoder
|
1378 |
+
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
|
1379 |
+
super().__init__()
|
1380 |
+
|
1381 |
+
self.head = head
|
1382 |
+
self.head_dim = dims // head
|
1383 |
+
self.dims = dims
|
1384 |
+
self.dropout = 0.01
|
1385 |
+
self.use_rope = use_rope
|
1386 |
+
self.debug = debug
|
1387 |
+
act_fn = get_activation(act)
|
1388 |
+
|
1389 |
+
self.attend_pitch = False
|
1390 |
+
|
1391 |
+
if self.attend_pitch:
|
1392 |
+
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
|
1393 |
+
self.mlp = nn.Sequential(
|
1394 |
+
nn.Linear(dims, dims),
|
1395 |
+
nn.ReLU(),
|
1396 |
+
nn.Linear(dims, dims),
|
1397 |
+
)
|
1398 |
+
else:
|
1399 |
+
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
|
1400 |
+
self.mlp = None
|
1401 |
+
|
1402 |
+
self.pitch_encoder = nn.Sequential(
|
1403 |
+
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
|
1404 |
+
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
|
1405 |
+
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
1406 |
+
|
1407 |
+
# self.spectrogram_encoder = nn.Sequential(
|
1408 |
+
# Conv1d(input_dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
1409 |
+
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
1410 |
+
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
1411 |
+
|
1412 |
+
# self.waveform_encoder = nn.Sequential(
|
1413 |
+
# Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
|
1414 |
+
# Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
|
1415 |
+
# Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
|
1416 |
+
|
1417 |
+
if use_rope:
|
1418 |
+
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
|
1419 |
+
else:
|
1420 |
+
self.rope = None
|
1421 |
+
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
1422 |
+
self.norm = RMSNorm(dims)
|
1423 |
+
|
1424 |
+
def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
|
1425 |
+
batch, ctx, dims = x.shape
|
1426 |
+
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
|
1427 |
+
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore
|
1428 |
+
x = self.rope.apply_rotary(x, freqs)# type: ignore
|
1429 |
+
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
|
1430 |
+
return x
|
1431 |
+
|
1432 |
+
def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
|
1433 |
+
# f0=x
|
1434 |
+
# freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer)
|
1435 |
+
if x.dim() == 2:
|
1436 |
+
x = x.unsqueeze(0)
|
1437 |
+
if feature == "pitch":
|
1438 |
+
x = self.pitch_encoder(x).permute(0, 2, 1)
|
1439 |
+
# elif feature == "spectrogram":
|
1440 |
+
# x = self.spectrogram_encoder(x).permute(0, 2, 1)
|
1441 |
+
# elif feature == "waveform":
|
1442 |
+
# x = self.waveform_encoder(x).permute(0, 2, 1)
|
1443 |
+
|
1444 |
+
# if self.target_length and x.shape[1] != self.target_length:
|
1445 |
+
# x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
|
1446 |
+
|
1447 |
+
if self.use_rope:
|
1448 |
+
x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
|
1449 |
+
|
1450 |
+
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
|
1451 |
+
if self.mlp is not None:
|
1452 |
+
x = self.mlp(x)
|
1453 |
+
|
1454 |
+
if self.attend_pitch:
|
1455 |
+
if xa is not None:
|
1456 |
+
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
|
1457 |
+
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
|
1458 |
+
|
1459 |
+
x = x + out
|
1460 |
+
|
1461 |
+
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
1462 |
+
x = self.norm(x)
|
1463 |
+
print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
|
1464 |
+
return x
|
1465 |
+
|
1466 |
+
|
1467 |
@dataclass
|
1468 |
class DataCollator:
|
1469 |
tokenizer: Any
|
|
|
1495 |
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
1496 |
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
1497 |
|
1498 |
+
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
|
1499 |
items = [f[key] for f in features if key in f]
|
1500 |
items = [item for item in items if item is not None]
|
1501 |
if not items:
|