John Waidhofer commited on
Commit
ec63e8e
1 Parent(s): 99c5692

update duration loading

Browse files
models/config/train.yaml CHANGED
@@ -23,12 +23,16 @@ data_module:
23
 
24
  datasets:
25
  preprocessing.dataset.Music4DanceDataset:
26
- song_data_path: ../datastores/dance-music/songs_cleaned.csv
27
- song_audio_path: ../datastores/dance-music
28
  class_list: *dance_ids
29
  multi_label: False
30
  min_votes: 1
31
  audio_window_jitter: 0.7
 
 
 
 
32
 
33
  model:
34
  n_channels: 128
@@ -46,7 +50,7 @@ trainer:
46
  min_epochs: 7
47
  fast_dev_run: False
48
  # gradient_clip_val: 0.5
49
- # overfit_batches: 1
50
 
51
  training_environment:
52
  learning_rate: 0.00053
 
23
 
24
  datasets:
25
  preprocessing.dataset.Music4DanceDataset:
26
+ song_data_path: ../../s3_connections/music4dance/songs_cleaned.csv
27
+ song_audio_path: ../../s3_connections/music4dance
28
  class_list: *dance_ids
29
  multi_label: False
30
  min_votes: 1
31
  audio_window_jitter: 0.7
32
+ preprocessing.dataset.BestBallroomDataset:
33
+ audio_dir: ../../s3_connections/ballroom-songs
34
+ class_list: *dance_ids
35
+ audio_window_jitter: 0.7
36
 
37
  model:
38
  n_channels: 128
 
50
  min_epochs: 7
51
  fast_dev_run: False
52
  # gradient_clip_val: 0.5
53
+ overfit_batches: 1
54
 
55
  training_environment:
56
  learning_rate: 0.00053
preprocessing/dataset.py CHANGED
@@ -174,6 +174,7 @@ class BestBallroomDataset(Dataset):
174
  song_paths, labels = self.get_examples(audio_dir, class_list)
175
  with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
176
  durations = json.load(f)
 
177
  audio_durations = [durations[song] for song in song_paths]
178
  self.song_dataset = SongDataset(
179
  song_paths, labels, audio_durations=audio_durations, **kwargs
 
174
  song_paths, labels = self.get_examples(audio_dir, class_list)
175
  with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
176
  durations = json.load(f)
177
+ durations = {os.path.join(audio_dir, filepath): duration for filepath, duration in durations.items()}
178
  audio_durations = [durations[song] for song in song_paths]
179
  self.song_dataset = SongDataset(
180
  song_paths, labels, audio_durations=audio_durations, **kwargs