Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,105 Bytes
812b01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from huggingface_hub import PyTorchModelHubMixin
from .config import (
N_MELS,
CNN_CH,
N_HEADS,
D_MODEL,
FF_DIM,
N_LAYERS,
DROPOUT,
DEPTHWISE_CONV_KERNEL_SIZE,
HIDDEN_DIM,
DEVICE,
)
class TaikoConformer5(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# 1) CNN frontend: frequency-only pooling
self.cnn = nn.Sequential(
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
)
feat_dim = CNN_CH * (N_MELS // 4)
# 2) Linear projection to model dimension
self.proj = nn.Linear(feat_dim, D_MODEL)
# 3) FiLM conditioning for notes_per_second
self.film = nn.Linear(1, 2 * D_MODEL)
# 4) Conformer encoder
self.encoder = Conformer(
input_dim=D_MODEL,
num_heads=N_HEADS,
ffn_dim=FF_DIM,
num_layers=N_LAYERS,
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
dropout=DROPOUT,
use_group_norm=False,
convolution_first=False,
)
# 5) Presence regressor head
self.presence_regressor = nn.Sequential(
nn.Dropout(DROPOUT),
nn.Linear(D_MODEL, HIDDEN_DIM),
nn.GELU(),
nn.Dropout(DROPOUT),
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
nn.Sigmoid(), # Output between 0 and 1
)
# 6) Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor
):
"""
Args:
mel: (B, 1, N_MELS, T_mel)
lengths: (B,) lengths after CNN
notes_per_second: (B,) stream of control values
Returns:
Dict with:
'presence': (B, T_cnn_out, 4)
'lengths': lengths
"""
# CNN frontend
x = self.cnn(mel) # (B, C, F, T)
B, C, F, T = x.size()
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
# Project to model dimension
x = self.proj(x) # (B, T, D_MODEL)
# FiLM conditioning
nps = notes_per_second.unsqueeze(-1) # (B, 1)
gamma_beta = self.film(nps) # (B, 2*D_MODEL)
gamma, beta = gamma_beta.chunk(2, dim=-1)
x = gamma.unsqueeze(1) * x + beta.unsqueeze(1)
# Conformer encoder
x, _ = self.encoder(x, lengths=lengths)
# Presence prediction
presence = self.presence_regressor(x)
return {"presence": presence, "lengths": lengths}
if __name__ == "__main__":
model = TaikoConformer5().to(device=DEVICE)
print(model)
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.numel():,}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {params / 1e6:.2f}M")
batch_size = 4
mel_time_steps = 1024
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
conformer_lengths = torch.tensor(
[mel_time_steps] * batch_size, dtype=torch.long
).to(DEVICE)
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
DEVICE
)
output = model(input_mel, conformer_lengths, notes_per_second_input)
print("Output shapes:")
for key, value in output.items():
print(f"{key}: {value.shape}")
|