Spaces:
Runtime error
Runtime error
from transformers import ASTModel, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer | |
import torch | |
from torch import nn | |
from sklearn.utils.class_weight import compute_class_weight | |
import evaluate | |
import numpy as np | |
accuracy = evaluate.load("accuracy") | |
class MultiModalAST(nn.Module): | |
def __init__(self, labels, sample_rate, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
id2label, label2id = get_id_label_mapping(labels) | |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
self.ast_feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
self.ast_model = ASTModel.from_pretrained( | |
model_checkpoint, | |
num_labels=len(label2id), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True | |
) | |
self.sample_rate = sample_rate | |
self.bpm_model = nn.Sequential( | |
nn.Linear(len(labels), 100), | |
nn.Linear(100, 50) | |
) | |
out_dim = 50 # TODO: Calculate output dimension | |
self.classifier = nn.Sequential( | |
nn.Linear(out_dim, 100), | |
nn.Linear(100, len(labels)) | |
) | |
def vectorize_bpm(self, waveform): | |
pass | |
def forward(self, audio): | |
bpm_vector = self.vectorize_bpm(audio) | |
bpm_out = self.bpm_model(bpm_vector) | |
spectrogram = self.ast_feature_extractor(audio) | |
ast_out = self.ast_model(spectrogram) | |
# Late fusion | |
z = torch.cat([ast_out, bpm_out]) # Which dimension? | |
return self.classifier(z) | |
def compute_metrics(eval_pred): | |
predictions = np.argmax(eval_pred.predictions, axis=1) | |
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids) | |
def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]: | |
id2label = {str(i) : label for i, label in enumerate(labels)} | |
label2id = {label : str(i) for i, label in enumerate(labels)} | |
return id2label, label2id | |
def train( | |
labels, | |
train_ds, | |
test_ds, | |
output_dir="models/weights/ast", | |
device="cpu", | |
batch_size=128, | |
epochs=10): | |
id2label, label2id = get_id_label_mapping(labels) | |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt") | |
train_ds.map(preprocess_waveform) | |
test_ds.map(preprocess_waveform) | |
model = AutoModelForAudioClassification.from_pretrained( | |
model_checkpoint, | |
num_labels=len(labels), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True | |
).to(device) | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=5e-5, | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=5, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=epochs, | |
warmup_ratio=0.1, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
use_mps_device=device == "mps" | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=test_ds, | |
tokenizer=feature_extractor, | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
return model | |