File size: 5,298 Bytes
dddb9f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
from torch import nn
import torch.nn.functional as F
from transformers import ASTConfig, ASTFeatureExtractor, ASTModel
BirdAST_FEATURE_EXTRACTOR = ASTFeatureExtractor()
DEFAULT_SR = 16_000
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_N_CLASSES = 728
DEFAULT_ACTIVATION = "silu"
DEFAULT_N_MLP_LAYERS = 1
def birdast_preprocess(audio_array, sr=DEFAULT_SR):
"""
Preprocess audio array for BirdAST model
audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
sr: int, sampling rate of the audio array (default: 16_000)
Note:
1. The audio array should be normalized to [-1, 1].
2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
"""
# Extract features
features = BirdAST_FEATURE_EXTRACTOR(audio_array, sampling_rate=sr, padding="max_length", return_tensors="pt")
# Convert to PyTorch tensor
spectrogram = torch.tensor(features['input_values']).squeeze(0)
return spectrogram
def birdast_inference(
model_weights,
spectrogram,
device = 'cpu',
backbone_name=DEFAULT_BACKBONE,
n_classes=DEFAULT_N_CLASSES,
activation=DEFAULT_ACTIVATION,
n_mlp_layers=DEFAULT_N_MLP_LAYERS
):
"""
Perform inference on BirdAST model
model_weights: list, list of model weights
spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
device: str, device to run inference (default: 'cpu')
backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
n_classes: int, number of classes (default: 728)
activation: str, activation function (default: 'silu')
n_mlp_layers: int, number of MLP layers (default: 1)
Returns:
predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
"""
model = BirdAST(
backbone_name=backbone_name,
n_classes=n_classes,
n_mlp_layers=n_mlp_layers,
activation=activation
)
predict_collects = []
for _weights in model_weights:
model.load_state_dict(torch.load(_weights, map_location=device))
if device != 'cpu': model.to(device)
model.eval()
with torch.no_grad():
if device != 'cpu': spectrogram = spectrogram.to(device)
# check if the input tensor is in the correct shape
if spectrogram.dim() == 2:
spectrogram = spectrogram.unsqueeze(0) # -> (batch_size, n_frames, n_mels)
output = model(spectrogram)
logits = output['logits']
probs = F.softmax(logits, dim=-1)
predict_collects.append(probs)
if device != 'cpu':
predict_collects = [pred.cpu() for pred in predict_collects]
predict_collects = torch.cat(predict_collects, dim=0).numpy()
return predict_collects
class BirdAST(nn.Module):
def __init__(self, backbone_name, n_classes, n_mlp_layers=1, activation='silu'):
super(BirdAST, self).__init__()
# pre-trained backbone
backbone_config = ASTConfig.from_pretrained(backbone_name)
self.ast = ASTModel.from_pretrained(backbone_name, config=backbone_config)
self.hidden_size = backbone_config.hidden_size
# set activation functions
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'silu':
self.activation = nn.SiLU()
else:
raise ValueError("Unsupported activation function. Choose 'relu' or 'silu'.")
# define MLP layers with activation
layers = []
for _ in range(n_mlp_layers):
layers.append(nn.Linear(self.hidden_size, self.hidden_size))
layers.append(self.activation)
layers.append(nn.Linear(self.hidden_size, n_classes))
self.mlp = nn.Sequential(*layers)
def forward(self, spectrogram):
# spectrogram: (batch_size, n_frames, n_mels)
# output: (batch_size, n_classes)
ast_output = self.ast(spectrogram, output_hidden_states=False)
logits = self.mlp(ast_output.last_hidden_state[:, 0, :]) # Use the CLS token
return {'logits': logits}
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
# example usage of BirdAST_Seq
# create random audio array
audio_array = np.random.randn(160_000 * 10)
# Preprocess audio array
spectrogram = birdast_preprocess(audio_array)
model_weights_dir = '/workspace/voice_of_jungle/training_logs'
# Load model weights
model_weights = [f'{model_weights_dir}/BirdAST_Baseline_GroupKFold_fold_{i}.pth' for i in range(5)]
# Perform inference
predictions = birdast_inference(model_weights, spectrogram.unsqueeze(0))
# Plot predictions
fig, ax = plt.subplots()
for i, pred in enumerate(predictions):
ax.plot(pred[0], label=f'model_{i}')
ax.legend()
fig.savefig('test_BirdAST_Seq.png')
print("Inference completed successfully!") |