voj / model.py
amroa's picture
add files
d967830
raw
history blame
1.53 kB
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel
class BirdAST(nn.Module):
def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
super(BirdAST, self).__init__()
# pre-trained backbone
backbone_config = ASTConfig.from_pretrained(backbone_name)
self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
self.hidden_size = backbone_config.hidden_size
# set activation functions
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'silu':
self.activation = nn.SiLU()
else:
raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.")
# define MLP layers with activation
layers = []
for _ in range(n_mlp_layers):
layers.append(nn.Linear(self.hidden_size, self.hidden_size))
layers.append(self.activation)
layers.append(nn.Linear(self.hidden_size, n_classes))
self.mlp = nn.Sequential(*layers)
def forward(self, spectrogram):
# spectrogram: (batch_size, n_mels, n_frames)
# output: (batch_size, n_classes)
ast_output = self.ast(spectrogram, output_hidden_states=False)
logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token
return {'logits': logits}