asr-model / echoutils.py
Sin2pi's picture
Upload echoutils.py
c05d8b0 verified
import torch
import os
import pyworld as pw
import numpy as np
import torchaudio
import torch.nn.functional as F
from datasets import load_dataset
from datasets import Audio
from dataclasses import dataclass
from typing import Any, List, Dict
import math
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
from typing import Any, List, Dict, Optional, Union, Tuple
from torch.nn.functional import scaled_dot_product_attention
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
# def shape(tensor: torch.Tensor, head: int, head_dim: int, batch: int, ctx: int):
# return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
# def reshape_to_output(attn_output, head: int, head_dim: int, batch: int, ctx: int, dims: int):
# return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, dims).contiguous()
def shape(self, tensor: torch.Tensor, ctx: int, batch: int):
return tensor.view(batch, ctx, self.head, self.head_dim).transpose(1, 2).contiguous()
def reshape_to_output(self, attn_output, batch, ctx):
return attn_output.permute(0, 2, 1, 3).reshape(batch, ctx, self.dims).contiguous()
def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None):
if is_causal:
mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0)
mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, ctx, ctx)
else:
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device)
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool()
mask = mask | (~padding_mask)
return mask
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
qk_cosine = qk_cosine + mask
weights = F.softmax(qk_cosine, dim=-1)
out = torch.matmul(weights, v)
return out
def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0):
dot_scores = torch.matmul(q, k.transpose(-1, -2))
if rbf_ratio <= 0.0:
return dot_scores
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
qk = torch.matmul(q, k.transpose(-1, -2))
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
def sliding_window_mask(q_len, k_len, window, device):
# mask[i, j] = 1 if j in [i-window+1, i], else 0
idxs = torch.arange(q_len, device=device).unsqueeze(1)
jdxs = torch.arange(k_len, device=device).unsqueeze(0)
mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs)
return mask.float() # shape: (q_len, k_len)
def mask_win(text_ctx, aud_ctx):
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0)
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype))
full_mask = torch.cat([mask, audio_mask], dim=-1)
return full_mask
def maskc(ctx, device):
return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0)
def qkv_init(dims: int, head: int):
head_dim = dims // head
scale = head_dim ** -0.5
q = nn.Linear(dims, dims)
k = nn.Linear(dims, dims, bias=False)
v = nn.Linear(dims, dims)
o = nn.Linear(dims, dims)
return q, k, v, o, scale
def create_qkv(q, k, v, x, xa=None, head=8):
head_dim = q.out_features // head
scale = head_dim ** -0.5
q = q(x) * scale
k = k(xa if xa is not None else x) * scale
v = v(xa if xa is not None else x)
batch, ctx, _ = q.shape
def _shape(tensor):
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
return _shape(q), _shape(k), _shape(v)
def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
# q, k, v = create_qkv(q, k, v, dims, head)
batch, head, ctx, dims = q.shape
attn_mask = None
if mask is not None:
if mask.dim() <= 3:
attn_mask = create_attention_mask(
batch_size=batch,
ctx=ctx,
is_causal=is_causal,
padding_mask=mask if mask.dim() > 1 else None,
device=device)
else:
attn_mask = mask
scaled_q = q
if temperature != 1.0 and temperature > 0:
scaled_q = q * (1.0 / temperature)**.5
a = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
return out, None
class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue]
v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue]
return k_out, v_out
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def trace_x(func):
def wrapper(*args, **kwargs):
print(f"Calling {func.__name__}")
result = func(*args, **kwargs)
if isinstance(result, torch.Tensor):
print(f" {func.__name__} returned shape: {result.shape}")
return result
return wrapper
def track_x(new_x, operation=""):
""" track_x(x, "x") """
x_id = [id(new_x)]
if new_x is None:
return new_x
current_id = id(new_x)
if current_id != x_id[0]:
print(f"x FLOW: {x_id[0]}{current_id} in {operation}")
x_id[0] = current_id
else:
print(f"x REUSE: {current_id} in {operation}")
return new_x
def track_xa(new_xa, operation=""):
""" track_xa(xa, "xa - decoder") """
xa_id = [id(new_xa)] if new_xa is not None else [None]
if new_xa is None:
return new_xa
current_id = id(new_xa)
if current_id != xa_id[0]:
print(f"xa FLOW: {xa_id[0]}{current_id} in {operation}")
xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue]
else:
print(f"xa REUSE: {current_id} in {operation}")
return new_xa
def get_activation(act: str) -> nn.Module:
"""Get activation function by name."""
act_map = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"tanh": nn.Tanh(),
"swish": nn.SiLU(),
"tanhshrink": nn.Tanhshrink(),
"softplus": nn.Softplus(),
"softshrink": nn.Softshrink(),
"leaky_relu": nn.LeakyReLU(),
"elu": nn.ELU()
}
return act_map.get(act, nn.GELU())
def get_generation_config(param):
return GenerationConfig( # type: ignore
max_length=param.text_ctx,
pad_token_id=getattr(param, "pad_token_id", 0),
bos_token_id=getattr(param, "bos_token_id", 1),
eos_token_id=getattr(param, "eos_token_id", 2),
do_sample=False,
num_beams=1,
early_stopping=False,
length_penalty=1.0,
no_repeat_ngram_size=0,
repetition_penalty=1.0,
temperature=1.0,
decoder_start_token_id=1,
is_multilingual=False,
use_cache=False,
return_timestamps=False)
# class rotary(nn.Module):
# def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
# super(rotary, self).__init__()
# self.use_pbias = use_pbias
# self.dims = dims
# self.head = head
# self.head_dim = dims // head
# self.radii = radii
# self.debug = debug
# self.counter = 0
# self.last_theta = None
# self.axial = axial
# self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
# theta = (torch.tensor(10000, device=device, dtype=dtype))
# self.theta = nn.Parameter(theta, requires_grad=True)
# self.theta_values = []
# if axial and spec_shape is not None:
# time_frames, freq_bins = spec_shape
# self.time_frames = time_frames
# self.freq_bins = freq_bins
# time_theta = 50.0
# time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
# self.register_buffer('time_freqs', time_freqs)
# freq_theta = 100.0
# freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
# self.register_buffer('freq_freqs', freq_freqs)
# def pitch_bias(self, f0):
# if f0 is None:
# return None
# f0_flat = f0.squeeze().float()
# f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
# f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
# f0_norm.unsqueeze(1)))
# return f0_sim.unsqueeze(0).unsqueeze(0)
# def theta_freqs(self, theta):
# if theta.dim() == 0:
# theta = theta.unsqueeze(0)
# freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
# torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
# self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
# return freq
# def _apply_radii(self, freqs, f0, ctx):
# if self.radii and f0 is not None:
# radius = f0.to(device, dtype)
# L = radius.shape[0]
# if L != ctx:
# feature = L / ctx
# idx = torch.arange(ctx, device=f0.device)
# idx = (idx * feature).long().clamp(0, L - 1)
# radius = radius[idx]
# return torch.polar(radius.unsqueeze(-1), freqs), radius
# else:
# return torch.polar(radius.unsqueeze(-1), freqs), radius
# else:
# return torch.polar(torch.ones_like(freqs), freqs), None
# def check_f0(self, f0, f0t, ctx):
# if f0 is not None and f0.shape[1] == ctx:
# return f0
# elif f0t is not None and f0t.shape[1] == ctx:
# return f0t
# else:
# return None
# def axial_freqs(self, ctx):
# if not self.axial:
# return None
# time_frames = self.time_frames
# freq_bins = self.freq_bins
# t = torch.arange(ctx, device=device, dtype=dtype)
# t_x = (t % time_frames).float()
# t_y = torch.div(t, time_frames, rounding_mode='floor').float()
# freqs_x = torch.outer(t_x, self.time_freqs)
# freqs_y = torch.outer(t_y, self.freq_freqs)
# freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
# freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
# return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
# def forward(self, x=None, feats=None, feature=None, layer=None) -> Tensor:
# ctx=x
# f0 = feats.get("f0") if feats is not None else None
# f0t = feats.get("f0t") if feats is not None else None
# f0 = self.check_f0(f0, f0t, ctx)
# if f0 is not None:
# # if f0.dim() == 2:
# # f0 = f0.squeeze(0)
# theta = f0 + self.theta
# else:
# theta = self.theta
# freqs = self.theta_freqs(theta)
# t = torch.arange(ctx, device=device, dtype=dtype) # type: ignore
# freqs = t[:, None] * freqs
# freqs, radius = self._apply_radii(freqs, f0, ctx)
# if self.axial and feature == "spectrogram":
# freqs_2d = self.axial_freqs(ctx)
# if freqs_2d is not None:
# return freqs_2d.unsqueeze(0)
# if "radius" in self.debug and self.counter == 10:
# print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
# self.counter += 1
# return freqs.unsqueeze(0)
# @staticmethod
# def split(X: Tensor):
# half_dim = X.shape[-1] // 2
# return X[..., :half_dim], X[..., half_dim:]
# @staticmethod
# def apply_rotary(x, freqs):
# x1 = x[..., :freqs.shape[-1]*2]
# x2 = x[..., freqs.shape[-1]*2:]
# orig_shape = x1.shape
# if x1.ndim == 2:
# x1 = x1.unsqueeze(0)
# x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
# x1 = torch.view_as_complex(x1) * freqs
# x1 = torch.view_as_real(x1).flatten(-2)
# x1 = x1.view(orig_shape)
# return torch.cat([x1.type_as(x), x2], dim=-1)
# class feature_encoder(nn.Module):
# def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None):
# """
# Feature encoder for audio processing.
# """
# super().__init__()
# self.dims = dims
# self.head = head
# self.head_dim = dims // head
# self.dropout = 0.01
# self.use_rope = use_rope
# self.attend_feature = attend_feature
# self.target_length = target_length
# self.feature = feature
# self.debug = debug
# act_fn = get_activation(act)
# if self.attend_feature:
# self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
# self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims))
# else:
# self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
# self.mlp = None
# self.spectrogram = nn.Sequential(
# Conv1d(mels, dims, kernel_size=3), act_fn,
# Conv1d(dims, dims, kernel_size=3), act_fn,
# Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn)
# self.waveform = nn.Sequential(
# Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
# Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
# Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
# self.pitch = nn.Sequential(
# Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn,
# Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
# if use_rope:
# # if spec_shape is not None:
# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
# self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
# else:
# self.rope = None
# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
# self.norm = RMSNorm(dims)
# def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None):
# if isinstance(x, int):
# ctx = x
# elif isinstance(x, torch.Tensor):
# ctx = x.shape[1] if x.dim() > 1 else x.shape[0]
# batch, ctx, dims = x.shape[0], ctx, x.shape[-1]
# x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
# freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)
# x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue]
# x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
# return x
# def mel_scalar(self, freq: float) -> float:
# return 1127.0 * math.log(1.0 + freq / 700.0)
# def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000):
# target_length = x.shape[1] if self.target_length is None else self.target_length
# if feature == "pitch":
# xp = x.clone()
# enc_dict = feats if feats is not None else {}
# enc_dict = dict(enc_dict)
# enc_dict["f0"] = xp
# # xp = self.mel_scalar(xp.mean())
# # print(f"Using pitch scalar: {xp}")
# # max_tscale = xp*300
# # print(f"Using max_tscale: {max_tscale}")
# feats = enc_dict
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.pitch(x).permute(0, 2, 1)
# if feature == "phase":
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.pitch(x).permute(0, 2, 1)
# if feature == "waveform":
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.waveform(x).permute(0, 2, 1)
# if target_length and x.shape[1] != self.target_length:
# x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2)
# if feature == "harmonics":
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.spectrogram(x).permute(0, 2, 1)
# if feature == "aperiodic":
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.spectrogram(x).permute(0, 2, 1)
# if feature == "spectrogram":
# if x.dim() == 2:
# x = x.unsqueeze(0)
# x = self.spectrogram(x).permute(0, 2, 1)
# if self.use_rope:
# x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
# x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer)
# else:
# max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale
# x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype)
# x = nn.functional.dropout(x, p=self.dropout, training=self.training)
# x = self.norm(x)
# if self.attend_feature:
# xa = feats[feature] # pyright: ignore[reportOptionalSubscript]
# if xa is not None:
# q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
# out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
# x = x + out
# x = nn.functional.dropout(x, p=self.dropout, training=self.training)
# x = self.norm(x)
# return x
class OneShot(nn.Module):
def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None):
super().__init__()
if features is None:
features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"]
self.head = head
self.head_dim = dims // head
self.scale = 1.0 // len(features) if features else scale
self.q = Linear(dims, dims)
self.k = Linear(dims, dims)
def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None:
B, L, D = x.shape
K = xa.size(1)
q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2)
k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2)
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim)
return bias
class curiosity(nn.Module):
def __init__(self, d, h, bias=True):
super().__init__()
self.h = h
self.dh = d // h
self.qkv = nn.Linear(d, d * 3, bias=bias)
self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
self.o = nn.Linear(d, d, bias=bias)
self.g = nn.Parameter(torch.zeros(h))
def split(self, x):
b, t, _ = x.shape
return x.view(b, t, self.h, self.dh).transpose(1, 2)
def merge(self, x):
b, h, t, dh = x.shape
return x.transpose(1, 2).contiguous().view(b, t, h * dh)
def forward(self, x, xa, mask=None):
q, k, v = self.qkv(x).chunk(3, -1)
qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
q, k, v = map(self.split, (q, k, v))
qa, ka, va = map(self.split, (qa, ka, va))
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
if mask is not None: dots = dots.masked_fill(mask, -9e15)
p = dots.softmax(-1)
pa = dots_aux.softmax(-1)
h_main = p @ v
h_aux = pa @ va
g = torch.sigmoid(self.g).view(1, -1, 1, 1)
out = self.merge(h_main * (1 - g) + h_aux * g)
return self.o(out)
class PositionalEncoding(nn.Module):
def __init__(self, dims, ctx):
super(PositionalEncoding, self).__init__()
self.dims = dims
self.ctx = ctx
self.pe = self.get_positional_encoding(max_ctx=ctx)
def get_positional_encoding(self, max_ctx):
pe = torch.zeros(max_ctx, self.dims)
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dims, 2, dtype=torch.float32)
* (-math.log(10000.0) / self.dims)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe.to(device)
def forward(self, x):
ctx = x.size(1)
pe = self.pe[:, :ctx, :]
x = x * math.sqrt(self.dims)
x = x + pe
return x
def plot_waveform(x=None, w=None, p=None, per=None, sample_idx=0, sr=16000, hop_length=160,
title="", markers=None, marker_labels=None,
show_voiced_regions=True, show_energy=False):
num_plots = sum([x is not None, w is not None, p is not None, per is not None])
if num_plots == 0:
raise ValueError("No data to plot. Please provide at least one input tensor.")
t_spans = []
if w is not None:
w_np = w[sample_idx].detach().cpu().numpy()
if w_np.ndim > 1:
w_np = w_np.squeeze()
t_spans.append(len(w_np) / sr)
if x is not None:
x_np = x[sample_idx].detach().cpu().numpy()
if x_np.shape[0] < x_np.shape[1]:
x_np = x_np.T
t_spans.append(x_np.shape[0] * hop_length / sr)
if p is not None:
p_np = p[sample_idx].detach().cpu().numpy()
if p_np.ndim > 1:
p_np = p_np.squeeze()
t_spans.append(len(p_np) * hop_length / sr)
if per is not None:
per_np = per[sample_idx].detach().cpu().numpy()
if per_np.ndim > 1:
per_np = per_np.squeeze()
t_spans.append(len(per_np) * hop_length / sr)
max_t = max(t_spans) if t_spans else 0
fig, axs = plt.subplots(num_plots, 1, figsize=(14, 4*num_plots), sharex=True)
if num_plots == 1:
axs = [axs]
if show_voiced_regions and per is not None:
per_np = per[sample_idx].detach().cpu().numpy()
if per_np.ndim > 1:
per_np = per_np.squeeze()
t_per = np.arange(len(per_np)) * hop_length / sr
threshold = 0.5
for ax in axs:
for i in range(len(per_np)-1):
if per_np[i] > threshold:
ax.axvspan(t_per[i], t_per[i+1], color='lightblue', alpha=0.2, zorder=0)
cu_ax = 0
if w is not None:
w_np = w[sample_idx].detach().cpu().numpy()
if w_np.ndim > 1:
w_np = w_np.squeeze()
t = np.arange(len(w_np)) / sr
axs[cu_ax].plot(t, w_np, color="tab:blue")
if show_energy:
frame_length = hop_length
hop_length_energy = hop_length // 2
energy = []
for i in range(0, len(w_np)-frame_length, hop_length_energy):
frame = w_np[i:i+frame_length]
energy.append(np.sqrt(np.mean(frame**2)))
energy = np.array(energy)
energy = energy / np.max(energy) * 0.8 * max(abs(w_np.min()), abs(w_np.max()))
t_energy = np.arange(len(energy)) * hop_length_energy / sr
axs[cu_ax].plot(t_energy, energy, color="red", alpha=0.7, label="Energy")
axs[cu_ax].legend(loc='upper right')
axs[cu_ax].set_title("Waveform")
axs[cu_ax].set_ylabel("Amplitude")
axs[cu_ax].set_xlim([0, max_t])
axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
cu_ax += 1
if x is not None:
x_np = x[sample_idx].detach().cpu().numpy()
if x_np.shape[0] < x_np.shape[1]:
x_np = x_np.T
axs[cu_ax].imshow(x_np.T, aspect="auto", origin="lower", cmap="magma",
extent=[0, x_np.shape[0]*hop_length/sr, 0, x_np.shape[1]])
axs[cu_ax].set_title("Spectrogram")
axs[cu_ax].set_ylabel("Mel Bin")
axs[cu_ax].set_xlim([0, max_t])
axs[cu_ax].grid(True, axis='x', linestyle='--', alpha=0.3)
cu_ax += 1
if p is not None:
p_np = p[sample_idx].detach().cpu().numpy()
if p_np.ndim > 1:
p_np = p_np.squeeze()
t_p = np.arange(len(p_np)) * hop_length / sr
axs[cu_ax].plot(t_p, p_np, color="tab:green")
axs[cu_ax].set_title("Pitch")
axs[cu_ax].set_ylabel("Frequency (Hz)")
axs[cu_ax].set_xlim([0, max_t])
axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
axs[cu_ax].set_ylim([0, min(1000, p_np.max() * 1.2)])
cu_ax += 1
if per is not None:
per_np = per[sample_idx].detach().cpu().numpy()
if per_np.ndim > 1:
per_np = per_np.squeeze()
t_per = np.arange(len(per_np)) * hop_length / sr
axs[cu_ax].plot(t_per, per_np, color="tab:red")
axs[cu_ax].set_title("Period (Voice Activity)")
axs[cu_ax].set_ylabel("periodocity")
axs[cu_ax].set_xlim([0, max_t])
axs[cu_ax].grid(True, axis='both', linestyle='--', alpha=0.3)
axs[cu_ax].set_ylim([-0.05, 1.05])
axs[cu_ax].axhline(y=0.5, color='k', linestyle='--', alpha=0.3)
if markers is not None:
for i, t in enumerate(markers):
label = marker_labels[i] if marker_labels and i < len(marker_labels) else None
for ax in axs:
ax.axvline(x=t, color='k', linestyle='-', alpha=0.7, label=label if i == 0 else None)
if marker_labels:
axs[0].legend(loc='upper right', fontsize='small')
axs[-1].set_xlabel("t (s)")
fig.suptitle(title, fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.97]) # type: ignore
plt.show()
return fig
def valid(default_value, *items):
"""Get first non-None item"""
for item in items:
if item is not None:
return item
return default_value
def dict_to(d, device, dtype=dtype):
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v
for k, v in d.items()}
def exists(v):
return v is not None
def default(v, b):
return v if exists(v) else b
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias) -> Tensor:
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
class Conv2d(nn.Conv2d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias) -> Tensor:
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype))
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
init.xavier_uniform_(self.linear.weight)
if bias:
init.zeros_(self.linear.bias)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
class RMSNorm(nn.Module):
def __init__(self, dims: Union[int, Tensor, List, Tuple],
eps = 1e-8, elementwise_affine = True):
super(RMSNorm, self).__init__()
if isinstance(dims, int):
self.normalized_shape = (dims,)
else:
self.normalized_shape = tuple(dims)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore
init.ones_(self.weight)
else:
self.register_parameter("weight", None)
def forward(self, x):
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple],
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None,
eps: float = 1e-5) -> Tensor:
return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore
def get_device():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_dtype():
return torch.float32 if torch.cuda.is_available() else torch.float64
def tox():
return {"device": get_device(), "dtype": get_dtype()}
class Sinusoids(nn.Module):
def __init__(self, ctx: int, dims: int):
super().__init__()
position = torch.arange(start=0, end=ctx, dtype=dtype).unsqueeze(dim=1)
div_term = torch.exp(input=torch.arange(start=0, end=dims, step=2, dtype=dtype) * -(math.log(10000.0) / dims))
features = torch.zeros(ctx, dims)
features[:, 0::2] = torch.sin(position * div_term)
features[:, 1::2] = torch.cos(position* div_term)
self.register_buffer('sinusoid', tensor=features)
self.positional_embeddings = nn.Parameter(self.sinusoid.clone()) # type: ignore
def forward(self, positions):
position_embeddings = self.positional_embeddings[positions]
return position_embeddings
def sinusoids(length, channels, max_tscale=10000):
assert channels % 2 == 0
log_tscale_increment = torch.log(torch.tensor(float(max_tscale))) / (channels // 2 - 1)
inv_tscales = torch.exp(-log_tscale_increment * torch.arange(channels // 2, device=device, dtype=torch.float32))
scaled_t = torch.arange(length, device=device, dtype=torch.float32).unsqueeze(1) * inv_tscales.unsqueeze(0)
return torch.cat([torch.sin(scaled_t), torch.cos(scaled_t)], dim=1)
class SelfCriticalRL(nn.Module):
def __init__(self, model, tokenizer, reward_fn):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.reward_fn = reward_fn
def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"):
with torch.no_grad():
greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len)
greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids]
sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5)
sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids]
rewards = []
baseline = []
for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore
ref_text = self.tokenizer.decode(ref)
rewards.append(self.reward_fn(s, ref_text))
baseline.append(self.reward_fn(g, ref_text))
rewards = torch.tensor(rewards, device=device, dtype=torch.float)
baseline = torch.tensor(baseline, device=device, dtype=torch.float)
advantage = rewards - baseline
logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size]
log_probs = F.log_softmax(logits, dim=-1)
log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1)
log_probs_sum = log_probs_seq.sum(dim=1)
loss = -(advantage * log_probs_sum).mean()
return loss
class SelfTrainingModule(nn.Module):
def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.quality_fn = quality_fn
self.threshold = threshold
def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
with torch.no_grad():
pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len)
if self.quality_fn is not None:
quality_scores = self.quality_fn(pred_ids, self.model, features)
mask = quality_scores > self.threshold
pred_ids = pred_ids[mask]
return pred_ids
def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"):
pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name)
logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"]
loss = nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0)
return loss
def confidence_indicator(pred_ids, model, features):
with torch.no_grad():
logits = model(input_ids=pred_ids, **features)["logits"]
probs = torch.softmax(logits, dim=-1)
max_probs, _ = probs.max(dim=-1)
return max_probs.mean(dim=1)
def wer_reward(hyp, ref):
hyp_words = hyp.split()
ref_words = ref.split()
d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)]
for i in range(len(hyp_words)+1):
d[i][0] = i
for j in range(len(ref_words)+1):
d[0][j] = j
for i in range(1, len(hyp_words)+1):
for j in range(1, len(ref_words)+1):
if hyp_words[i-1] == ref_words[j-1]:
d[i][j] = d[i-1][j-1]
else:
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
wer = d[-1][-1] / max(1, len(ref_words))
return -wer # negative WER as reward
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id]
def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids]
def setup_tokenizer(dir: str):
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(f"{dir}")
orig_encode = tokenizer.encode
orig_decode = tokenizer.decode
def enc(text, add_special_tokens=True):
ids = orig_encode(text).ids
if not add_special_tokens:
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
ids = [id for id in ids if id not in sp_ids]
return ids
def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True):
results = []
if isinstance(ids_list, torch.Tensor):
ids_list = ids_list.tolist()
elif isinstance(ids_list, np.ndarray):
ids_list = ids_list.tolist()
for ids in ids_list:
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
results.append(orig_decode(ids))
return results
def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)]
return orig_decode(ids)
def save_pretrained(save_dir):
os.makedirs(save_dir, exist_ok=True)
tokenizer.save(f"{save_dir}/tokenizer.json")
tokenizer.encode = enc
tokenizer.batch_decode = bdec
tokenizer.decode = dec
tokenizer.save_pretrained = save_pretrained
tokenizer.pad_token_id = 0
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
return tokenizer
def tokenize_pitch(pitch_features, target_length):
pitch_len = pitch_features.shape[-1]
token_len = target_length
if pitch_len > token_len:
pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len)
else:
pitch_tokens = F.interpolate(pitch_features, token_len)
return pitch_tokens
def load_wave(wave_data, sample_rate=16000):
if isinstance(wave_data, str):
waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False)
elif isinstance(wave_data, dict):
waveform = torch.tensor(data=wave_data["array"]).float()
sample_rate = wave_data["sampling_rate"] # noqa: F841
else:
raise TypeError("Invalid wave_data format.")
return waveform
def world_to_mel(sp, ap, sample_rate=16000, n_mels=128):
import librosa
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels)
mel_basis = torch.from_numpy(mel_basis).float()
sp_mel = torch.matmul(sp, mel_basis.T) # (frames, 128)
ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128)
return sp_mel, ap_mel
def extract_features(batch, tokenizer, waveform=False, spec=False, f0=False, f0t=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False):
# import torchaudio
# import torchaudio.functional
# import torchaudio.transforms
# torch_windows = {
# 'hann': torch.hann_window,
# 'hamming': torch.hamming_window,
# 'blackman': torch.blackman_window,
# 'bartlett': torch.bartlett_window,
# 'ones': torch.ones,
# None: torch.ones,
# }
# if dummy:
# return {
# "spectrogram": torch.zeros((1, 128, 100)),
# "f0": torch.zeros((1, 100)),
# "f0t": torch.zeros((1, 100)),
# "pitch": torch.zeros((1, 100)),
# "harmonics": torch.zeros((1, 128, 100)),
# "aperiodics": torch.zeros((1, 128, 100)),
# "crepe_time": None,
# "crepe_frequency": None,
# "crepe_confidence": None,
# "crepe_activation": None,
# }
audio = batch["audio"]
sample_rate = audio["sampling_rate"]
labels = tokenizer.encode(batch["transcription"])
wav = load_wave(wave_data=audio, sample_rate=sample_rate)
spectrogram_config = {
# "hop_length": 256,
# "f_min": 150,
# "f_max": 2000,
# "n_mels": 128,
# "n_fft": 1024,
"sample_rate": 16000,
# "pad_mode": "constant",
# "center": True,
# "power": 1.0,
# "window_fn": torch.hann_window,
# "mel_scale": "htk",
# "norm": None,
# "normalized": False,
}
def crepe_predict(wav, sample_rate, viterbi=False):
import torchcrepe
wav = wav.numpy().astype(np.float32)
time, frequency, confidence, activation = torchcrepe.predict(
wav, sample_rate=sample_rate, viterbi=viterbi)
crepe_time = torch.from_numpy(time)
crepe_frequency = torch.from_numpy(frequency)
crepe_confidence = torch.from_numpy(confidence)
crepe_activation = torch.from_numpy(activation)
return crepe_time, crepe_frequency, crepe_confidence, crepe_activation
if crepe:
crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True)
else:
crepe_time = None
crepe_frequency = None
crepe_confidence = None
crepe_activation = None
# def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
# if isinstance(window_fn, str):
# window_fn = torch_windows[window_fn]
# if window_fn is None:
# window_fn = torch.ones(n_fft)
# if isinstance(window_fn, torch.Tensor):
# window_fn = window_fn.to(device)
# return torchaudio.functional.spectrogram(
# wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft,
# window=window_fn, center=True, pad_mode="reflect", power=1.0)
# def mel_spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
# transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
# mel_spectrogram = transform(wav)
# log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
# log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
# spectrogram_tensor = (log_mel + 4.0) / 4.0
# spectrogram_tensor = torch.tensor(spectrogram_tensor)
# return spectrogram_tensor
if spec:
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config)
mel_spectrogram = transform(wav)
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
spectrogram_tensor = (log_mel + 4.0) / 4.0
spectrogram_tensor = torch.tensor(spectrogram_tensor)
# if spec:
# if isinstance(wav, torch.Tensor):
# wav = wav.to(device)
# spectrogram_tensor = mel_spectrogram(wav, sample_rate, **spectrogram_config)
# spectrogram_tensor = spectrogram_tensor.permute(1, 0)
def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window):
transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mels,
melkwargs={
"n_fft": n_fft,
"hop_length": hop_length,
"window_fn": window_fn,
"n_mels": n_mels,
"center": True,
"pad_mode": "reflect",
"norm": None,
"mel_scale": "htk",
}
)
mfcc_tensor = transform(wav)
return mfcc_tensor
def compute_pitch(wav, sample_rate, hop_length=256):
import pyworld as pw
wav_np = wav.numpy().astype(np.float64)
f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000)
f0 = pw.stonemask(wav_np, f0, t, sample_rate)
return f0, t
def compute_harmonics_and_aperiodics(wav, f0, t, sample_rate):
import pyworld as pw
wav_np = wav.numpy().astype(np.float64)
sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256)
ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256)
harmonic_tensor = torch.from_numpy(sp)
aperiodic_tensor = torch.from_numpy(ap)
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
return harmonic_tensor, aperiodic_tensor
if f0 or f0t or pitch or harmonics or aperiodics:
wavnp = wav.numpy().astype(np.float64)
f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000)
f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate)
if f0:
f0_tensor = torch.from_numpy(f0_np)
else:
f0_tensor = None
if f0t:
wav = torch.from_numpy(wavnp)
t2 = torch.from_numpy(t)
audio_duration = len(wav) / sample_rate
T = len(labels)
tok_dur_sec = audio_duration / T
token_starts = torch.arange(T) * tok_dur_sec
token_ends = token_starts + tok_dur_sec
start_idx = torch.searchsorted(t2, token_starts, side="left")
end_idx = torch.searchsorted(t2, token_ends, side="right")
pitch_tok = torch.zeros(T, dtype=torch.float32)
for i in range(T):
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore
segment = f0_np[lo:hi]
if mode == "mean":
pitch_tok[i] = segment.mean()
elif mode == "median":
pitch_tok[i] = torch.median(segment)
else:
pitch_tok[i] = segment[-1]
pitch_tok[pitch_tok < 100.0] = 0.0
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0
f0t_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok])
f0t_tensor = torch.where(f0t_tensor == 0.0, torch.zeros_like(f0t_tensor), (f0t_tensor - 71.0) / (500.0 - 71.0))
else:
f0t_tensor = None
if phase_mod:
tframe = torch.mean(t2[1:] - t2[:-1])
phi0 = 0.0
omega = 2 * torch.pi * f0_tensor # type: ignore
dphi = omega * tframe
phi = torch.cumsum(dphi, dim=0) + phi0
phase = torch.remainder(phi, 2 * torch.pi)
else:
phase = None
if pitch:
p_tensor = compute_pitch(wav, sample_rate, hop_length=hop_length)[0]
p_tensor = torch.from_numpy(p_tensor)
p_tensor = p_tensor.unsqueeze(0)
# p_tensor = torch.from_numpy(f0_np)
else:
p_tensor = None
if harmonics or aperiodics:
spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256)
apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256)
harmonic_tensor = torch.from_numpy(spnp)
aperiodic_tensor = torch.from_numpy(apnp)
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0)
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0)
else:
harmonic_tensor = None
aperiodic_tensor = None
if waveform:
wave_tensor = wav
else:
wave_tensor = None
if dummy:
if spectrogram_tensor is not None:
dummy_tensor = torch.ones_like(spectrogram_tensor)
elif p_tensor is not None:
dummy_tensor = torch.ones_like(p_tensor)
elif f0_tensor is not None:
dummy_tensor = torch.ones_like(f0_tensor)
elif f0t_tensor is not None:
dummy_tensor = torch.ones_like(f0t_tensor)
else:
batch_size = 128
seq_len = 1024
dummy_tensor = torch.ones(batch_size, seq_len)
dummy_tensor = dummy_tensor.to(device)
else:
dummy_tensor = None
if debug:
print(f"['f0']: {f0_tensor.shape if f0 else None}")
print(f"['f0t']: {f0t_tensor.shape if f0t else None}")
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}")
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}")
print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}")
print(f"['waveform']: {wave_tensor.shape if waveform else None}")
print(f"['labels']: {len(labels) if labels else None}")
print(f"['phase']: {phase.shape if phase else None}")
print(f"['pitch']: {p_tensor.shape if pitch else None}")
print(f"['crepe_time']: {crepe_time.shape if crepe else None}")
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}")
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}")
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}")
print(f"['dummy']: {dummy_tensor.shape if dummy else None}")
return {
"waveform": wave_tensor if waveform else None,
"spectrogram": spectrogram_tensor if spec else None,
"f0": f0_tensor if f0 else None,
"f0t": f0t_tensor if f0t else None,
"pitch": p_tensor if pitch else None,
"harmonic": harmonic_tensor if harmonics else None,
"aperiodic": aperiodic_tensor if aperiodics else None,
"labels": labels,
"phase": phase if phase_mod else None,
"crepe_time": crepe_time if crepe else None,
"crepe_frequency": crepe_frequency if crepe else None,
"crepe_confidence": crepe_confidence if crepe else None,
"crepe_activation": crepe_activation if crepe else None,
"dummy": dummy_tensor if dummy else None,
}
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048):
if extract_args is None:
extract_args = {
"waveform": False,
"spec": False,
"f0": False,
"f0t": False,
"pitch": False,
"harmonic": False,
"aperiodic": False,
"sample_rate": 16000,
"hop_length": 256,
"mode": "mean",
"debug": False,
"phase_mod": False,
"crepe": False,
"dummy": False,
}
if load_saved:
if cache_dir is None:
cache_dir = "./processed_datasets"
else:
cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
cache_file_train = os.path.join(cache_dir, "train.arrow")
cache_file_test = os.path.join(cache_dir, "test.arrow")
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
from datasets import Dataset
train_dataset = Dataset.load_from_disk(cache_file_train)
test_dataset = Dataset.load_from_disk(cache_file_test)
return train_dataset, test_dataset
if sanity_check:
test = load_dataset(
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1)
dataset = test.map(
lambda x: extract_features(x, tokenizer, **extract_args),
remove_columns=test.column_names)
train_dataset = dataset
test_dataset = dataset
return train_dataset, test_dataset
else:
def filter_func(x):
return (0 < len(x["transcription"]) < max_ctx and
len(x["audio"]["array"]) > 0 and
len(x["audio"]["array"]) < max_ctx * 160)
raw_train = load_dataset(
"google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000)
raw_test = load_dataset(
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100)
raw_train = raw_train.filter(filter_func)
raw_test = raw_test.filter(filter_func)
raw_train = raw_train.cast_column("audio", Audio(sampling_rate=sample_rate))
raw_test = raw_test.cast_column("audio", Audio(sampling_rate=sample_rate))
train_dataset = raw_train.map(
lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names)
test_dataset = raw_test.map(
lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names)
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
return train_dataset, test_dataset
def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module:
if feature == "spectrogram":
return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head, layer=layer, act=act, feature=feature, features=features)
elif feature == "waveform":
return WEncoder(input_dims, dims, head, layer, act, feature, features)
elif feature == "pitch":
return PEncoder(input_dims, dims, head, layer, act, feature, features)
else:
raise ValueError(f"Unknown feature type: {feature}")
class FEncoder(nn.Module):
def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
self.feature = feature
self.mels = mels
self.input_dims = input_dims
act_fn = get_activation(act)
self.encoder = nn.Sequential(
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
x = self.rope.apply_rotary(x, freqs)# type: ignore
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"):
x = self.encoder(x).permute(0, 2, 1)
if self.use_rope:
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
x = self.norm(x)
return x
class WEncoder(nn.Module): # waveform encoder
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
act_fn = get_activation(act)
self.target_length = None
self.encoder = nn.Sequential(
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore
x = self.rope.apply_rotary(x, freqs)# type: ignore
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"):
x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims)
if self.target_length and x.shape[1] != self.target_length:
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
if self.use_rope:
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
return self.norm(x)
class PEncoder(nn.Module): # pitch encoder
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], one_shot=False, spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dims = dims
self.dropout = 0.01
self.use_rope = use_rope
self.debug = debug
act_fn = get_activation(act)
self.attend_pitch = False
if self.attend_pitch:
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
self.mlp = nn.Sequential(
nn.Linear(dims, dims),
nn.ReLU(),
nn.Linear(dims, dims),
)
else:
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
self.mlp = None
self.pitch_encoder = nn.Sequential(
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
# self.spectrogram_encoder = nn.Sequential(
# Conv1d(input_dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
# self.waveform_encoder = nn.Sequential(
# Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
# Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
# Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
if use_rope:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore
x = self.rope.apply_rotary(x, freqs)# type: ignore
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"):
# f0=x
# freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer)
if x.dim() == 2:
x = x.unsqueeze(0)
if feature == "pitch":
x = self.pitch_encoder(x).permute(0, 2, 1)
# elif feature == "spectrogram":
# x = self.spectrogram_encoder(x).permute(0, 2, 1)
# elif feature == "waveform":
# x = self.waveform_encoder(x).permute(0, 2, 1)
# if self.target_length and x.shape[1] != self.target_length:
# x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
if self.use_rope:
x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer)
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
if self.mlp is not None:
x = self.mlp(x)
if self.attend_pitch:
if xa is not None:
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
x = x + out
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.norm(x)
print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None
return x
@dataclass
class DataCollator:
tokenizer: Any
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
all_keys = set()
for f in features:
all_keys.update(f.keys())
batch = {}
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0)
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1)
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2)
for key in all_keys:
if key == "labels":
labels_list = [f["labels"] for f in features]
max_len = max(len(l) for l in labels_list) # noqa: E741
all_ids, all_labels = [], []
for label in labels_list:
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
decoder_input = [bos_token_id] + label_list
label_eos = label_list + [eos_token_id]
input_len = max_len + 1 - len(decoder_input)
label_len = max_len + 1 - len(label_eos)
padded_input = decoder_input + [pad_token_id] * input_len
padded_labels = label_eos + [pad_token_id] * label_len
all_ids.append(padded_input)
all_labels.append(padded_labels)
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "f0t", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]:
items = [f[key] for f in features if key in f]
items = [item for item in items if item is not None]
if not items:
continue
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items]
max_len = max(item.shape[-1] for item in items)
padded = []
for item in items:
pad_width = max_len - item.shape[-1]
if pad_width > 0:
pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id)
else:
pad_item = item
padded.append(pad_item)
batch[key] = torch.stack(padded)
# if key == "spectrogram":
# batch["spectrogram"] = batch[key]
return batch
def levenshtein(reference_words, hypothesis_words):
m, n = len(reference_words), len(hypothesis_words)
dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)]
for i in range(m+1):
dist_matrix[i][0] = i
for j in range(n+1):
dist_matrix[0][j] = j
for i in range(1, m+1):
for j in range(1, n+1):
if reference_words[i-1] == hypothesis_words[j-1]:
dist_matrix[i][j] = dist_matrix[i-1][j-1]
else:
substitution = dist_matrix[i-1][j-1] + 1
insertion = dist_matrix[i][j-1] + 1
deletion = dist_matrix[i-1][j] + 1
dist_matrix[i][j] = min(substitution, insertion, deletion)
return dist_matrix[m][n]
def wer_batch(references, hypotheses):
total_errors = 0
total_words = 0
for ref, hyp in zip(references, hypotheses):
ref_words = ref.lower().split()
errors = levenshtein(ref_words, hyp.lower().split())
total_errors += errors
total_words += len(ref_words)
return (total_errors / total_words) * 100 if total_words > 0 else 0.0
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0):
def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2):
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
if isinstance(ids[0], (list, torch.Tensor, np.ndarray)):
return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids]
else:
return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)]
pred_ids = pred.predictions
label_ids = pred.label_ids
if isinstance(pred_ids, tuple):
pred_ids = pred_ids[0]
if not isinstance(pred_ids, torch.Tensor):
pred_ids = torch.tensor(pred_ids)
label_ids = clean(label_ids)
pred_ids = clean(pred_ids)
pred_str = tokenizer.batch_decode(pred_ids)
label_str = tokenizer.batch_decode(label_ids)
if print_pred:
for i in range(min(num_samples, len(pred_ids))):
print(f"Pred tokens: {pred_ids[i]}")
print(f"Label tokens: {label_ids[i]}")
print(f"Pred: '{pred_str[i]}'")
print(f"Label: '{label_str[i]}'")
print("-" * 40)
wer = wer_batch(label_str, pred_str)
if model is not None:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0
else:
trainable_params = 0.0
efficiency_score = 0.0
return {
"wer": float(wer),
"efficiency_score": float(efficiency_score),
}
def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels