File size: 1,527 Bytes
d967830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel

class BirdAST(nn.Module):
    
    def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
        super(BirdAST, self).__init__()
        
        # pre-trained backbone
        backbone_config = ASTConfig.from_pretrained(backbone_name)
        self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
        self.hidden_size = backbone_config.hidden_size
        
        # set activation functions
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        else:
            raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.")
        
        # define MLP layers with activation
        layers = []
        for _ in range(n_mlp_layers):
            layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            layers.append(self.activation)
        layers.append(nn.Linear(self.hidden_size, n_classes))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, spectrogram):
        # spectrogram: (batch_size, n_mels, n_frames)
        # output: (batch_size, n_classes)
        
        ast_output = self.ast(spectrogram, output_hidden_states=False)
        logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token 
        
        return {'logits': logits}