File size: 3,006 Bytes
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, AutoModel

from utmosv2.dataset._utils import get_dataset_num


class _SSLEncoder(nn.Module):
    def __init__(self, sr: int, model_name: str, freeze: bool):
        super().__init__()
        self.sr = sr
        self.processor = AutoFeatureExtractor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.processor(
            [t.cpu().numpy() for t in x],
            sampling_rate=self.sr,
            return_tensors="pt",
        ).to(self.model.device)
        outputs = self.model(**x, output_hidden_states=True)
        return outputs.hidden_states


class SSLExtModel(nn.Module):
    def __init__(self, cfg, name: str | None = None):
        super().__init__()
        self.cfg = cfg
        self.encoder = _SSLEncoder(
            cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze
        )
        hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name)
        self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0))
        if cfg.model.ssl.attn:
            self.attn = nn.ModuleList(
                [
                    nn.MultiheadAttention(
                        embed_dim=in_features,
                        num_heads=8,
                        dropout=0.2,
                        batch_first=True,
                    )
                    for _ in range(cfg.model.ssl.attn)
                ]
            )
        self.num_dataset = get_dataset_num(cfg)
        self.fc = nn.Linear(
            in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes
        )

    def forward(self, x, d):
        x = self.encoder(x)
        x = sum([t * w for t, w in zip(x, self.weights)])
        if self.cfg.model.ssl.attn:
            y = x
            for attn in self.attn:
                y, _ = attn(y, y, y)
            x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1)
        else:
            x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1)
        x = self.fc(torch.cat([x, d], dim=1))
        return x


def get_ssl_output_shape(name: str) -> tuple[int, int]:
    if name in [
        "facebook/w2v-bert-2.0",
        "facebook/wav2vec2-large",
        "facebook/wav2vec2-large-robust",
        "facebook/wav2vec2-large-960h",
        "microsoft/wavlm-large",
        "facebook/wav2vec2-large-xlsr-53",
    ]:
        return 25, 1024
    elif name in [
        "facebook/hubert-base-ls960",
        "facebook/data2vec-audio-base-960h",
        "microsoft/wavlm-base",
        "microsoft/wavlm-base-plus",
        "microsoft/wavlm-base-plus-sv",
        "facebook/wav2vec2-base",
    ]:
        return 13, 768
    else:
        raise NotImplementedError