File size: 2,872 Bytes
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d78d6d1
557fb53
 
 
 
 
 
d78d6d1
557fb53
 
 
 
 
 
 
 
d78d6d1
 
 
557fb53
 
 
 
d78d6d1
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from typing import Any
import pytorch_lightning as pl
from torch.utils.data import random_split
from transformers import AutoFeatureExtractor
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

from preprocessing.dataset import (
    HuggingFaceDatasetWrapper,
    BestBallroomDataset,
    get_datasets,
)
from preprocessing.pipelines import WaveformTrainingPipeline

from .utils import get_id_label_mapping, compute_hf_metrics

MODEL_CHECKPOINT = "facebook/wav2vec2-base"


class Wav2VecFeatureExtractor:
    def __init__(self) -> None:
        self.waveform_pipeline = WaveformTrainingPipeline()
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(
            MODEL_CHECKPOINT,
        )

    def __call__(self, waveform) -> Any:
        waveform = self.waveform_pipeline(waveform)
        return self.feature_extractor(
            waveform.squeeze(0), sampling_rate=self.feature_extractor.sampling_rate
        )

    def __getattr__(self, attr):
        return getattr(self.feature_extractor, attr)


def train_huggingface(config: dict):
    TARGET_CLASSES = config["dance_ids"]
    DEVICE = config["device"]
    SEED = config["seed"]
    OUTPUT_DIR = "models/weights/wav2vec2"
    batch_size = config["data_module"]["batch_size"]
    epochs = config["trainer"]["min_epochs"]
    test_proportion = config["data_module"].get("test_proportion", 0.2)
    pl.seed_everything(SEED, workers=True)
    feature_extractor = Wav2VecFeatureExtractor()
    dataset = get_datasets(config["datasets"], feature_extractor)
    dataset = HuggingFaceDatasetWrapper(dataset)
    id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
    test_proportion = config["data_module"]["test_proportion"]
    train_proporition = 1 - test_proportion
    train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])

    model = AutoModelForAudioClassification.from_pretrained(
        MODEL_CHECKPOINT,
        num_labels=len(TARGET_CLASSES),
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,
    ).to(DEVICE)
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=5,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        warmup_ratio=0.1,
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        push_to_hub=False,
        use_mps_device=DEVICE == "mps",
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        compute_metrics=compute_hf_metrics,
    )
    trainer.train()
    return model