asr-model / modelA.py
Sin2pi's picture
Update modelA.py
79a5277 verified
raw
history blame
40.8 kB
import os
import math
import warnings
import logging
from itertools import chain
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from tensordict import TensorDict
from typing import Optional, Dict, Union, List, Tuple
import numpy as np
from functools import partial
from datetime import datetime
from tensordict import TensorDict
from transformers.trainer_seq2seq import Seq2SeqTrainer
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
from echoutils import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)
class rotary(nn.Module):
def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
super(rotary, self).__init__()
self.use_pbias = use_pbias
self.dims = dims
self.head = head
self.head_dim = dims // head
self.radii = radii
self.debug = debug
self.counter = 0
self.last_theta = None
self.axial = axial
self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
theta = (torch.tensor(10000, device=device, dtype=dtype))
self.theta = nn.Parameter(theta, requires_grad=True)
self.theta_values = []
if axial and spec_shape is not None:
time_frames, freq_bins = spec_shape
self.time_frames = time_frames
self.freq_bins = freq_bins
time_theta = 50.0
time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
self.register_buffer('time_freqs', time_freqs)
freq_theta = 100.0
freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
self.register_buffer('freq_freqs', freq_freqs)
def pitch_bias(self, f0):
if f0 is None:
return None
f0_flat = f0.squeeze().float()
f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
f0_norm.unsqueeze(1)))
return f0_sim.unsqueeze(0).unsqueeze(0)
def theta_freqs(self, theta):
if theta.dim() == 0:
theta = theta.unsqueeze(0)
freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
return freq
def _apply_radii(self, freqs, f0, ctx):
if self.radii and f0 is not None:
radius = f0.to(device, dtype)
L = radius.shape[0]
if L != ctx:
F = L / ctx
idx = torch.arange(ctx, device=f0.device)
idx = (idx * F).long().clamp(0, L - 1)
radius = radius[idx]
return torch.polar(radius.unsqueeze(-1), freqs), radius
else:
return torch.polar(radius.unsqueeze(-1), freqs), radius
else:
return torch.polar(torch.ones_like(freqs), freqs), None
def check_f0(self, f0, f0t, ctx):
if f0 is not None and f0.shape[1] == ctx:
return f0
elif f0t is not None and f0t.shape[1] == ctx:
return f0t
else:
return None
def axial_freqs(self, ctx):
if not self.axial:
return None
time_frames = self.time_frames
freq_bins = self.freq_bins
t = torch.arange(ctx, device=device, dtype=dtype)
t_x = (t % time_frames).float()
t_y = torch.div(t, time_frames, rounding_mode='floor').float()
freqs_x = torch.outer(t_x, self.time_freqs)
freqs_y = torch.outer(t_y, self.freq_freqs)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def forward(self, x=None, en=None, f=None, layer=None) -> Tensor:
ctx=x
f0 = en.get("f0") if en is not None else None
f0t = en.get("f0t") if en is not None else None
f0 = self.check_f0(f0, f0t, ctx)
if f0 is not None:
if f0.dim() == 2:
f0 = f0.squeeze(0)
theta = f0 + self.theta
else:
theta = self.theta
freqs = self.theta_freqs(theta)
t = torch.arange(ctx, device=device, dtype=dtype)
freqs = t[:, None] * freqs
freqs, radius = self._apply_radii(freqs, f0, ctx)
if self.axial and f == "spectrogram":
freqs_2d = self.axial_freqs(ctx)
if freqs_2d is not None:
return freqs_2d.unsqueeze(0)
if "radius" in self.debug and self.counter == 10:
print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
self.counter += 1
return freqs.unsqueeze(0)
@staticmethod
def apply_rotary(x, freqs):
x1 = x[..., :freqs.shape[-1]*2]
x2 = x[..., freqs.shape[-1]*2:]
orig_shape = x1.shape
if x1.ndim == 2:
x1 = x1.unsqueeze(0)
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
x1 = torch.view_as_complex(x1) * freqs
x1 = torch.view_as_real(x1).flatten(-2)
x1 = x1.view(orig_shape)
return torch.cat([x1.type_as(x), x2], dim=-1)
class MultiheadA(nn.Module):
rbf = False
def __init__(self, dims: int, head: int, rotary_emb: bool = True,
zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
super(MultiheadA, self).__init__()
self.dims = dims
self.head = head
self.head_dim = dims // head
self.debug = debug
self.counter = 0
self.use_pbias = use_pbias
self.q = nn.Linear(dims, dims).to(device, dtype)
self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
self.v = nn.Linear(dims, dims).to(device, dtype)
self.o = nn.Linear(dims, dims).to(device, dtype)
self.pad_token = 0
self.rotary_emb = rotary_emb
self.minz = minz
self.maxz = maxz
self.zero_val = zero_val
self.optim_attn = optim_attn
self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
if rotary_emb:
self.rope = rotary(
dims=dims,
head=head,
debug=debug,
radii=False,
)
else:
self.rope = None
def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
qk_cosine = qk_cosine + mask
weights = F.softmax(qk_cosine, dim=-1)
out = torch.matmul(weights, v)
return out
def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
scale = (self.dims // self.head) ** -0.25
dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
if rbf_ratio <= 0.0:
return dot_scores
q_norm = q.pow(2).sum(dim=-1, keepdim=True)
k_norm = k.pow(2).sum(dim=-1, keepdim=True)
qk = torch.matmul(q, k.transpose(-1, -2))
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
def forward(self, x: Tensor, xa = None, mask = None, en= None, layer = None, f=None) -> tuple:
x = x.to(device, dtype)
if xa is not None:
xa = xa.to(device, dtype)
scale = (self.dims // self.head) ** -0.25
z = default(xa, x).to(device, dtype)
q = self.q(x)
k = self.k(z)
v = self.v(z)
if self.rotary_emb:
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
q2 = q.shape[2]
k2 = k.shape[2]
q = self.rope.apply_rotary(q, (self.rope(x=q2, en=en, f=f, layer=layer)))
k = self.rope.apply_rotary(k, (self.rope(x=k2, en=en, f=f, layer=layer)))
else:
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if self.rbf:
qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
if self.use_pbias:
pbias = self.rope.pitch_bias(f0 = en.get("f0", None) if en is not None else None)
if pbias is not None:
qk = qk + pbias[:,:,:q2,:q2]
token_ids = k[:, :, :, 0]
zscale = torch.ones_like(token_ids)
fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
zscale[token_ids.float() == self.pad_token] = fzero
if mask is not None:
if mask.dim() == 4:
mask = mask[0, 0]
mask = mask[:q2, :k2] if xa is not None else mask[:q2, :q2]
qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
qk = qk * zscale.unsqueeze(-2)
w = F.softmax(qk, dim=-1).to(q.dtype)
wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
if "multihead" in self.debug and self.counter % 100 == 0:
print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
self.counter += 1
return self.o(wv), qk
@staticmethod
def split(X: Tensor) -> (Tensor, Tensor):
half_dim = X.shape[-1] // 2
return X[..., :half_dim], X[..., half_dim:]
class t_gate(nn.Module):
def __init__(self, dims, num_types=4, enabled=True):
super().__init__()
self.enabled = enabled
self.gate_projections = nn.ModuleList([
nn.Sequential(Linear(dims, 1), nn.Sigmoid())
for _ in range(num_types)])
self.type_classifier = nn.Sequential(
Linear(dims, num_types),
nn.Softmax(dim=-1))
def forward(self, x):
if not self.enabled:
return None
type_probs = self.type_classifier(x)
gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
return comb_gate
class m_gate(nn.Module):
def __init__(self, dims, mem_size=64, enabled=True):
super().__init__()
self.enabled = enabled
if enabled:
self.m_key = nn.Parameter(torch.randn(mem_size, dims))
self.m_val = nn.Parameter(torch.randn(mem_size, 1))
self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
def forward(self, x):
if not self.enabled:
return None
d_gate = torch.sigmoid(self.gate_proj(x))
attention = torch.matmul(x, self.m_key.transpose(0, 1))
attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
m_gate = torch.matmul(attention, self.m_val)
m_gate = torch.sigmoid(m_gate)
return 0.5 * (d_gate + m_gate)
class c_gate(nn.Module):
def __init__(self, dims, enabled=True):
super().__init__()
self.enabled = enabled
if enabled:
self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
self.integ = Linear(dims*5, dims)
def forward(self, x, features):
if not self.enabled:
return None
s_feat = features.get("spectrogram", x)
w_feat = features.get("waveform", x)
p_feat = features.get("pitch", x)
e_feat = features.get("envelope", x)
ph_feat = features.get("phase", x)
s = self.s_gate(x) * s_feat
w = self.w_gate(x) * w_feat
p = self.p_gate(x) * p_feat
e = self.e_gate(x) * e_feat
ph = self.ph_gate(x) * ph_feat
comb = torch.cat([s, w, p, e, ph], dim=-1)
return self.integ(comb)
class mlp_gate(nn.Module):
def __init__(self, dims, head, enabled=True, one_shot=True):
super().__init__()
self.enabled = enabled
if enabled:
self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
def forward(self, x, xa=None, f=None):
if not self.enabled:
return None
return self.gate(x)
class Residual(nn.Module):
_seen = set()
def __init__(self, ctx, dims, head, act, debug: List[str] = [],
tgate=True, mgate=False, cgate=False, mem_size=512, features=None, one_shot=False):
super().__init__()
self.dims = dims
self.head = head
self.ctx = ctx
self.head_dim = dims // head
self.features = features
self.debug = debug
self.counter = 0
self.dropout = 0.01
self.one_shot = one_shot
self.blend = nn.Parameter(torch.tensor(0.5))
act_fn = get_activation(act)
self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
self.curiosity = curiosity(dims, head)
if not any([tgate, mgate, cgate]):
self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
else:
self.mlp_gate = None
mlp = dims * 4
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
self.c_gate = c_gate(dims=dims, enabled=cgate)
self.mlp_gate = mlp_gate(dims=dims, head=head, enabled=not any([tgate, mgate, cgate]), one_shot=True)
self.lna = RMSNorm(dims)
self.lnb = RMSNorm(dims)
self.lnc = RMSNorm(dims)
def forward(self, x, xa=None, mask=None, en=None, layer=None, f=None) -> Tensor:
b = torch.sigmoid(self.blend)
ax = x + self.attn(self.lna(x), xa=xa, mask=mask, en=en, layer=layer, f=f)[0]
bx = b * ax + (1 - b) * x
cx = self.lnb(bx)
dx = self.mlp(cx)
ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
fx = x + ex + dx
gx = self.lnc(fx)
return gx
class OneShot(nn.Module):
def __init__(self, dims: int, head: int, scale: float = 0.3):
super().__init__()
self.head = head
self.hdim = dims // head
self.scale = scale
self.q_proj = Linear(dims, dims)
self.k_proj = Linear(dims, dims)
def forward(self, x: Tensor, guide: Tensor, f=None) -> Tensor | None:
B, Q, _ = x.shape
K = guide.size(1)
q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
return bias
class curiosity(nn.Module):
def __init__(self, d, h, bias=True):
super().__init__()
self.h = h
self.dh = d // h
self.qkv = nn.Linear(d, d * 3, bias=bias)
self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
self.o = nn.Linear(d, d, bias=bias)
self.g = nn.Parameter(torch.zeros(h))
def split(self, x):
b, t, _ = x.shape
return x.view(b, t, self.h, self.dh).transpose(1, 2)
def merge(self, x):
b, h, t, dh = x.shape
return x.transpose(1, 2).contiguous().view(b, t, h * dh)
def forward(self, x, xa, mask=None):
q, k, v = self.qkv(x).chunk(3, -1)
qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
q, k, v = map(self.split, (q, k, v))
qa, ka, va = map(self.split, (qa, ka, va))
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
if mask is not None: dots = dots.masked_fill(mask, -9e15)
p = dots.softmax(-1)
pa = dots_aux.softmax(-1)
h_main = p @ v
h_aux = pa @ va
g = torch.sigmoid(self.g).view(1, -1, 1, 1)
out = self.merge(h_main * (1 - g) + h_aux * g)
return self.o(out)
class PositionalEncoding(nn.Module):
def __init__(self, dims, ctx):
super(PositionalEncoding, self).__init__()
self.dims = dims
self.ctx = ctx
self.pe = self.get_positional_encoding(max_ctx=ctx)
def get_positional_encoding(self, max_ctx):
pe = torch.zeros(max_ctx, self.dims)
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.dims, 2, dtype=torch.float32)
* (-math.log(10000.0) / self.dims)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe.to(device)
def forward(self, x):
ctx = x.size(1)
pe = self.pe[:, :ctx, :]
x = x * math.sqrt(self.dims)
x = x + pe
return x
class FEncoder(nn.Module):
def __init__(self, mels, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None, debug=[]):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
act_fn = get_activation(act)
self.attend_pitch = False
if self.attend_pitch:
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
self.mlp = nn.Sequential(
nn.Linear(dims, dims),
nn.ReLU(),
nn.Linear(dims, dims),
)
else:
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
self.mlp = None
self.encoder = nn.Sequential(
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, en=en, f=f, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x: Tensor, en=None, f=None, layer = None):
x = self.encoder(x).permute(0, 2, 1)
if self.use_rope:
x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
if self.mlp is not None:
x = self.mlp(x)
if self.attend_pitch:
xa = en["input_ids"]
if xa is not None:
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
out = self.o(out)
x = x + out
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.norm(x)
return x
class WEncoder(nn.Module):
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dropout = 0.01
self.use_rope = use_rope
self.dims = dims
self.debug = debug
act_fn = get_activation(act)
self.target_length = None
self.encoder = nn.Sequential(
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
if use_rope:
if spec_shape is not None:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, en=en, f=f, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x: Tensor, en= None, f=None, layer = None):
x = self.encoder(x).permute(0, 2, 1)
if self.target_length and x.shape[1] != self.target_length:
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
if self.use_rope:
x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.ln(x)
print(f"X: {x.shape} {f}") if "encoder" in self.debug else None
return self.norm(x)
class PEncoder(nn.Module):
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=True, debug=[], one_shot=False, spec_shape=None):
super().__init__()
self.head = head
self.head_dim = dims // head
self.dims = dims
self.dropout = 0.01
self.use_rope = use_rope
self.debug = debug
act_fn = get_activation(act)
self.encoder = nn.Sequential(
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
if use_rope:
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
else:
self.rope = None
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.norm = RMSNorm(dims)
def rope_to_feature(self, x, en=None, f="pitch", layer="PEncoder"):
batch, ctx, dims = x.shape
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
freqs = self.rope(ctx, en=en, f=f, layer=layer)
x = self.rope.apply_rotary(x, freqs)
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
return x
def forward(self, x: Tensor, en= None, f="pitch", layer="PEncoder"):
if x.dim() == 2:
x = x.unsqueeze(0)
x = self.encoder(x).permute(0, 2, 1)
if self.use_rope:
x = self.rope_to_feature(x, en=en, f=f, layer=layer)
else:
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = self.norm(x)
print(f"X: {x.shape} {f}") if "PEncoder" in self.debug else None
return x
class theBridge(nn.Module):
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
debug: List[str], features: List[str], act: str = "gelu"):
super(theBridge, self).__init__()
tgate = True
mgate = False
cgate = False
self.debug = debug
self.counter = 0
self.dropout = 0.01
self.features = features
self.do_blend = "no_blend" not in self.debug
self.sequential = "sequential" in self.debug
self.layer = layer
self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
self.norm = RMSNorm(dims)
self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, 10000)
self.rotary = rotary(dims=dims, head=head, debug=debug, radii=False)
with torch.no_grad():
self.token.weight[0].zero_()
act_fn = get_activation(act)
if features == ["spectrogram", "waveform", "pitch"]:
cgate=True
else:
cgate = False
self.blockA = nn.ModuleDict()
self.blockA["waveform"] = nn.ModuleList(
[WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
[Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
for _ in range(layer)] if "waveform" in features else None)
for feature_type in ["spectrogram", "aperiodic", "harmonic"]:
if feature_type in features:
self.blockA[feature_type] = nn.ModuleList(
[FEncoder(mels=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
[Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
else:
self.blockA[feature_type] = None
for feature_type in ["pitch", "phase"]:
if feature_type in features:
self.blockA[feature_type] = nn.ModuleList(
[PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act_fn)] +
[Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
else:
self.blockA[feature_type] = None
self.blockB = nn.ModuleList([
Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
for _ in range(layer)])
self.modal = nn.ModuleList([
Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
for _ in range(layer)])
mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
self.register_buffer("mask", mask, persistent=False)
self.norm = RMSNorm(dims)
def forward(self, x, xa, en, f, sequential=False) -> Tensor:
mask = self.mask[:x.shape[1], :x.shape[1]]
x = self.token(x.long()) + self.positional[:x.shape[1]]
out = {}
out["input_ids"] = x
out.update(en)
for b in chain(self.blockA[f] or []):
xa = b(x=xa, en=out, f=f, layer="en")
for b in chain(self.blockB or []):
x = b(x=x, xa=None, mask=mask, en=out, f=f, layer="dec")
y = b(x, xa=xa, mask=None, en=out, f=f, layer="cross")
if sequential:
x = y
else:
a = torch.sigmoid(self.blend)
x = a * y + (1 - a) * x
for b in self.modal:
xc = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None, en=out, f=f, layer="modal")
xm = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None, en=out, f=f, layer="modal")
if sequential:
x = xm
else:
a = torch.sigmoid(self.blend)
x = a * x + (1 - a) * xm
if self.counter < 1 and "encoder" in self.debug:
shapes = {k: v.shape for k, v in en.items()}
print(f"Step {self.counter}: mode: {list(en.keys()) }: shapes: {shapes}")
self.counter += 1
x = self.norm(x)
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
return x
class Echo(nn.Module):
def __init__(self, param: Dimensions):
super().__init__()
self.param = param
self.processor = theBridge(
vocab=param.vocab,
mels=param.mels,
ctx=param.ctx,
dims=param.dims,
head=param.head,
layer=param.layer,
features=param.features,
act=param.act,
debug=param.debug,
)
def forward(self,
labels=None,
input_ids=None,
waveform: Optional[torch.Tensor]=None,
spectrogram: Optional[torch.Tensor]=None,
pitch: Optional[torch.Tensor]=None,
f0: Optional[torch.Tensor]=None,
f0t: Optional[torch.Tensor]=None,
harmonic: Optional[torch.Tensor]=None,
aperiodic: Optional[torch.Tensor]=None,
phase: Optional[torch.Tensor]=None,
) -> Dict[str, Optional[torch.Tensor]]:
en= TensorDict(batch_size=[1], device=self.device, dtype=self.dtype)
en= {}
if f0 is not None:
en["f0"] = f0
if f0t is not None:
en["f0t"] = f0t
if harmonic is not None:
en["harmonic"] = harmonic
if aperiodic is not None:
en["aperiodic"] = aperiodic
if phase is not None:
en["phase"] = phase
if pitch is not None:
en["pitch"] = pitch
if waveform is not None:
en["waveform"] = waveform
if spectrogram is not None:
en["spectrogram"] = spectrogram
x = input_ids
for f, xa in en.items():
logits = self.processor(x, xa, en, f)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
return {"logits": logits, "loss": loss}
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def _init_weights(self, module):
std = 0.02
self.init_counts = {
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
"Conv2d": 0, "theBridge": 0, "Echo": 0,
"Residual": 0, "MultiheadA": 0,
"MultiheadC": 0, "MultiheadD": 0, "FEncoder": 0,
"WEncoder": 0, "PEncoder": 0}
for name, module in self.named_modules():
if isinstance(module, RMSNorm):
nn.init.ones_(module.weight)
self.init_counts["RMSNorm"] += 1
elif isinstance(module, nn.Linear):
if module.weight is not None:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Linear"] += 1
elif isinstance(module, Conv1d):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Conv1d"] += 1
elif isinstance(module, Conv2d):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Conv2d"] += 1
elif isinstance(module, MultiheadA):
self.init_counts["MultiheadA"] += 1
elif isinstance(module, Residual):
self.init_counts["Residual"] += 1
elif isinstance(module, PEncoder):
self.init_counts["PEncoder"] += 1
elif isinstance(module, FEncoder):
self.init_counts["FEncoder"] += 1
elif isinstance(module, WEncoder):
self.init_counts["WEncoder"] += 1
elif isinstance(module, theBridge):
self.init_counts["theBridge"] += 1
elif isinstance(module, Echo):
self.init_counts["Echo"] += 1
def init_weights(self):
print("Initializing model weights...")
self.apply(self._init_weights)
print("Initialization summary:")
for module_type, count in self.init_counts.items():
if count > 0:
print(f"{module_type}: {count}")
def generate(self, input_ids=None, spectrogram=None, waveform=None, pitch=None, f0=None,
envelope=None, phase=None, tokenizer=None, max_length=128, min_length=1, device=None, **kwargs):
if device is None:
device = self.device
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
bos_token_id = getattr(tokenizer, "bos_token_id", 1)
eos_token_id = getattr(tokenizer, "eos_token_id", 2)
batch_size = 1
for x in [spectrogram, waveform, pitch, f0, envelope, phase]:
if x is not None:
batch_size = x.shape[0]
break
ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
feature = {}
if spectrogram is not None:
feature["spectrogram"] = spectrogram
if waveform is not None:
feature["waveform"] = waveform
if pitch is not None:
feature["pitch"] = pitch
if envelope is not None:
feature["envelope"] = envelope
if phase is not None:
feature["phase"] = phase
if f0 is not None:
feature["f0"] = f0
for i in range(max_length - 1):
with torch.no_grad():
feature["input_ids"] = ids
logits = self.SpeechTransformer(feature)
next_token_logits = logits[:, -1, :]
if i < min_length:
next_token_logits[:, eos_token_id] = 0
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
ids = torch.cat([ids, next_tokens], dim=1)
if (next_tokens == eos_token_id).all() and i >= min_length:
break
return ids
@property
def config(self):
class Config:
pad_token_id = getattr(self.param, "pad_token_id", 0)
bos_token_id = getattr(self.param, "bos_token_id", 1)
eos_token_id = getattr(self.param, "eos_token_id", 2)
def to_json_string(self):
import json
return json.dumps({
"pad_token_id": self.pad_token_id,
"bos_token_id": self.bos_token_id,
"eos_token_id": self.eos_token_id,
})
return Config()
def main():
token = ""
log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S'))
os.makedirs(log_dir, exist_ok=True)
tokenizer = setup_tokenizer("./")
sanity_check = False
streaming = False
load_saved = False
save_dataset = False
cache_dir = None
extract_args = None
extract_args = {
"waveform": False,
"spec": True,
"f0": False,
"f0t": False,
"pitch": True,
"harmonics": False,
"aperiodics": False,
"phase_mod": False,
"crepe": False,
"sample_rate": 16000,
"hop_length": 256,
"mode": "mean",
"debug": False,
}
param = Dimensions(
vocab=40000,
mels=128,
ctx=2048,
dims=512,
head=4,
layer=4,
act="swish",
debug={"encoder"},
features = ["spectrogram", "pitch"],
)
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=sanity_check, sample_rate=16000, streaming=streaming,
load_saved=load_saved, save_dataset=save_dataset, cache_dir=cache_dir, extract_args=extract_args, max_ctx=param.ctx)
model = Echo(param).to('cuda')
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
from functools import partial
metrics_fn = partial(compute_metrics,
print_pred=True,
num_samples=1,
tokenizer=tokenizer, model=model)
if sanity_check:
training_args = Seq2SeqTrainingArguments(
output_dir=log_dir,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
max_steps=10,
eval_steps=5,
save_steps=0,
warmup_steps=0,
logging_steps=1,
logging_dir=log_dir,
eval_strategy="steps",
save_strategy="no",
logging_strategy="no",
report_to=["tensorboard"],
push_to_hub=False,
save_total_limit=1,
label_names=["labels"],
save_safetensors=False,
eval_on_start=True,
batch_eval_metrics=False,
disable_tqdm=False,
include_tokens_per_second=True,
include_num_input_tokens_seen=True,
learning_rate=1e-7,
weight_decay=0.01,
)
else:
training_args = Seq2SeqTrainingArguments(
output_dir=log_dir,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
max_steps=1000,
eval_steps=100,
save_steps=1000,
warmup_steps=100,
logging_steps=10,
logging_dir=log_dir,
logging_strategy="steps",
eval_strategy="steps",
save_strategy="no",
report_to=["tensorboard"],
push_to_hub=False,
save_total_limit=1,
label_names=["labels"],
save_safetensors=False,
eval_on_start=True,
batch_eval_metrics=False,
disable_tqdm=False,
include_tokens_per_second=True,
include_num_input_tokens_seen=True,
learning_rate=0.00025,
weight_decay=0.025,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999),
amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=DataCollator(tokenizer=tokenizer),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
compute_metrics=metrics_fn,
optimizers=(optimizer, scheduler)
)
model.init_weights()
trainer.train()
if __name__ == "__main__":
main()