Spaces:
Runtime error
Runtime error
from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer | |
import torch | |
from torch import nn | |
from sklearn.utils.class_weight import compute_class_weight | |
import evaluate | |
import numpy as np | |
accuracy = evaluate.load("accuracy") | |
def compute_metrics(eval_pred): | |
predictions = np.argmax(eval_pred.predictions, axis=1) | |
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids) | |
def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]: | |
id2label = {str(i) : label for i, label in enumerate(labels)} | |
label2id = {label : str(i) for i, label in enumerate(labels)} | |
return id2label, label2id | |
def train( | |
labels, | |
train_ds, | |
test_ds, | |
output_dir="models/weights/ast", | |
device="cpu", | |
batch_size=128, | |
epochs=10): | |
id2label, label2id = get_id_label_mapping(labels) | |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt") | |
train_ds.map(preprocess_waveform) | |
test_ds.map(preprocess_waveform) | |
model = AutoModelForAudioClassification.from_pretrained( | |
model_checkpoint, | |
num_labels=len(labels), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True | |
).to(device) | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=5e-5, | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=5, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=epochs, | |
warmup_ratio=0.1, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
use_mps_device=device == "mps" | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=test_ds, | |
tokenizer=feature_extractor, | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
return model | |