Spaces:
Runtime error
Runtime error
Commit
·
6ba247f
1
Parent(s):
e748bc2
added class count function
Browse files- preprocessing/dataset.py +25 -4
preprocessing/dataset.py
CHANGED
@@ -2,7 +2,7 @@ import importlib
|
|
2 |
import os
|
3 |
from typing import Any
|
4 |
import torch
|
5 |
-
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
import torchaudio as ta
|
@@ -278,6 +278,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
278 |
target_classes: list[str] = None,
|
279 |
batch_size: int = 64,
|
280 |
num_workers=10,
|
|
|
281 |
):
|
282 |
super().__init__()
|
283 |
self.val_proportion = val_proportion
|
@@ -286,6 +287,10 @@ class DanceDataModule(pl.LightningDataModule):
|
|
286 |
self.target_classes = target_classes
|
287 |
self.batch_size = batch_size
|
288 |
self.num_workers = num_workers
|
|
|
|
|
|
|
|
|
289 |
self.dataset = dataset
|
290 |
|
291 |
def setup(self, stage: str):
|
@@ -317,9 +322,10 @@ class DanceDataModule(pl.LightningDataModule):
|
|
317 |
)
|
318 |
|
319 |
def get_label_weights(self):
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
323 |
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
324 |
|
325 |
|
@@ -349,3 +355,18 @@ def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
|
|
349 |
ProvidedDataset = getattr(module, class_name)
|
350 |
datasets.append(ProvidedDataset(**kwargs))
|
351 |
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
from typing import Any
|
4 |
import torch
|
5 |
+
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
8 |
import torchaudio as ta
|
|
|
278 |
target_classes: list[str] = None,
|
279 |
batch_size: int = 64,
|
280 |
num_workers=10,
|
281 |
+
data_subset=None,
|
282 |
):
|
283 |
super().__init__()
|
284 |
self.val_proportion = val_proportion
|
|
|
287 |
self.target_classes = target_classes
|
288 |
self.batch_size = batch_size
|
289 |
self.num_workers = num_workers
|
290 |
+
|
291 |
+
if data_subset is not None and float(data_subset) != 1.0:
|
292 |
+
dataset, _ = random_split(dataset, [data_subset, 1 - data_subset])
|
293 |
+
|
294 |
self.dataset = dataset
|
295 |
|
296 |
def setup(self, stage: str):
|
|
|
322 |
)
|
323 |
|
324 |
def get_label_weights(self):
|
325 |
+
dataset = (
|
326 |
+
self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
|
327 |
+
)
|
328 |
+
weights = [ds.song_dataset.get_label_weights() for ds in dataset._data.datasets]
|
329 |
return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
|
330 |
|
331 |
|
|
|
355 |
ProvidedDataset = getattr(module, class_name)
|
356 |
datasets.append(ProvidedDataset(**kwargs))
|
357 |
return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
|
358 |
+
|
359 |
+
|
360 |
+
def get_class_counts(config: dict):
|
361 |
+
# TODO: Figure out why music4dance has fractional labels
|
362 |
+
dataset = get_datasets(config["datasets"], lambda x: x)
|
363 |
+
counts = sum(
|
364 |
+
np.sum(
|
365 |
+
np.arange(len(config["dance_ids"]))
|
366 |
+
== np.expand_dims(ds.song_dataset.dance_labels.argmax(1), 1),
|
367 |
+
axis=0,
|
368 |
+
)
|
369 |
+
for ds in dataset._data.datasets
|
370 |
+
)
|
371 |
+
labels = sorted(config["dance_ids"])
|
372 |
+
return dict(zip(labels, counts))
|