waidhoferj commited on
Commit
6193575
·
1 Parent(s): 9c870fa

updated to return string category labels

Browse files
Files changed (1) hide show
  1. preprocessing/dataset.py +10 -4
preprocessing/dataset.py CHANGED
@@ -173,13 +173,19 @@ class BestBallroomDataset(Dataset):
173
  self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
174
  ) -> None:
175
  super().__init__()
176
- song_paths, labels = self.get_examples(audio_dir, class_list)
 
 
 
177
  with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
178
  durations = json.load(f)
179
- durations = {os.path.join(audio_dir, filepath): duration for filepath, duration in durations.items()}
 
 
 
180
  audio_durations = [durations[song] for song in song_paths]
181
  self.song_dataset = SongDataset(
182
- song_paths, labels, audio_durations=audio_durations, **kwargs
183
  )
184
 
185
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
@@ -208,7 +214,7 @@ class BestBallroomDataset(Dataset):
208
  song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
209
  labels.extend([dance_label] * len(folder_contents))
210
 
211
- return np.array(song_paths), np.stack(labels)
212
 
213
 
214
  class Music4DanceDataset(Dataset):
 
173
  self, audio_dir="data/ballroom-songs", class_list=None, **kwargs
174
  ) -> None:
175
  super().__init__()
176
+ song_paths, encoded_labels, str_labels = self.get_examples(
177
+ audio_dir, class_list
178
+ )
179
+ self.labels = str_labels
180
  with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
181
  durations = json.load(f)
182
+ durations = {
183
+ os.path.join(audio_dir, filepath): duration
184
+ for filepath, duration in durations.items()
185
+ }
186
  audio_durations = [durations[song] for song in song_paths]
187
  self.song_dataset = SongDataset(
188
+ song_paths, encoded_labels, audio_durations=audio_durations, **kwargs
189
  )
190
 
191
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
 
214
  song_paths.extend(os.path.join(folder_path, f) for f in folder_contents)
215
  labels.extend([dance_label] * len(folder_contents))
216
 
217
+ return np.array(song_paths), np.stack(labels), dances
218
 
219
 
220
  class Music4DanceDataset(Dataset):