Spaces:
Runtime error
Runtime error
File size: 2,872 Bytes
557fb53 d78d6d1 557fb53 d78d6d1 557fb53 d78d6d1 557fb53 d78d6d1 557fb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
from typing import Any
import pytorch_lightning as pl
from torch.utils.data import random_split
from transformers import AutoFeatureExtractor
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
from preprocessing.dataset import (
HuggingFaceDatasetWrapper,
BestBallroomDataset,
get_datasets,
)
from preprocessing.pipelines import WaveformTrainingPipeline
from .utils import get_id_label_mapping, compute_hf_metrics
MODEL_CHECKPOINT = "facebook/wav2vec2-base"
class Wav2VecFeatureExtractor:
def __init__(self) -> None:
self.waveform_pipeline = WaveformTrainingPipeline()
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
MODEL_CHECKPOINT,
)
def __call__(self, waveform) -> Any:
waveform = self.waveform_pipeline(waveform)
return self.feature_extractor(
waveform.squeeze(0), sampling_rate=self.feature_extractor.sampling_rate
)
def __getattr__(self, attr):
return getattr(self.feature_extractor, attr)
def train_huggingface(config: dict):
TARGET_CLASSES = config["dance_ids"]
DEVICE = config["device"]
SEED = config["seed"]
OUTPUT_DIR = "models/weights/wav2vec2"
batch_size = config["data_module"]["batch_size"]
epochs = config["trainer"]["min_epochs"]
test_proportion = config["data_module"].get("test_proportion", 0.2)
pl.seed_everything(SEED, workers=True)
feature_extractor = Wav2VecFeatureExtractor()
dataset = get_datasets(config["datasets"], feature_extractor)
dataset = HuggingFaceDatasetWrapper(dataset)
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
test_proportion = config["data_module"]["test_proportion"]
train_proporition = 1 - test_proportion
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
model = AutoModelForAudioClassification.from_pretrained(
MODEL_CHECKPOINT,
num_labels=len(TARGET_CLASSES),
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
).to(DEVICE)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=5,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
use_mps_device=DEVICE == "mps",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
compute_metrics=compute_hf_metrics,
)
trainer.train()
return model
|