Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
·
99c5692
1
Parent(s):
d68daac
added audio duration tracking to BestBallroomDataset
Browse files- preprocessing/dataset.py +22 -1
preprocessing/dataset.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import importlib
|
|
|
2 |
import os
|
3 |
from typing import Any
|
4 |
import torch
|
@@ -7,6 +8,7 @@ import numpy as np
|
|
7 |
import pandas as pd
|
8 |
import torchaudio as ta
|
9 |
import pytorch_lightning as pl
|
|
|
10 |
|
11 |
from preprocessing.preprocess import (
|
12 |
fix_dance_rating_counts,
|
@@ -170,7 +172,12 @@ class BestBallroomDataset(Dataset):
|
|
170 |
) -> None:
|
171 |
super().__init__()
|
172 |
song_paths, labels = self.get_examples(audio_dir, class_list)
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
176 |
return self.song_dataset[index]
|
@@ -388,3 +395,17 @@ def get_class_counts(config: dict):
|
|
388 |
)
|
389 |
labels = sorted(config["dance_ids"])
|
390 |
return dict(zip(labels, counts))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import importlib
|
2 |
+
import json
|
3 |
import os
|
4 |
from typing import Any
|
5 |
import torch
|
|
|
8 |
import pandas as pd
|
9 |
import torchaudio as ta
|
10 |
import pytorch_lightning as pl
|
11 |
+
from glob import iglob
|
12 |
|
13 |
from preprocessing.preprocess import (
|
14 |
fix_dance_rating_counts,
|
|
|
172 |
) -> None:
|
173 |
super().__init__()
|
174 |
song_paths, labels = self.get_examples(audio_dir, class_list)
|
175 |
+
with open(os.path.join(audio_dir, "audio_durations.json"), "r") as f:
|
176 |
+
durations = json.load(f)
|
177 |
+
audio_durations = [durations[song] for song in song_paths]
|
178 |
+
self.song_dataset = SongDataset(
|
179 |
+
song_paths, labels, audio_durations=audio_durations, **kwargs
|
180 |
+
)
|
181 |
|
182 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
183 |
return self.song_dataset[index]
|
|
|
395 |
)
|
396 |
labels = sorted(config["dance_ids"])
|
397 |
return dict(zip(labels, counts))
|
398 |
+
|
399 |
+
|
400 |
+
def record_audio_durations(folder: str):
|
401 |
+
"""
|
402 |
+
Records a filename: duration mapping of all audio files in a folder to a json file.
|
403 |
+
"""
|
404 |
+
durations = {}
|
405 |
+
music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
|
406 |
+
for file in music_files:
|
407 |
+
meta = ta.info(file)
|
408 |
+
durations[file] = meta.num_frames / meta.sample_rate
|
409 |
+
|
410 |
+
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
411 |
+
json.dump(durations, f)
|