John Waidhofer commited on
Commit
9f53273
·
1 Parent(s): 5da8010

fixing issues with residual dancer

Browse files
models/config/train.yaml CHANGED
@@ -1,4 +1,4 @@
1
- training_fn: wav2vec2.train_huggingface
2
  device: cuda
3
  seed: 42
4
  dance_ids: &dance_ids
@@ -19,7 +19,7 @@ dance_ids: &dance_ids
19
  data_module:
20
  batch_size: 64
21
  num_workers: 5
22
- test_proportion: 0.2
23
 
24
  datasets:
25
  preprocessing.dataset.Music4DanceDataset:
 
1
+ training_fn: residual.train_residual_dancer
2
  device: cuda
3
  seed: 42
4
  dance_ids: &dance_ids
 
19
  data_module:
20
  batch_size: 64
21
  num_workers: 5
22
+ test_proportion: 0.15
23
 
24
  datasets:
25
  preprocessing.dataset.Music4DanceDataset:
models/residual.py CHANGED
@@ -25,7 +25,6 @@ class ResidualDancer(nn.Module):
25
  self.n_channels = n_channels
26
  self.n_classes = n_classes
27
 
28
- # Spectrogram
29
  self.spec_bn = nn.BatchNorm2d(1)
30
 
31
  # CNN
@@ -111,6 +110,7 @@ def train_residual_dancer(config: dict):
111
  TARGET_CLASSES = config["dance_ids"]
112
  DEVICE = config["device"]
113
  SEED = config["seed"]
 
114
  pl.seed_everything(SEED, workers=True)
115
  feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
116
  dataset = get_datasets(config["datasets"], feature_extractor)
 
25
  self.n_channels = n_channels
26
  self.n_classes = n_classes
27
 
 
28
  self.spec_bn = nn.BatchNorm2d(1)
29
 
30
  # CNN
 
110
  TARGET_CLASSES = config["dance_ids"]
111
  DEVICE = config["device"]
112
  SEED = config["seed"]
113
+ torch.set_float32_matmul_precision('medium')
114
  pl.seed_everything(SEED, workers=True)
115
  feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
116
  dataset = get_datasets(config["datasets"], feature_extractor)
preprocessing/dataset.py CHANGED
@@ -73,7 +73,13 @@ class SongDataset(Dataset):
73
 
74
  waveform = self._waveform_from_index(idx)
75
  dance_labels = self._label_from_index(idx)
76
- return waveform, dance_labels
 
 
 
 
 
 
77
 
78
  def _idx2audio_idx(self, idx: int) -> int:
79
  return self._get_audio_loc_from_idx(idx)[0]
 
73
 
74
  waveform = self._waveform_from_index(idx)
75
  dance_labels = self._label_from_index(idx)
76
+
77
+ if self._validate_output(waveform, dance_labels):
78
+ return waveform, dance_labels
79
+ else:
80
+ # WARNING: Could cause train/test split leak
81
+ return self[idx-1]
82
+
83
 
84
  def _idx2audio_idx(self, idx: int) -> int:
85
  return self._get_audio_loc_from_idx(idx)[0]
preprocessing/pipelines.py CHANGED
@@ -115,7 +115,4 @@ class AudioToSpectrogram:
115
  def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
116
  spectrogram = self.spec(waveform)
117
  spectrogram = self.to_db(spectrogram)
118
-
119
- # Normalize
120
- spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
121
  return spectrogram
 
115
  def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
116
  spectrogram = self.spec(waveform)
117
  spectrogram = self.to_db(spectrogram)
 
 
 
118
  return spectrogram