Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
•
248f682
1
Parent(s):
ec63e8e
updated resample pipeline
Browse files- preprocessing/dataset.py +5 -3
- preprocessing/pipelines.py +5 -10
preprocessing/dataset.py
CHANGED
@@ -29,6 +29,7 @@ class SongDataset(Dataset):
|
|
29 |
audio_window_duration=6, # seconds
|
30 |
audio_window_jitter=1.0, # seconds
|
31 |
audio_durations=None,
|
|
|
32 |
):
|
33 |
assert (
|
34 |
audio_window_duration > audio_window_jitter
|
@@ -54,6 +55,7 @@ class SongDataset(Dataset):
|
|
54 |
self.audio_window_duration = int(audio_window_duration)
|
55 |
self.audio_start_offset = audio_start_offset
|
56 |
self.audio_window_jitter = audio_window_jitter
|
|
|
57 |
|
58 |
def __len__(self):
|
59 |
return int(
|
@@ -125,9 +127,9 @@ class SongDataset(Dataset):
|
|
125 |
waveform, sample_rate = ta.load(
|
126 |
audio_filepath, frame_offset=frame_offset, num_frames=num_frames
|
127 |
)
|
128 |
-
|
129 |
-
sample_rate
|
130 |
-
)
|
131 |
return waveform
|
132 |
|
133 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
|
|
29 |
audio_window_duration=6, # seconds
|
30 |
audio_window_jitter=1.0, # seconds
|
31 |
audio_durations=None,
|
32 |
+
target_sample_rate=16000,
|
33 |
):
|
34 |
assert (
|
35 |
audio_window_duration > audio_window_jitter
|
|
|
55 |
self.audio_window_duration = int(audio_window_duration)
|
56 |
self.audio_start_offset = audio_start_offset
|
57 |
self.audio_window_jitter = audio_window_jitter
|
58 |
+
self.target_sample_rate = target_sample_rate
|
59 |
|
60 |
def __len__(self):
|
61 |
return int(
|
|
|
127 |
waveform, sample_rate = ta.load(
|
128 |
audio_filepath, frame_offset=frame_offset, num_frames=num_frames
|
129 |
)
|
130 |
+
waveform = ta.functional.resample(
|
131 |
+
waveform, orig_freq=sample_rate, new_freq=self.target_sample_rate
|
132 |
+
)
|
133 |
return waveform
|
134 |
|
135 |
def _label_from_index(self, idx: int) -> torch.Tensor:
|
preprocessing/pipelines.py
CHANGED
@@ -7,21 +7,17 @@ import torch.nn as nn
|
|
7 |
class WaveformTrainingPipeline(torch.nn.Module):
|
8 |
def __init__(
|
9 |
self,
|
10 |
-
input_freq=16000,
|
11 |
-
resample_freq=16000,
|
12 |
expected_duration=6,
|
13 |
snr_mean=6.0,
|
14 |
noise_path=None,
|
15 |
):
|
16 |
super().__init__()
|
17 |
-
self.input_freq = input_freq
|
18 |
self.snr_mean = snr_mean
|
19 |
self.noise = self.get_noise(noise_path)
|
20 |
-
self.
|
21 |
-
self.resample = taT.Resample(input_freq, resample_freq)
|
22 |
|
23 |
self.preprocess_waveform = WaveformPreprocessing(
|
24 |
-
|
25 |
)
|
26 |
|
27 |
def get_noise(self, path) -> torch.Tensor:
|
@@ -30,8 +26,8 @@ class WaveformTrainingPipeline(torch.nn.Module):
|
|
30 |
noise, sr = torchaudio.load(path)
|
31 |
if noise.shape[0] > 1:
|
32 |
noise = noise.mean(0, keepdim=True)
|
33 |
-
if sr != self.
|
34 |
-
noise = taF.resample(noise, sr, self.
|
35 |
return noise
|
36 |
|
37 |
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
@@ -49,7 +45,6 @@ class WaveformTrainingPipeline(torch.nn.Module):
|
|
49 |
return noisy_waveform
|
50 |
|
51 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
52 |
-
waveform = self.resample(waveform)
|
53 |
waveform = self.preprocess_waveform(waveform)
|
54 |
if self.noise is not None:
|
55 |
waveform = self.add_noise(waveform)
|
@@ -63,7 +58,7 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
|
63 |
super().__init__(*args, **kwargs)
|
64 |
self.mask_count = mask_count
|
65 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
66 |
-
sample_rate=self.
|
67 |
)
|
68 |
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
69 |
self.time_mask = taT.TimeMasking(time_mask_size)
|
|
|
7 |
class WaveformTrainingPipeline(torch.nn.Module):
|
8 |
def __init__(
|
9 |
self,
|
|
|
|
|
10 |
expected_duration=6,
|
11 |
snr_mean=6.0,
|
12 |
noise_path=None,
|
13 |
):
|
14 |
super().__init__()
|
|
|
15 |
self.snr_mean = snr_mean
|
16 |
self.noise = self.get_noise(noise_path)
|
17 |
+
self.sample_rate = 16000
|
|
|
18 |
|
19 |
self.preprocess_waveform = WaveformPreprocessing(
|
20 |
+
self.sample_rate * expected_duration
|
21 |
)
|
22 |
|
23 |
def get_noise(self, path) -> torch.Tensor:
|
|
|
26 |
noise, sr = torchaudio.load(path)
|
27 |
if noise.shape[0] > 1:
|
28 |
noise = noise.mean(0, keepdim=True)
|
29 |
+
if sr != self.sample_rate:
|
30 |
+
noise = taF.resample(noise, sr, self.sample_rate)
|
31 |
return noise
|
32 |
|
33 |
def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
45 |
return noisy_waveform
|
46 |
|
47 |
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
|
|
48 |
waveform = self.preprocess_waveform(waveform)
|
49 |
if self.noise is not None:
|
50 |
waveform = self.add_noise(waveform)
|
|
|
58 |
super().__init__(*args, **kwargs)
|
59 |
self.mask_count = mask_count
|
60 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
61 |
+
sample_rate=self.sample_rate,
|
62 |
)
|
63 |
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
|
64 |
self.time_mask = taT.TimeMasking(time_mask_size)
|