Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoFeatureExtractor, AutoModel | |
from utmosv2.dataset._utils import get_dataset_num | |
class _SSLEncoder(nn.Module): | |
def __init__(self, sr: int, model_name: str, freeze: bool): | |
super().__init__() | |
self.sr = sr | |
self.processor = AutoFeatureExtractor.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
if freeze: | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
x = self.processor( | |
[t.cpu().numpy() for t in x], | |
sampling_rate=self.sr, | |
return_tensors="pt", | |
).to(self.model.device) | |
outputs = self.model(**x, output_hidden_states=True) | |
return outputs.hidden_states | |
class SSLExtModel(nn.Module): | |
def __init__(self, cfg, name: str | None = None): | |
super().__init__() | |
self.cfg = cfg | |
self.encoder = _SSLEncoder( | |
cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze | |
) | |
hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name) | |
self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0)) | |
if cfg.model.ssl.attn: | |
self.attn = nn.ModuleList( | |
[ | |
nn.MultiheadAttention( | |
embed_dim=in_features, | |
num_heads=8, | |
dropout=0.2, | |
batch_first=True, | |
) | |
for _ in range(cfg.model.ssl.attn) | |
] | |
) | |
self.num_dataset = get_dataset_num(cfg) | |
self.fc = nn.Linear( | |
in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes | |
) | |
def forward(self, x, d): | |
x = self.encoder(x) | |
x = sum([t * w for t, w in zip(x, self.weights)]) | |
if self.cfg.model.ssl.attn: | |
y = x | |
for attn in self.attn: | |
y, _ = attn(y, y, y) | |
x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1) | |
else: | |
x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1) | |
x = self.fc(torch.cat([x, d], dim=1)) | |
return x | |
def get_ssl_output_shape(name: str) -> tuple[int, int]: | |
if name in [ | |
"facebook/w2v-bert-2.0", | |
"facebook/wav2vec2-large", | |
"facebook/wav2vec2-large-robust", | |
"facebook/wav2vec2-large-960h", | |
"microsoft/wavlm-large", | |
"facebook/wav2vec2-large-xlsr-53", | |
]: | |
return 25, 1024 | |
elif name in [ | |
"facebook/hubert-base-ls960", | |
"facebook/data2vec-audio-base-960h", | |
"microsoft/wavlm-base", | |
"microsoft/wavlm-base-plus", | |
"microsoft/wavlm-base-plus-sv", | |
"facebook/wav2vec2-base", | |
]: | |
return 13, 768 | |
else: | |
raise NotImplementedError | |