waidhoferj commited on
Commit
6ba247f
·
1 Parent(s): e748bc2

added class count function

Browse files
Files changed (1) hide show
  1. preprocessing/dataset.py +25 -4
preprocessing/dataset.py CHANGED
@@ -2,7 +2,7 @@ import importlib
2
  import os
3
  from typing import Any
4
  import torch
5
- from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
6
  import numpy as np
7
  import pandas as pd
8
  import torchaudio as ta
@@ -278,6 +278,7 @@ class DanceDataModule(pl.LightningDataModule):
278
  target_classes: list[str] = None,
279
  batch_size: int = 64,
280
  num_workers=10,
 
281
  ):
282
  super().__init__()
283
  self.val_proportion = val_proportion
@@ -286,6 +287,10 @@ class DanceDataModule(pl.LightningDataModule):
286
  self.target_classes = target_classes
287
  self.batch_size = batch_size
288
  self.num_workers = num_workers
 
 
 
 
289
  self.dataset = dataset
290
 
291
  def setup(self, stage: str):
@@ -317,9 +322,10 @@ class DanceDataModule(pl.LightningDataModule):
317
  )
318
 
319
  def get_label_weights(self):
320
- weights = [
321
- ds.song_dataset.get_label_weights() for ds in self.dataset._data.datasets
322
- ]
 
323
  return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
324
 
325
 
@@ -349,3 +355,18 @@ def get_datasets(dataset_config: dict, feature_extractor) -> Dataset:
349
  ProvidedDataset = getattr(module, class_name)
350
  datasets.append(ProvidedDataset(**kwargs))
351
  return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  from typing import Any
4
  import torch
5
+ from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset
6
  import numpy as np
7
  import pandas as pd
8
  import torchaudio as ta
 
278
  target_classes: list[str] = None,
279
  batch_size: int = 64,
280
  num_workers=10,
281
+ data_subset=None,
282
  ):
283
  super().__init__()
284
  self.val_proportion = val_proportion
 
287
  self.target_classes = target_classes
288
  self.batch_size = batch_size
289
  self.num_workers = num_workers
290
+
291
+ if data_subset is not None and float(data_subset) != 1.0:
292
+ dataset, _ = random_split(dataset, [data_subset, 1 - data_subset])
293
+
294
  self.dataset = dataset
295
 
296
  def setup(self, stage: str):
 
322
  )
323
 
324
  def get_label_weights(self):
325
+ dataset = (
326
+ self.dataset.dataset if isinstance(self.dataset, Subset) else self.dataset
327
+ )
328
+ weights = [ds.song_dataset.get_label_weights() for ds in dataset._data.datasets]
329
  return torch.mean(torch.stack(weights), dim=0) # TODO: Make this weighted
330
 
331
 
 
355
  ProvidedDataset = getattr(module, class_name)
356
  datasets.append(ProvidedDataset(**kwargs))
357
  return PipelinedDataset(ConcatDataset(datasets), feature_extractor)
358
+
359
+
360
+ def get_class_counts(config: dict):
361
+ # TODO: Figure out why music4dance has fractional labels
362
+ dataset = get_datasets(config["datasets"], lambda x: x)
363
+ counts = sum(
364
+ np.sum(
365
+ np.arange(len(config["dance_ids"]))
366
+ == np.expand_dims(ds.song_dataset.dance_labels.argmax(1), 1),
367
+ axis=0,
368
+ )
369
+ for ds in dataset._data.datasets
370
+ )
371
+ labels = sorted(config["dance_ids"])
372
+ return dict(zip(labels, counts))