File size: 4,792 Bytes
4b8361a
0030bc6
4b8361a
 
 
557fb53
 
 
 
 
 
 
 
0030bc6
 
b6800ef
248f682
557fb53
 
248f682
0030bc6
4b8361a
b6800ef
 
 
 
0030bc6
 
248f682
 
0030bc6
4b8361a
557fb53
 
 
 
0030bc6
557fb53
0030bc6
 
 
 
 
 
 
4b8361a
557fb53
0030bc6
b6800ef
 
557fb53
 
 
 
 
 
 
 
 
 
248f682
557fb53
 
 
 
 
 
0030bc6
4b8361a
0030bc6
 
 
 
 
4b8361a
 
ba35f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0030bc6
557fb53
4b8361a
0030bc6
4b8361a
557fb53
51f4763
0030bc6
51f4763
 
4b8361a
0030bc6
51f4763
0030bc6
4b8361a
51f4763
0030bc6
51f4763
0030bc6
 
 
 
557fb53
51f4763
 
 
 
557fb53
0030bc6
557fb53
0030bc6
 
557fb53
0030bc6
 
 
 
557fb53
 
 
0030bc6
 
557fb53
0030bc6
 
557fb53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn


class WaveformTrainingPipeline(torch.nn.Module):
    def __init__(
        self,
        expected_duration=6,
        snr_mean=6.0,
        noise_path=None,
    ):
        super().__init__()
        self.snr_mean = snr_mean
        self.noise = self.get_noise(noise_path)
        self.sample_rate = 16000

        self.preprocess_waveform = WaveformPreprocessing(
            self.sample_rate * expected_duration
        )

    def get_noise(self, path) -> torch.Tensor:
        if path is None:
            return None
        noise, sr = torchaudio.load(path)
        if noise.shape[0] > 1:
            noise = noise.mean(0, keepdim=True)
        if sr != self.sample_rate:
            noise = taF.resample(noise, sr, self.sample_rate)
        return noise

    def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
        assert (
            self.noise is not None
        ), "Cannot add noise because a noise file was not provided."
        num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
        noise = self.noise.repeat(1, num_repeats)[:, : waveform.shape[1]]
        noise_power = noise.norm(p=2)
        signal_power = waveform.norm(p=2)
        snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
        snr = torch.exp(snr_db / 10)
        scale = snr * noise_power / signal_power
        noisy_waveform = (scale * waveform + noise) / 2
        return noisy_waveform

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        waveform = self.preprocess_waveform(waveform)
        if self.noise is not None:
            waveform = self.add_noise(waveform)
        return waveform


class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
    def __init__(
        self, freq_mask_size=10, time_mask_size=80, mask_count=2, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.mask_count = mask_count
        self.audio_to_spectrogram = AudioToSpectrogram(
            sample_rate=self.sample_rate,
        )
        self.freq_mask = taT.FrequencyMasking(freq_mask_size)
        self.time_mask = taT.TimeMasking(time_mask_size)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        waveform = super().forward(waveform)
        spec = self.audio_to_spectrogram(waveform)

        # Spectrogram augmentation
        for _ in range(self.mask_count):
            spec = self.freq_mask(spec)
            spec = self.time_mask(spec)
        return spec


class SpectrogramProductionPipeline(torch.nn.Module):
    def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.preprocess_waveform = WaveformPreprocessing(
            sample_rate * expected_duration
        )
        self.audio_to_spectrogram = AudioToSpectrogram(
            sample_rate=sample_rate,
        )

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        waveform = self.preprocess_waveform(waveform)
        return self.audio_to_spectrogram(waveform)


class WaveformPreprocessing(torch.nn.Module):
    def __init__(self, expected_sample_length: int):
        super().__init__()
        self.expected_sample_length = expected_sample_length

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        c_dim = 1 if len(waveform.shape) == 3 else 0
        # Take out extra channels
        if waveform.shape[c_dim] > 1:
            waveform = waveform.mean(c_dim, keepdim=True)

        # ensure it is the correct length
        waveform = self._rectify_duration(waveform, c_dim)
        return waveform

    def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
        expected_samples = self.expected_sample_length
        sample_count = waveform.shape[channel_dim + 1]
        if expected_samples == sample_count:
            return waveform
        elif expected_samples > sample_count:
            pad_amount = expected_samples - sample_count
            return torch.nn.functional.pad(
                waveform,
                (channel_dim + 1) * [0] + [pad_amount],
                mode="constant",
                value=0.0,
            )
        else:
            return waveform[:, :expected_samples]


class AudioToSpectrogram:
    def __init__(
        self,
        sample_rate=16000,
    ):
        self.spec = taT.MelSpectrogram(
            sample_rate=sample_rate, n_mels=128, n_fft=1024
        )  # Note: this doesn't work on mps right now.
        self.to_db = taT.AmplitudeToDB()

    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        spectrogram = self.spec(waveform)
        spectrogram = self.to_db(spectrogram)
        return spectrogram