|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch import nn |
|
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel |
|
|
|
class BirdAST(nn.Module): |
|
|
|
def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'): |
|
super(BirdAST, self).__init__() |
|
|
|
|
|
backbone_config = ASTConfig.from_pretrained(backbone_name) |
|
self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config) |
|
self.hidden_size = backbone_config.hidden_size |
|
|
|
|
|
if activation == 'relu': |
|
self.activation = nn.ReLU() |
|
elif activation == 'silu': |
|
self.activation = nn.SiLU() |
|
else: |
|
raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.") |
|
|
|
|
|
layers = [] |
|
for _ in range(n_mlp_layers): |
|
layers.append(nn.Linear(self.hidden_size, self.hidden_size)) |
|
layers.append(self.activation) |
|
layers.append(nn.Linear(self.hidden_size, n_classes)) |
|
self.mlp = nn.Sequential(*layers) |
|
|
|
def forward(self, spectrogram): |
|
|
|
|
|
|
|
ast_output = self.ast(spectrogram, output_hidden_states=False) |
|
logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) |
|
|
|
return {'logits': logits} |