Spaces:
Runtime error
Runtime error
File size: 2,308 Bytes
e6fd727 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification, TrainingArguments, Trainer
import torch
from torch import nn
from sklearn.utils.class_weight import compute_class_weight
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
def get_id_label_mapping(labels:list[str]) -> tuple[dict, dict]:
id2label = {str(i) : label for i, label in enumerate(labels)}
label2id = {label : str(i) for i, label in enumerate(labels)}
return id2label, label2id
def train(
labels,
train_ds,
test_ds,
output_dir="models/weights/ast",
device="cpu",
batch_size=128,
epochs=10):
id2label, label2id = get_id_label_mapping(labels)
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
preprocess_waveform = lambda wf : feature_extractor(wf, sampling_rate=train_ds.resample_frequency, padding="max_length", return_tensors="pt")
train_ds.map(preprocess_waveform)
test_ds.map(preprocess_waveform)
model = AutoModelForAudioClassification.from_pretrained(
model_checkpoint,
num_labels=len(labels),
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True
).to(device)
training_args = TrainingArguments(
output_dir=output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=5,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
use_mps_device=device == "mps"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=feature_extractor,
compute_metrics=compute_metrics,
)
trainer.train()
return model
|