waidhoferj commited on
Commit
d78d6d1
·
1 Parent(s): 0a2992f

fixed hf training loop for wav2vec

Browse files
Files changed (1) hide show
  1. models/wav2vec2.py +6 -5
models/wav2vec2.py CHANGED
@@ -27,14 +27,14 @@ class Wav2VecFeatureExtractor:
27
  def __call__(self, waveform) -> Any:
28
  waveform = self.waveform_pipeline(waveform)
29
  return self.feature_extractor(
30
- waveform, sampling_rate=self.feature_extractor.sampling_rate
31
  )
32
 
33
  def __getattr__(self, attr):
34
  return getattr(self.feature_extractor, attr)
35
 
36
 
37
- def train_wav_model(config: dict):
38
  TARGET_CLASSES = config["dance_ids"]
39
  DEVICE = config["device"]
40
  SEED = config["seed"]
@@ -43,12 +43,14 @@ def train_wav_model(config: dict):
43
  epochs = config["trainer"]["min_epochs"]
44
  test_proportion = config["data_module"].get("test_proportion", 0.2)
45
  pl.seed_everything(SEED, workers=True)
46
- dataset = get_datasets(config["datasets"])
 
 
47
  id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
48
  test_proportion = config["data_module"]["test_proportion"]
49
  train_proporition = 1 - test_proportion
50
  train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
51
- feature_extractor = Wav2VecFeatureExtractor()
52
  model = AutoModelForAudioClassification.from_pretrained(
53
  MODEL_CHECKPOINT,
54
  num_labels=len(TARGET_CLASSES),
@@ -77,7 +79,6 @@ def train_wav_model(config: dict):
77
  args=training_args,
78
  train_dataset=train_ds,
79
  eval_dataset=test_ds,
80
- tokenizer=feature_extractor,
81
  compute_metrics=compute_hf_metrics,
82
  )
83
  trainer.train()
 
27
  def __call__(self, waveform) -> Any:
28
  waveform = self.waveform_pipeline(waveform)
29
  return self.feature_extractor(
30
+ waveform.squeeze(0), sampling_rate=self.feature_extractor.sampling_rate
31
  )
32
 
33
  def __getattr__(self, attr):
34
  return getattr(self.feature_extractor, attr)
35
 
36
 
37
+ def train_huggingface(config: dict):
38
  TARGET_CLASSES = config["dance_ids"]
39
  DEVICE = config["device"]
40
  SEED = config["seed"]
 
43
  epochs = config["trainer"]["min_epochs"]
44
  test_proportion = config["data_module"].get("test_proportion", 0.2)
45
  pl.seed_everything(SEED, workers=True)
46
+ feature_extractor = Wav2VecFeatureExtractor()
47
+ dataset = get_datasets(config["datasets"], feature_extractor)
48
+ dataset = HuggingFaceDatasetWrapper(dataset)
49
  id2label, label2id = get_id_label_mapping(TARGET_CLASSES)
50
  test_proportion = config["data_module"]["test_proportion"]
51
  train_proporition = 1 - test_proportion
52
  train_ds, test_ds = random_split(dataset, [train_proporition, test_proportion])
53
+
54
  model = AutoModelForAudioClassification.from_pretrained(
55
  MODEL_CHECKPOINT,
56
  num_labels=len(TARGET_CLASSES),
 
79
  args=training_args,
80
  train_dataset=train_ds,
81
  eval_dataset=test_ds,
 
82
  compute_metrics=compute_hf_metrics,
83
  )
84
  trainer.train()