Spaces:
Runtime error
Runtime error
from torch.utils.data import DataLoader | |
import pandas as pd | |
from typing import Callable | |
from torch import nn | |
from torch.utils.data import SubsetRandomSampler | |
from sklearn.model_selection import KFold | |
import pytorch_lightning as pl | |
from pytorch_lightning import callbacks as cb | |
from models.utils import LabelWeightedBCELoss | |
from models.audio_spectrogram_transformer import ( | |
train as train_audio_spectrogram_transformer, | |
get_id_label_mapping, | |
) | |
from preprocessing.dataset import SongDataset, WaveformTrainingEnvironment | |
from preprocessing.preprocess import get_examples | |
from models.residual import ResidualDancer, TrainingEnvironment | |
from models.decision_tree import DanceTreeClassifier, features_from_path | |
import yaml | |
from preprocessing.dataset import ( | |
DanceDataModule, | |
WaveformSongDataset, | |
HuggingFaceWaveformSongDataset, | |
) | |
from torch.utils.data import random_split | |
import numpy as np | |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification | |
from argparse import ArgumentParser | |
import torch | |
from torch import nn | |
from sklearn.utils.class_weight import compute_class_weight | |
def get_training_fn(id: str) -> Callable: | |
match id: | |
case "ast_ptl": | |
return train_ast_lightning | |
case "ast_hf": | |
return train_ast | |
case "residual_dancer": | |
return train_model | |
case "decision_tree": | |
return train_decision_tree | |
case _: | |
raise Exception(f"Couldn't find a training function for '{id}'.") | |
def get_config(filepath: str) -> dict: | |
with open(filepath, "r") as f: | |
config = yaml.safe_load(f) | |
return config | |
def cross_validation(config, k=5): | |
df = pd.read_csv("data/songs.csv") | |
g_config = config["global"] | |
batch_size = config["data_module"]["batch_size"] | |
x, y = get_examples(df, "data/samples", class_list=g_config["dance_ids"]) | |
dataset = SongDataset(x, y) | |
splits = KFold(n_splits=k, shuffle=True, random_state=g_config["seed"]) | |
trainer = pl.Trainer(accelerator=g_config["device"]) | |
for fold, (train_idx, val_idx) in enumerate(splits.split(x, y)): | |
print(f"Fold {fold+1}") | |
model = ResidualDancer(n_classes=len(g_config["dance_ids"])) | |
train_env = TrainingEnvironment(model, nn.BCELoss()) | |
train_sampler = SubsetRandomSampler(train_idx) | |
test_sampler = SubsetRandomSampler(val_idx) | |
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) | |
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) | |
trainer.fit(train_env, train_loader) | |
trainer.test(train_env, test_loader) | |
def train_model(config: dict): | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
pl.seed_everything(SEED, workers=True) | |
data = DanceDataModule(target_classes=TARGET_CLASSES, **config["data_module"]) | |
model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"]) | |
label_weights = data.get_label_weights().to(DEVICE) | |
criterion = LabelWeightedBCELoss( | |
label_weights | |
) # nn.CrossEntropyLoss(label_weights) | |
train_env = TrainingEnvironment(model, criterion, config) | |
callbacks = [ | |
# cb.LearningRateFinder(update_attr=True), | |
cb.EarlyStopping("val/loss", patience=5), | |
cb.StochasticWeightAveraging(1e-2), | |
cb.RichProgressBar(), | |
cb.DeviceStatsMonitor(), | |
] | |
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"]) | |
trainer.fit(train_env, datamodule=data) | |
trainer.test(train_env, datamodule=data) | |
def train_ast(config: dict): | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
dataset_kwargs = config["data_module"]["dataset_kwargs"] | |
test_proportion = config["data_module"].get("test_proportion", 0.2) | |
train_proportion = 1.0 - test_proportion | |
song_data_path = "data/songs_cleaned.csv" | |
song_audio_path = "data/samples" | |
pl.seed_everything(SEED, workers=True) | |
df = pd.read_csv(song_data_path) | |
x, y = get_examples( | |
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True | |
) | |
train_i, test_i = random_split( | |
np.arange(len(x)), [train_proportion, test_proportion] | |
) | |
train_ds = HuggingFaceWaveformSongDataset( | |
x[train_i], y[train_i], **dataset_kwargs, resample_frequency=16000 | |
) | |
test_ds = HuggingFaceWaveformSongDataset( | |
x[test_i], y[test_i], **dataset_kwargs, resample_frequency=16000 | |
) | |
train_audio_spectrogram_transformer( | |
TARGET_CLASSES, train_ds, test_ds, device=DEVICE | |
) | |
def train_ast_lightning(config: dict): | |
""" | |
work on integration between waveform dataset and environment. Should work for both HF and PTL. | |
""" | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
pl.seed_everything(SEED, workers=True) | |
data = DanceDataModule( | |
target_classes=TARGET_CLASSES, | |
dataset_cls=WaveformSongDataset, | |
**config["data_module"], | |
) | |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES) | |
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
model = AutoModelForAudioClassification.from_pretrained( | |
model_checkpoint, | |
num_labels=len(label2id), | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True, | |
).to(DEVICE) | |
label_weights = data.get_label_weights().to(DEVICE) | |
criterion = LabelWeightedBCELoss( | |
label_weights | |
) # nn.CrossEntropyLoss(label_weights) | |
train_env = WaveformTrainingEnvironment(model, criterion, feature_extractor, config) | |
callbacks = [ | |
# cb.LearningRateFinder(update_attr=True), | |
cb.EarlyStopping("val/loss", patience=5), | |
cb.StochasticWeightAveraging(1e-2), | |
cb.RichProgressBar(), | |
] | |
trainer = pl.Trainer(callbacks=callbacks, **config["trainer"]) | |
trainer.fit(train_env, datamodule=data) | |
trainer.test(train_env, datamodule=data) | |
def train_decision_tree(config: dict): | |
TARGET_CLASSES = config["global"]["dance_ids"] | |
DEVICE = config["global"]["device"] | |
SEED = config["global"]["seed"] | |
song_data_path = config["data_module"]["song_data_path"] | |
song_audio_path = config["data_module"]["song_audio_path"] | |
pl.seed_everything(SEED, workers=True) | |
df = pd.read_csv(song_data_path) | |
x, y = get_examples( | |
df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True | |
) | |
# Convert y back to string classes | |
y = np.array(TARGET_CLASSES)[y.argmax(-1)] | |
train_i, test_i = random_split(np.arange(len(x)), [0.8, 0.2]) | |
train_paths, train_y = x[train_i], y[train_i] | |
train_x = features_from_path(train_paths) | |
model = DanceTreeClassifier(device=DEVICE) | |
model.fit(train_x, train_y) | |
model.save() | |
if __name__ == "__main__": | |
parser = ArgumentParser( | |
description="Trains models on the dance dataset and saves weights." | |
) | |
parser.add_argument( | |
"--config", | |
help="Path to the yaml file that defines the training configuration.", | |
default="models/config/train_local.yaml", | |
) | |
args = parser.parse_args() | |
config = get_config(args.config) | |
training_id = config["global"]["id"] | |
train = get_training_fn(training_id) | |
train(config) | |