waidhoferj commited on
Commit
99c5692
·
1 Parent(s): d68daac

added audio duration tracking to BestBallroomDataset

Browse files
Files changed (1) hide show
  1. preprocessing/dataset.py +22 -1
preprocessing/dataset.py CHANGED
@@ -1,4 +1,5 @@
1
  import importlib
 
2
  import os
3
  from typing import Any
4
  import torch
@@ -7,6 +8,7 @@ import numpy as np
7
  import pandas as pd
8
  import torchaudio as ta
9
  import pytorch_lightning as pl
 
10
 
11
  from preprocessing.preprocess import (
12
  fix_dance_rating_counts,
@@ -170,7 +172,12 @@ class BestBallroomDataset(Dataset):
170
  ) -> None:
171
  super().__init__()
172
  song_paths, labels = self.get_examples(audio_dir, class_list)
173
- self.song_dataset = SongDataset(song_paths, labels, **kwargs)
 
 
 
 
 
174
 
175
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
176
  return self.song_dataset[index]
@@ -388,3 +395,17 @@ def get_class_counts(config: dict):
388
  )
389
  labels = sorted(config["dance_ids"])
390
  return dict(zip(labels, counts))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import importlib
2
+ import json
3
  import os
4
  from typing import Any
5
  import torch
 
8
  import pandas as pd
9
  import torchaudio as ta
10
  import pytorch_lightning as pl
11
+ from glob import iglob
12
 
13
  from preprocessing.preprocess import (
14
  fix_dance_rating_counts,
 
172
  ) -> None:
173
  super().__init__()
174
  song_paths, labels = self.get_examples(audio_dir, class_list)
175
+ with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
176
+ durations = json.load(f)
177
+ audio_durations = [durations[song] for song in song_paths]
178
+ self.song_dataset = SongDataset(
179
+ song_paths, labels, audio_durations=audio_durations, **kwargs
180
+ )
181
 
182
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
183
  return self.song_dataset[index]
 
395
  )
396
  labels = sorted(config["dance_ids"])
397
  return dict(zip(labels, counts))
398
+
399
+
400
+ def record_audio_durations(folder: str):
401
+ """
402
+ Records a filename: duration mapping of all audio files in a folder to a json file.
403
+ """
404
+ durations = {}
405
+ music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
406
+ for file in music_files:
407
+ meta = ta.info(file)
408
+ durations[file] = meta.num_frames / meta.sample_rate
409
+
410
+ with open(os.path.join(folder, "audio_durations.json"), "w") as f:
411
+ json.dump(durations, f)