File size: 1,986 Bytes
51e2f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

import numpy as np
import pedalboard as pb
import torch
import torchaudio as ta
from torch.utils import data

from models.bandit.core.data._types import AudioDict, DataDict


class BaseSourceSeparationDataset(data.Dataset, ABC):
    def __init__(

            self, split: str,

            stems: List[str],

            files: List[str],

            data_path: str,

            fs: int,

            npy_memmap: bool,

            recompute_mixture: bool

            ):
        self.split = split
        self.stems = stems
        self.stems_no_mixture = [s for s in stems if s != "mixture"]
        self.files = files
        self.data_path = data_path
        self.fs = fs
        self.npy_memmap = npy_memmap
        self.recompute_mixture = recompute_mixture

    @abstractmethod
    def get_stem(

            self,

            *,

            stem: str,

            identifier: Dict[str, Any]

            ) -> torch.Tensor:
        raise NotImplementedError

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}
        for stem in stems:
            audio[stem] = self.get_stem(stem=stem, identifier=identifier)

        return audio

    def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:

        if self.recompute_mixture:
            audio = self._get_audio(
                self.stems_no_mixture,
                identifier=identifier
                )
            audio["mixture"] = self.compute_mixture(audio)
            return audio
        else:
            return self._get_audio(self.stems, identifier=identifier)

    @abstractmethod
    def get_identifier(self, index: int) -> Dict[str, Any]:
        pass

    def compute_mixture(self, audio: AudioDict) -> torch.Tensor:

        return sum(
                audio[stem] for stem in audio if stem != "mixture"
        )