Spaces:
Runtime error
Runtime error
Commit
·
6193575
1
Parent(s):
9c870fa
updated to return string category labels
Browse files- preprocessing/dataset.py +10 -4
preprocessing/dataset.py
CHANGED
@@ -173,13 +173,19 @@ class BestBallroomDataset(Dataset):
|
|
173 |
self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
|
174 |
) -> None:
|
175 |
super().__init__()
|
176 |
-
song_paths,
|
|
|
|
|
|
|
177 |
with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
|
178 |
durations = json.load(f)
|
179 |
-
durations = {
|
|
|
|
|
|
|
180 |
audio_durations = [durations[song] for song in song_paths]
|
181 |
self.song_dataset = SongDataset(
|
182 |
-
song_paths,
|
183 |
)
|
184 |
|
185 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
@@ -208,7 +214,7 @@ class BestBallroomDataset(Dataset):
|
|
208 |
song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
|
209 |
labels.extend([dance_label] * len(folder_contents))
|
210 |
|
211 |
-
return np.array(song_paths), np.stack(labels)
|
212 |
|
213 |
|
214 |
class Music4DanceDataset(Dataset):
|
|
|
173 |
self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
|
174 |
) -> None:
|
175 |
super().__init__()
|
176 |
+
song_paths, encoded_labels, str_labels = self.get_examples(
|
177 |
+
audio_dir, class_list
|
178 |
+
)
|
179 |
+
self.labels = str_labels
|
180 |
with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
|
181 |
durations = json.load(f)
|
182 |
+
durations = {
|
183 |
+
os.path.join(audio_dir, filepath): duration
|
184 |
+
for filepath, duration in durations.items()
|
185 |
+
}
|
186 |
audio_durations = [durations[song] for song in song_paths]
|
187 |
self.song_dataset = SongDataset(
|
188 |
+
song_paths, encoded_labels, audio_durations=audio_durations, **kwargs
|
189 |
)
|
190 |
|
191 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
214 |
song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
|
215 |
labels.extend([dance_label] * len(folder_contents))
|
216 |
|
217 |
+
return np.array(song_paths), np.stack(labels), dances
|
218 |
|
219 |
|
220 |
class Music4DanceDataset(Dataset):
|