Spaces:
Runtime error
Runtime error
waidhoferj
commited on
Commit
·
b6800ef
1
Parent(s):
14f49a9
updated paths to work remotely
Browse files- models/config/train.yaml +5 -5
- preprocessing/pipelines.py +10 -7
- train.py +2 -4
models/config/train.yaml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
global:
|
2 |
id: ast_ptl
|
3 |
-
device:
|
4 |
seed: 42
|
5 |
dance_ids:
|
6 |
- ATN
|
@@ -19,10 +19,10 @@ global:
|
|
19 |
- VWZ
|
20 |
- WCS
|
21 |
data_module:
|
22 |
-
song_data_path:
|
23 |
-
song_audio_path:
|
24 |
-
batch_size:
|
25 |
-
num_workers:
|
26 |
min_votes: 1
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
|
|
1 |
global:
|
2 |
id: ast_ptl
|
3 |
+
device: cuda
|
4 |
seed: 42
|
5 |
dance_ids:
|
6 |
- ATN
|
|
|
19 |
- VWZ
|
20 |
- WCS
|
21 |
data_module:
|
22 |
+
song_data_path: ../datastores/dance-music/songs_cleaned.csv
|
23 |
+
song_audio_path: ../datastores/dance-music
|
24 |
+
batch_size: 64
|
25 |
+
num_workers: 4
|
26 |
min_votes: 1
|
27 |
dataset_kwargs:
|
28 |
audio_window_duration: 6
|
preprocessing/pipelines.py
CHANGED
@@ -3,8 +3,6 @@ import torchaudio
|
|
3 |
from torchaudio import transforms as taT, functional as taF
|
4 |
import torch.nn as nn
|
5 |
|
6 |
-
NOISE_PATH = "data/augmentation/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"
|
7 |
-
|
8 |
class AudioTrainingPipeline(torch.nn.Module):
|
9 |
def __init__(self,
|
10 |
input_freq=16000,
|
@@ -13,12 +11,13 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
13 |
freq_mask_size=10,
|
14 |
time_mask_size=80,
|
15 |
mask_count = 2,
|
16 |
-
snr_mean=6.0
|
|
|
17 |
super().__init__()
|
18 |
self.input_freq = input_freq
|
19 |
self.snr_mean = snr_mean
|
20 |
self.mask_count = mask_count
|
21 |
-
self.noise = self.get_noise()
|
22 |
self.resample = taT.Resample(input_freq,resample_freq)
|
23 |
self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
|
24 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
@@ -28,8 +27,10 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
28 |
self.time_mask = taT.TimeMasking(time_mask_size)
|
29 |
|
30 |
|
31 |
-
def get_noise(self) -> torch.Tensor:
|
32 |
-
|
|
|
|
|
33 |
if noise.shape[0] > 1:
|
34 |
noise = noise.mean(0, keepdim=True)
|
35 |
if sr != self.input_freq:
|
@@ -37,6 +38,7 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
37 |
return noise
|
38 |
|
39 |
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
|
|
|
40 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
41 |
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
|
42 |
noise_power = noise.norm(p=2)
|
@@ -53,7 +55,8 @@ class AudioTrainingPipeline(torch.nn.Module):
|
|
53 |
except:
|
54 |
print("oops")
|
55 |
waveform = self.preprocess_waveform(waveform)
|
56 |
-
|
|
|
57 |
spec = self.audio_to_spectrogram(waveform)
|
58 |
|
59 |
# Spectrogram augmentation
|
|
|
3 |
from torchaudio import transforms as taT, functional as taF
|
4 |
import torch.nn as nn
|
5 |
|
|
|
|
|
6 |
class AudioTrainingPipeline(torch.nn.Module):
|
7 |
def __init__(self,
|
8 |
input_freq=16000,
|
|
|
11 |
freq_mask_size=10,
|
12 |
time_mask_size=80,
|
13 |
mask_count = 2,
|
14 |
+
snr_mean=6.0,
|
15 |
+
noise_path=None):
|
16 |
super().__init__()
|
17 |
self.input_freq = input_freq
|
18 |
self.snr_mean = snr_mean
|
19 |
self.mask_count = mask_count
|
20 |
+
self.noise = self.get_noise(noise_path)
|
21 |
self.resample = taT.Resample(input_freq,resample_freq)
|
22 |
self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
|
23 |
self.audio_to_spectrogram = AudioToSpectrogram(
|
|
|
27 |
self.time_mask = taT.TimeMasking(time_mask_size)
|
28 |
|
29 |
|
30 |
+
def get_noise(self, path) -> torch.Tensor:
|
31 |
+
if path is None:
|
32 |
+
return None
|
33 |
+
noise, sr = torchaudio.load(path)
|
34 |
if noise.shape[0] > 1:
|
35 |
noise = noise.mean(0, keepdim=True)
|
36 |
if sr != self.input_freq:
|
|
|
38 |
return noise
|
39 |
|
40 |
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
|
41 |
+
assert self.noise is not None, "Cannot add noise because a noise file was not provided."
|
42 |
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
|
43 |
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
|
44 |
noise_power = noise.norm(p=2)
|
|
|
55 |
except:
|
56 |
print("oops")
|
57 |
waveform = self.preprocess_waveform(waveform)
|
58 |
+
if self.noise is not None:
|
59 |
+
waveform = self.add_noise(waveform)
|
60 |
spec = self.audio_to_spectrogram(waveform)
|
61 |
|
62 |
# Spectrogram augmentation
|
train.py
CHANGED
@@ -14,9 +14,8 @@ from models.residual import ResidualDancer, TrainingEnvironment
|
|
14 |
import yaml
|
15 |
from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
|
16 |
from torch.utils.data import random_split
|
17 |
-
from wakepy import keepawake
|
18 |
import numpy as np
|
19 |
-
from transformers import
|
20 |
from argparse import ArgumentParser
|
21 |
|
22 |
|
@@ -151,5 +150,4 @@ if __name__ == "__main__":
|
|
151 |
config = get_config(args.config)
|
152 |
training_id = config["global"]["id"]
|
153 |
train = get_training_fn(training_id)
|
154 |
-
|
155 |
-
train(config)
|
|
|
14 |
import yaml
|
15 |
from preprocessing.dataset import DanceDataModule, WaveformSongDataset, HuggingFaceWaveformSongDataset
|
16 |
from torch.utils.data import random_split
|
|
|
17 |
import numpy as np
|
18 |
+
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
19 |
from argparse import ArgumentParser
|
20 |
|
21 |
|
|
|
150 |
config = get_config(args.config)
|
151 |
training_id = config["global"]["id"]
|
152 |
train = get_training_fn(training_id)
|
153 |
+
train(config)
|
|