File size: 4,321 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from torch.utils.data import Dataset
from pathlib import Path
from typing import Optional
import torch
from torch.utils.data import default_collate
from typing import Tuple
from functools import partial
from gyraudio.audio_separation.properties import (
    AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB
)


class AudioDataset(Dataset):
    def __init__(
        self,
        data_path: Path,
        augmentation_config: dict = {},
        snr_filter: Optional[float] = None,
        debug: bool = False
    ):
        self.debug = debug
        self.data_path = data_path
        self.augmentation_config = augmentation_config
        self.snr_filter = snr_filter
        self.load_data()
        self.length = len(self.file_list)
        self.collate_fn = None
        if AUG_TRIM in self.augmentation_config:
            self.collate_fn = partial(collate_fn_generic,
                                      lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS],
                                      length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER],
                                      trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB])

    def filter_data(self, snr):
        if self.snr_filter is None:
            return True
        if snr in self.snr_filter:
            return True
        else:
            return False

    def load_data(self):
        raise NotImplementedError("load_data method must be implemented")

    def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal):
        if AUG_RESCALE in self.augmentation_config:
            current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device)
            # logging.debug(current_amplitude)
            mixed_audio_signal *= current_amplitude
            noise_audio_signal *= current_amplitude
            clean_audio_signal *= current_amplitude
        if AUG_AWGN in self.augmentation_config:
            # noise_std = self.augmentation_config[AUG_AWGN]["noise_std"]
            noise_std = 0.01
            current_noise_std = torch.randn(1) * noise_std
            # logging.debug(current_noise_std)
            extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std
            mixed_audio_signal = mixed_audio_signal+extra_awgn
            # Open question: should we add noise to the noise signal aswell?

        return mixed_audio_signal, clean_audio_signal, noise_audio_signal

    def __len__(self):
        return self.length

    def __getitem__(self, idx: int) -> torch.Tensor:
        raise NotImplementedError("__getitem__ method must be implemented")


def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]:
    """Collate function to allow trimming (=crop the time dimension) of the signals in a batch.

    Args:
    batch (list): A list of tuples (triplets), where each tuple contain:
    - mixed_audio_signal
    - clean_audio_signal
    - noise_audio_signal
    lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1)
    length_divider (int) : has to be a trimmed length divider
    trim_prob (float) : trimming probability

    Returns:
    - Tensor: A batch of mixed_audio_signal, trimmed to the same length.
    - Tensor: A batch of clean_audio_signal
    - Tensor: A batch of noise_audio_signal
    """

    # Find the length of the shortest signal in the batch
    mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch)
    length = mixed_audio_signal[0].shape[-1]
    min_length, max_length = lengths_lim
    take_full_signal = torch.rand(1) > trim_prob
    if not take_full_signal:
        start = torch.randint(0, length-min_length, (1,))
        trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,))
        trim_length = trim_length-trim_length % length_divider
        end = start + trim_length
    else:
        start = 0
        end = length - length % length_divider
    mixed_audio_signal = mixed_audio_signal[..., start:end]
    clean_audio_signal = clean_audio_signal[..., start:end]
    noise_audio_signal = noise_audio_signal[..., start:end]
    return mixed_audio_signal, clean_audio_signal, noise_audio_signal