waidhoferj commited on
Commit
b6800ef
·
1 Parent(s): 14f49a9

updated paths to work remotely

Browse files
models/config/train.yaml CHANGED
@@ -1,6 +1,6 @@
1
  global:
2
  id: ast_ptl
3
- device: mps
4
  seed: 42
5
  dance_ids:
6
  - ATN
@@ -19,10 +19,10 @@ global:
19
  - VWZ
20
  - WCS
21
  data_module:
22
- song_data_path: data/samples/songs_cleaned.csv
23
- song_audio_path: data/samples
24
- batch_size: 256
25
- num_workers: 10
26
  min_votes: 1
27
  dataset_kwargs:
28
  audio_window_duration: 6
 
1
  global:
2
  id: ast_ptl
3
+ device: cuda
4
  seed: 42
5
  dance_ids:
6
  - ATN
 
19
  - VWZ
20
  - WCS
21
  data_module:
22
+ song_data_path: ../datastores/dance-music/songs_cleaned.csv
23
+ song_audio_path: ../datastores/dance-music
24
+ batch_size: 64
25
+ num_workers: 4
26
  min_votes: 1
27
  dataset_kwargs:
28
  audio_window_duration: 6
preprocessing/pipelines.py CHANGED
@@ -3,8 +3,6 @@ import torchaudio
3
  from torchaudio import transforms as taT, functional as taF
4
  import torch.nn as nn
5
 
6
- NOISE_PATH = "data/augmentation/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
7
-
8
  class AudioTrainingPipeline(torch.nn.Module):
9
  def __init__(self,
10
  input_freq=16000,
@@ -13,12 +11,13 @@ class AudioTrainingPipeline(torch.nn.Module):
13
  freq_mask_size=10,
14
  time_mask_size=80,
15
  mask_count = 2,
16
- snr_mean=6.0):
 
17
  super().__init__()
18
  self.input_freq = input_freq
19
  self.snr_mean = snr_mean
20
  self.mask_count = mask_count
21
- self.noise = self.get_noise()
22
  self.resample = taT.Resample(input_freq,resample_freq)
23
  self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
24
  self.audio_to_spectrogram = AudioToSpectrogram(
@@ -28,8 +27,10 @@ class AudioTrainingPipeline(torch.nn.Module):
28
  self.time_mask = taT.TimeMasking(time_mask_size)
29
 
30
 
31
- def get_noise(self) -> torch.Tensor:
32
- noise, sr = torchaudio.load(NOISE_PATH)
 
 
33
  if noise.shape[0] > 1:
34
  noise = noise.mean(0, keepdim=True)
35
  if sr != self.input_freq:
@@ -37,6 +38,7 @@ class AudioTrainingPipeline(torch.nn.Module):
37
  return noise
38
 
39
  def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
 
40
  num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
41
  noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
42
  noise_power = noise.norm(p=2)
@@ -53,7 +55,8 @@ class AudioTrainingPipeline(torch.nn.Module):
53
  except:
54
  print("oops")
55
  waveform = self.preprocess_waveform(waveform)
56
- waveform = self.add_noise(waveform)
 
57
  spec = self.audio_to_spectrogram(waveform)
58
 
59
  # Spectrogram augmentation
 
3
  from torchaudio import transforms as taT, functional as taF
4
  import torch.nn as nn
5
 
 
 
6
  class AudioTrainingPipeline(torch.nn.Module):
7
  def __init__(self,
8
  input_freq=16000,
 
11
  freq_mask_size=10,
12
  time_mask_size=80,
13
  mask_count = 2,
14
+ snr_mean=6.0,
15
+ noise_path=None):
16
  super().__init__()
17
  self.input_freq = input_freq
18
  self.snr_mean = snr_mean
19
  self.mask_count = mask_count
20
+ self.noise = self.get_noise(noise_path)
21
  self.resample = taT.Resample(input_freq,resample_freq)
22
  self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
23
  self.audio_to_spectrogram = AudioToSpectrogram(
 
27
  self.time_mask = taT.TimeMasking(time_mask_size)
28
 
29
 
30
+ def get_noise(self, path) -> torch.Tensor:
31
+ if path is None:
32
+ return None
33
+ noise, sr = torchaudio.load(path)
34
  if noise.shape[0] > 1:
35
  noise = noise.mean(0, keepdim=True)
36
  if sr != self.input_freq:
 
38
  return noise
39
 
40
  def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
41
+ assert self.noise is not None, "Cannot add noise because a noise file was not provided."
42
  num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
43
  noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
44
  noise_power = noise.norm(p=2)
 
55
  except:
56
  print("oops")
57
  waveform = self.preprocess_waveform(waveform)
58
+ if self.noise is not None:
59
+ waveform = self.add_noise(waveform)
60
  spec = self.audio_to_spectrogram(waveform)
61
 
62
  # Spectrogram augmentation
train.py CHANGED
@@ -14,9 +14,8 @@ from models.residual import ResidualDancer, TrainingEnvironment
14
  import yaml
15
  from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
16
  from torch.utils.data import random_split
17
- from wakepy import keepawake
18
  import numpy as np
19
- from transformers import ASTFeatureExtractor, AutoFeatureExtractor, ASTConfig, AutoModelForAudioClassification
20
  from argparse import ArgumentParser
21
 
22
 
@@ -151,5 +150,4 @@ if __name__ == "__main__":
151
  config = get_config(args.config)
152
  training_id = config["global"]["id"]
153
  train = get_training_fn(training_id)
154
- with keepawake():
155
- train(config)
 
14
  import yaml
15
  from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
16
  from torch.utils.data import random_split
 
17
  import numpy as np
18
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
19
  from argparse import ArgumentParser
20
 
21
 
 
150
  config = get_config(args.config)
151
  training_id = config["global"]["id"]
152
  train = get_training_fn(training_id)
153
+ train(config)