import os from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional import numpy as np import pedalboard as pb import torch import torchaudio as ta from torch.utils import data from models.bandit.core.data._types import AudioDict, DataDict class BaseSourceSeparationDataset(data.Dataset, ABC): def __init__( self, split: str, stems: List[str], files: List[str], data_path: str, fs: int, npy_memmap: bool, recompute_mixture: bool ): self.split = split self.stems = stems self.stems_no_mixture = [s for s in stems if s != "mixture"] self.files = files self.data_path = data_path self.fs = fs self.npy_memmap = npy_memmap self.recompute_mixture = recompute_mixture @abstractmethod def get_stem( self, *, stem: str, identifier: Dict[str, Any] ) -> torch.Tensor: raise NotImplementedError def _get_audio(self, stems, identifier: Dict[str, Any]): audio = {} for stem in stems: audio[stem] = self.get_stem(stem=stem, identifier=identifier) return audio def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: if self.recompute_mixture: audio = self._get_audio( self.stems_no_mixture, identifier=identifier ) audio["mixture"] = self.compute_mixture(audio) return audio else: return self._get_audio(self.stems, identifier=identifier) @abstractmethod def get_identifier(self, index: int) -> Dict[str, Any]: pass def compute_mixture(self, audio: AudioDict) -> torch.Tensor: return sum( audio[stem] for stem in audio if stem != "mixture" )