Spaces:
Runtime error
Runtime error
from typing import Any | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from transformers import ( | |
AutoFeatureExtractor, | |
AutoModelForAudioClassification, | |
TrainingArguments, | |
Trainer, | |
ASTConfig, | |
ASTFeatureExtractor, | |
ASTForAudioClassification, | |
) | |
import torch | |
from torch import nn | |
from models.training_environment import TrainingEnvironment | |
from preprocessing.pipelines import WaveformTrainingPipeline | |
from preprocessing.dataset import ( | |
DanceDataModule, | |
HuggingFaceDatasetWrapper, | |
get_datasets, | |
) | |
from preprocessing.dataset import get_music4dance_examples | |
from .utils import get_id_label_mapping, compute_hf_metrics | |
import pytorch_lightning as pl | |
from pytorch_lightning import callbacks as cb | |
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
class AST(nn.Module): | |
def __init__(self, labels, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
id2label, label2id = get_id_label_mapping(labels) | |
config = ASTConfig( | |
hidden_size=300, | |
num_attention_heads=5, | |
num_hidden_layers=3, | |
id2label=id2label, | |
label2id=label2id, | |
num_labels=len(label2id), | |
ignore_mismatched_sizes=True, | |
) | |
self.model = ASTForAudioClassification(config) | |
def forward(self, x): | |
return self.model(x).logits | |
class ASTExtractorWrapper: | |
def __init__(self, sampling_rate=16000, return_tensors="pt") -> None: | |
self.extractor = ASTFeatureExtractor() | |
self.sampling_rate = sampling_rate | |
self.return_tensors = return_tensors | |
self.waveform_pipeline = WaveformTrainingPipeline() # TODO configure from yaml | |
def __call__(self, x) -> Any: | |
x = self.waveform_pipeline(x) | |
device = x.device | |
x = x.squeeze(0).numpy() | |
x = self.extractor( | |
x, return_tensors=self.return_tensors, sampling_rate=self.sampling_rate | |
) | |
return x["input_values"].squeeze(0).to(device) | |
def train_lightning_ast(config: dict): | |
""" | |
work on integration between waveform dataset and environment. Should work for both HF and PTL. | |
""" | |
TARGET_CLASSES = config["dance_ids"] | |
DEVICE = config["device"] | |
SEED = config["seed"] | |
pl.seed_everything(SEED, workers=True) | |
feature_extractor = ASTExtractorWrapper() | |
dataset = get_datasets(config["datasets"], feature_extractor) | |
data = DanceDataModule( | |
dataset, | |
target_classes=TARGET_CLASSES, | |
**config["data_module"], | |
) | |
model = AST(TARGET_CLASSES).to(DEVICE) | |
label_weights = data.get_label_weights().to(DEVICE) | |
criterion = nn.CrossEntropyLoss( | |
label_weights | |
) # LabelWeightedBCELoss(label_weights) | |
train_env = TrainingEnvironment(model, criterion, config) | |
callbacks = [ | |
# cb.LearningRateFinder(update_attr=True), | |
cb.EarlyStopping("val/loss", patience=5), | |
cb.RichProgressBar(), | |
] | |
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"]) | |
trainer.fit(train_env, datamodule=data) | |
trainer.test(train_env, datamodule=data) | |
def train_huggingface_ast(config: dict): | |
TARGET_CLASSES = config["dance_ids"] | |
DEVICE = config["device"] | |
SEED = config["seed"] | |
OUTPUT_DIR = "models/weights/ast" | |
batch_size = config["data_module"]["batch_size"] | |
epochs = config["data_module"]["min_epochs"] | |
test_proportion = config["data_module"].get("test_proportion", 0.2) | |
pl.seed_everything(SEED, workers=True) | |
dataset = get_datasets(config["datasets"]) | |
hf_dataset = HuggingFaceDatasetWrapper(dataset) | |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES) | |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
preprocess_waveform = lambda wf: feature_extractor( | |
wf, | |
sampling_rate=train_ds.resample_frequency, | |
# padding="max_length", | |
# return_tensors="pt", | |
) | |
hf_dataset.append_to_pipeline(preprocess_waveform) | |
test_proportion = config["data_module"]["test_proportion"] | |
train_proporition = 1 - test_proportion | |
train_ds, test_ds = torch.utils.data.random_split( | |
hf_dataset, [train_proporition, test_proportion] | |
) | |
model = AutoModelForAudioClassification.from_pretrained( | |
model_checkpoint, | |
num_labels=len(TARGET_CLASSES), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True, | |
).to(DEVICE) | |
training_args = TrainingArguments( | |
output_dir=OUTPUT_DIR, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=5e-5, | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=5, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=epochs, | |
warmup_ratio=0.1, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=False, | |
use_mps_device=DEVICE == "mps", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_ds, | |
eval_dataset=test_ds, | |
tokenizer=feature_extractor, | |
compute_metrics=compute_hf_metrics, | |
) | |
trainer.train() | |
return model | |