waidhoferj commited on
Commit
248f682
1 Parent(s): ec63e8e

updated resample pipeline

Browse files
preprocessing/dataset.py CHANGED
@@ -29,6 +29,7 @@ class SongDataset(Dataset):
29
  audio_window_duration=6, # seconds
30
  audio_window_jitter=1.0, # seconds
31
  audio_durations=None,
 
32
  ):
33
  assert (
34
  audio_window_duration > audio_window_jitter
@@ -54,6 +55,7 @@ class SongDataset(Dataset):
54
  self.audio_window_duration = int(audio_window_duration)
55
  self.audio_start_offset = audio_start_offset
56
  self.audio_window_jitter = audio_window_jitter
 
57
 
58
  def __len__(self):
59
  return int(
@@ -125,9 +127,9 @@ class SongDataset(Dataset):
125
  waveform, sample_rate = ta.load(
126
  audio_filepath, frame_offset=frame_offset, num_frames=num_frames
127
  )
128
- assert (
129
- sample_rate == self.sample_rate
130
- ), f"Expected sample rate of {self.sample_rate}. Found {sample_rate}"
131
  return waveform
132
 
133
  def _label_from_index(self, idx: int) -> torch.Tensor:
 
29
  audio_window_duration=6, # seconds
30
  audio_window_jitter=1.0, # seconds
31
  audio_durations=None,
32
+ target_sample_rate=16000,
33
  ):
34
  assert (
35
  audio_window_duration > audio_window_jitter
 
55
  self.audio_window_duration = int(audio_window_duration)
56
  self.audio_start_offset = audio_start_offset
57
  self.audio_window_jitter = audio_window_jitter
58
+ self.target_sample_rate = target_sample_rate
59
 
60
  def __len__(self):
61
  return int(
 
127
  waveform, sample_rate = ta.load(
128
  audio_filepath, frame_offset=frame_offset, num_frames=num_frames
129
  )
130
+ waveform = ta.functional.resample(
131
+ waveform, orig_freq=sample_rate, new_freq=self.target_sample_rate
132
+ )
133
  return waveform
134
 
135
  def _label_from_index(self, idx: int) -> torch.Tensor:
preprocessing/pipelines.py CHANGED
@@ -7,21 +7,17 @@ import torch.nn as nn
7
  class WaveformTrainingPipeline(torch.nn.Module):
8
  def __init__(
9
  self,
10
- input_freq=16000,
11
- resample_freq=16000,
12
  expected_duration=6,
13
  snr_mean=6.0,
14
  noise_path=None,
15
  ):
16
  super().__init__()
17
- self.input_freq = input_freq
18
  self.snr_mean = snr_mean
19
  self.noise = self.get_noise(noise_path)
20
- self.resample_frequency = resample_freq
21
- self.resample = taT.Resample(input_freq, resample_freq)
22
 
23
  self.preprocess_waveform = WaveformPreprocessing(
24
- resample_freq * expected_duration
25
  )
26
 
27
  def get_noise(self, path) -> torch.Tensor:
@@ -30,8 +26,8 @@ class WaveformTrainingPipeline(torch.nn.Module):
30
  noise, sr = torchaudio.load(path)
31
  if noise.shape[0] > 1:
32
  noise = noise.mean(0, keepdim=True)
33
- if sr != self.input_freq:
34
- noise = taF.resample(noise, sr, self.input_freq)
35
  return noise
36
 
37
  def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
@@ -49,7 +45,6 @@ class WaveformTrainingPipeline(torch.nn.Module):
49
  return noisy_waveform
50
 
51
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
52
- waveform = self.resample(waveform)
53
  waveform = self.preprocess_waveform(waveform)
54
  if self.noise is not None:
55
  waveform = self.add_noise(waveform)
@@ -63,7 +58,7 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
63
  super().__init__(*args, **kwargs)
64
  self.mask_count = mask_count
65
  self.audio_to_spectrogram = AudioToSpectrogram(
66
- sample_rate=self.resample_frequency,
67
  )
68
  self.freq_mask = taT.FrequencyMasking(freq_mask_size)
69
  self.time_mask = taT.TimeMasking(time_mask_size)
 
7
  class WaveformTrainingPipeline(torch.nn.Module):
8
  def __init__(
9
  self,
 
 
10
  expected_duration=6,
11
  snr_mean=6.0,
12
  noise_path=None,
13
  ):
14
  super().__init__()
 
15
  self.snr_mean = snr_mean
16
  self.noise = self.get_noise(noise_path)
17
+ self.sample_rate = 16000
 
18
 
19
  self.preprocess_waveform = WaveformPreprocessing(
20
+ self.sample_rate * expected_duration
21
  )
22
 
23
  def get_noise(self, path) -> torch.Tensor:
 
26
  noise, sr = torchaudio.load(path)
27
  if noise.shape[0] > 1:
28
  noise = noise.mean(0, keepdim=True)
29
+ if sr != self.sample_rate:
30
+ noise = taF.resample(noise, sr, self.sample_rate)
31
  return noise
32
 
33
  def add_noise(self, waveform: torch.Tensor) -> torch.Tensor:
 
45
  return noisy_waveform
46
 
47
  def forward(self, waveform: torch.Tensor) -> torch.Tensor:
 
48
  waveform = self.preprocess_waveform(waveform)
49
  if self.noise is not None:
50
  waveform = self.add_noise(waveform)
 
58
  super().__init__(*args, **kwargs)
59
  self.mask_count = mask_count
60
  self.audio_to_spectrogram = AudioToSpectrogram(
61
+ sample_rate=self.sample_rate,
62
  )
63
  self.freq_mask = taT.FrequencyMasking(freq_mask_size)
64
  self.time_mask = taT.TimeMasking(time_mask_size)