kAIto47802
Resolved conflict in README.md
b55d767
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, AutoModel
from utmosv2.dataset._utils import get_dataset_num
class _SSLEncoder(nn.Module):
def __init__(self, sr: int, model_name: str, freeze: bool):
super().__init__()
self.sr = sr
self.processor = AutoFeatureExtractor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
if freeze:
for param in self.model.parameters():
param.requires_grad = False
def forward(self, x):
x = self.processor(
[t.cpu().numpy() for t in x],
sampling_rate=self.sr,
return_tensors="pt",
).to(self.model.device)
outputs = self.model(**x, output_hidden_states=True)
return outputs.hidden_states
class SSLExtModel(nn.Module):
def __init__(self, cfg, name: str | None = None):
super().__init__()
self.cfg = cfg
self.encoder = _SSLEncoder(
cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze
)
hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name)
self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0))
if cfg.model.ssl.attn:
self.attn = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=in_features,
num_heads=8,
dropout=0.2,
batch_first=True,
)
for _ in range(cfg.model.ssl.attn)
]
)
self.num_dataset = get_dataset_num(cfg)
self.fc = nn.Linear(
in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes
)
def forward(self, x, d):
x = self.encoder(x)
x = sum([t * w for t, w in zip(x, self.weights)])
if self.cfg.model.ssl.attn:
y = x
for attn in self.attn:
y, _ = attn(y, y, y)
x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1)
else:
x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1)
x = self.fc(torch.cat([x, d], dim=1))
return x
def get_ssl_output_shape(name: str) -> tuple[int, int]:
if name in [
"facebook/w2v-bert-2.0",
"facebook/wav2vec2-large",
"facebook/wav2vec2-large-robust",
"facebook/wav2vec2-large-960h",
"microsoft/wavlm-large",
"facebook/wav2vec2-large-xlsr-53",
]:
return 25, 1024
elif name in [
"facebook/hubert-base-ls960",
"facebook/data2vec-audio-base-960h",
"microsoft/wavlm-base",
"microsoft/wavlm-base-plus",
"microsoft/wavlm-base-plus-sv",
"facebook/wav2vec2-base",
]:
return 13, 768
else:
raise NotImplementedError