Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
·
d78d6d1
1
Parent(s):
0a2992f
fixed hf training loop for wav2vec
Browse files- models/wav2vec2.py +6 -5
models/wav2vec2.py
CHANGED
@@ -27,14 +27,14 @@ class Wav2VecFeatureExtractor:
|
|
27 |
def __call__(self, waveform) -> Any:
|
28 |
waveform = self.waveform_pipeline(waveform)
|
29 |
return self.feature_extractor(
|
30 |
-
waveform, sampling_rate=self.feature_extractor.sampling_rate
|
31 |
)
|
32 |
|
33 |
def __getattr__(self, attr):
|
34 |
return getattr(self.feature_extractor, attr)
|
35 |
|
36 |
|
37 |
-
def
|
38 |
TARGET_CLASSES = config["dance_ids"]
|
39 |
DEVICE = config["device"]
|
40 |
SEED = config["seed"]
|
@@ -43,12 +43,14 @@ def train_wav_model(config: dict):
|
|
43 |
epochs = config["trainer"]["min_epochs"]
|
44 |
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
45 |
pl.seed_everything(SEED, workers=True)
|
46 |
-
|
|
|
|
|
47 |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
48 |
test_proportion = config["data_module"]["test_proportion"]
|
49 |
train_proporition = 1 - test_proportion
|
50 |
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
|
51 |
-
|
52 |
model = AutoModelForAudioClassification.from_pretrained(
|
53 |
MODEL_CHECKPOINT,
|
54 |
num_labels=len(TARGET_CLASSES),
|
@@ -77,7 +79,6 @@ def train_wav_model(config: dict):
|
|
77 |
args=training_args,
|
78 |
train_dataset=train_ds,
|
79 |
eval_dataset=test_ds,
|
80 |
-
tokenizer=feature_extractor,
|
81 |
compute_metrics=compute_hf_metrics,
|
82 |
)
|
83 |
trainer.train()
|
|
|
27 |
def __call__(self, waveform) -> Any:
|
28 |
waveform = self.waveform_pipeline(waveform)
|
29 |
return self.feature_extractor(
|
30 |
+
waveform.squeeze(0), sampling_rate=self.feature_extractor.sampling_rate
|
31 |
)
|
32 |
|
33 |
def __getattr__(self, attr):
|
34 |
return getattr(self.feature_extractor, attr)
|
35 |
|
36 |
|
37 |
+
def train_huggingface(config: dict):
|
38 |
TARGET_CLASSES = config["dance_ids"]
|
39 |
DEVICE = config["device"]
|
40 |
SEED = config["seed"]
|
|
|
43 |
epochs = config["trainer"]["min_epochs"]
|
44 |
test_proportion = config["data_module"].get("test_proportion", 0.2)
|
45 |
pl.seed_everything(SEED, workers=True)
|
46 |
+
feature_extractor = Wav2VecFeatureExtractor()
|
47 |
+
dataset = get_datasets(config["datasets"], feature_extractor)
|
48 |
+
dataset = HuggingFaceDatasetWrapper(dataset)
|
49 |
id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
|
50 |
test_proportion = config["data_module"]["test_proportion"]
|
51 |
train_proporition = 1 - test_proportion
|
52 |
train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
|
53 |
+
|
54 |
model = AutoModelForAudioClassification.from_pretrained(
|
55 |
MODEL_CHECKPOINT,
|
56 |
num_labels=len(TARGET_CLASSES),
|
|
|
79 |
args=training_args,
|
80 |
train_dataset=train_ds,
|
81 |
eval_dataset=test_ds,
|
|
|
82 |
compute_metrics=compute_hf_metrics,
|
83 |
)
|
84 |
trainer.train()
|