Spaces:
Runtime error
Runtime error
John Waidhofer
commited on
Commit
·
9f53273
1
Parent(s):
5da8010
fixing issues with residual dancer
Browse files- models/config/train.yaml +2 -2
- models/residual.py +1 -1
- preprocessing/dataset.py +7 -1
- preprocessing/pipelines.py +0 -3
models/config/train.yaml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
training_fn:
|
2 |
device: cuda
|
3 |
seed: 42
|
4 |
dance_ids: &dance_ids
|
@@ -19,7 +19,7 @@ dance_ids: &dance_ids
|
|
19 |
data_module:
|
20 |
batch_size: 64
|
21 |
num_workers: 5
|
22 |
-
test_proportion: 0.
|
23 |
|
24 |
datasets:
|
25 |
preprocessing.dataset.Music4DanceDataset:
|
|
|
1 |
+
training_fn: residual.train_residual_dancer
|
2 |
device: cuda
|
3 |
seed: 42
|
4 |
dance_ids: &dance_ids
|
|
|
19 |
data_module:
|
20 |
batch_size: 64
|
21 |
num_workers: 5
|
22 |
+
test_proportion: 0.15
|
23 |
|
24 |
datasets:
|
25 |
preprocessing.dataset.Music4DanceDataset:
|
models/residual.py
CHANGED
@@ -25,7 +25,6 @@ class ResidualDancer(nn.Module):
|
|
25 |
self.n_channels = n_channels
|
26 |
self.n_classes = n_classes
|
27 |
|
28 |
-
# Spectrogram
|
29 |
self.spec_bn = nn.BatchNorm2d(1)
|
30 |
|
31 |
# CNN
|
@@ -111,6 +110,7 @@ def train_residual_dancer(config: dict):
|
|
111 |
TARGET_CLASSES = config["dance_ids"]
|
112 |
DEVICE = config["device"]
|
113 |
SEED = config["seed"]
|
|
|
114 |
pl.seed_everything(SEED, workers=True)
|
115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
|
|
25 |
self.n_channels = n_channels
|
26 |
self.n_classes = n_classes
|
27 |
|
|
|
28 |
self.spec_bn = nn.BatchNorm2d(1)
|
29 |
|
30 |
# CNN
|
|
|
110 |
TARGET_CLASSES = config["dance_ids"]
|
111 |
DEVICE = config["device"]
|
112 |
SEED = config["seed"]
|
113 |
+
torch.set_float32_matmul_precision('medium')
|
114 |
pl.seed_everything(SEED, workers=True)
|
115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
preprocessing/dataset.py
CHANGED
@@ -73,7 +73,13 @@ class SongDataset(Dataset):
|
|
73 |
|
74 |
waveform = self._waveform_from_index(idx)
|
75 |
dance_labels = self._label_from_index(idx)
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def _idx2audio_idx(self, idx: int) -> int:
|
79 |
return self._get_audio_loc_from_idx(idx)[0]
|
|
|
73 |
|
74 |
waveform = self._waveform_from_index(idx)
|
75 |
dance_labels = self._label_from_index(idx)
|
76 |
+
|
77 |
+
if self._validate_output(waveform, dance_labels):
|
78 |
+
return waveform, dance_labels
|
79 |
+
else:
|
80 |
+
# WARNING: Could cause train/test split leak
|
81 |
+
return self[idx-1]
|
82 |
+
|
83 |
|
84 |
def _idx2audio_idx(self, idx: int) -> int:
|
85 |
return self._get_audio_loc_from_idx(idx)[0]
|
preprocessing/pipelines.py
CHANGED
@@ -115,7 +115,4 @@ class AudioToSpectrogram:
|
|
115 |
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
|
116 |
spectrogram = self.spec(waveform)
|
117 |
spectrogram = self.to_db(spectrogram)
|
118 |
-
|
119 |
-
# Normalize
|
120 |
-
spectrogram = (spectrogram - spectrogram.mean()) / (2 * spectrogram.std())
|
121 |
return spectrogram
|
|
|
115 |
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
|
116 |
spectrogram = self.spec(waveform)
|
117 |
spectrogram = self.to_db(spectrogram)
|
|
|
|
|
|
|
118 |
return spectrogram
|