Spaces:
Runtime error
Runtime error
Commit
·
1160793
1
Parent(s):
9ff4511
refactored train loop
Browse filesadded Multi_STFT loss from paper
added minimal dataset
- README.md +8 -1
- checkpoing_saver.py +3 -2
- conf/config.yaml +1 -4
- conf/loss/L1_Multi_STFT.yaml +1 -0
- datasets/minimal.py +24 -0
- datasets/valentini.py +1 -0
- losses.py +174 -1
- main.py +0 -1
- notebooks/EDA.ipynb +6 -25
- train.py +50 -50
- transforms.py +2 -0
README.md
CHANGED
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
|
3 |
|
@@ -7,4 +15,3 @@
|
|
7 |
| ideal denoising | 1.9709 | 0.9211 |
|
8 |
| baseline | 1.7433 | 0.8844 |
|
9 |
|
10 |
-
|
|
|
1 |
+
# MVP
|
2 |
+
Сервисом является web interface, в котором пользователь
|
3 |
+
сможет записать своей голос в шумных условиях и получить на выход аудиозапись без шума.
|
4 |
+
Для обработки шумных аудио файлов есть доступ к API на питоне.
|
5 |
+
|
6 |
+
Web interface реализован на gradio. Сама работа пишется в контексте фрейморка pytorch.
|
7 |
+
В качестве системы контроля экспериментов выбран wandb. Для управления конфигами - hydra.
|
8 |
+
Архитектура модели базируется на работе "Real Time Speech Enhancement in the Waveform Domain" от facebook.
|
9 |
|
10 |
|
11 |
|
|
|
15 |
| ideal denoising | 1.9709 | 0.9211 |
|
16 |
| baseline | 1.7433 | 0.8844 |
|
17 |
|
|
checkpoing_saver.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import wandb
|
6 |
|
7 |
class CheckpointSaver:
|
8 |
-
def __init__(self, dirpath, decreasing=True, top_n=5):
|
9 |
"""
|
10 |
dirpath: Directory path where to store all model weights
|
11 |
decreasing: If decreasing is `True`, then lower metric is better
|
@@ -17,9 +17,10 @@ class CheckpointSaver:
|
|
17 |
self.decreasing = decreasing
|
18 |
self.top_model_paths = []
|
19 |
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
|
|
20 |
|
21 |
def __call__(self, model, epoch, metric_val):
|
22 |
-
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
|
23 |
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
24 |
if save:
|
25 |
logging.info(
|
|
|
5 |
import wandb
|
6 |
|
7 |
class CheckpointSaver:
|
8 |
+
def __init__(self, dirpath, run_name='', decreasing=True, top_n=5):
|
9 |
"""
|
10 |
dirpath: Directory path where to store all model weights
|
11 |
decreasing: If decreasing is `True`, then lower metric is better
|
|
|
17 |
self.decreasing = decreasing
|
18 |
self.top_model_paths = []
|
19 |
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
20 |
+
self.run_name = run_name
|
21 |
|
22 |
def __call__(self, model, epoch, metric_val):
|
23 |
+
model_path = os.path.join(self.dirpath, self.run_name, model.__class__.__name__ + f'_epoch{epoch}.pt')
|
24 |
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
25 |
if save:
|
26 |
logging.info(
|
conf/config.yaml
CHANGED
@@ -16,10 +16,7 @@ dataloader:
|
|
16 |
|
17 |
validation:
|
18 |
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
19 |
-
|
20 |
-
easy: p232_284.wav
|
21 |
-
medium: p232_071.wav
|
22 |
-
hard : p257_171.wav
|
23 |
|
24 |
|
25 |
wandb:
|
|
|
16 |
|
17 |
validation:
|
18 |
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
19 |
+
sample_rate: 48000
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
wandb:
|
conf/loss/L1_Multi_STFT.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
name: L1_Multi_STFT
|
datasets/minimal.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from pathlib import Path
|
4 |
+
import torchaudio
|
5 |
+
import numpy as np
|
6 |
+
from torchaudio.transforms import Resample
|
7 |
+
|
8 |
+
|
9 |
+
class Minimal(Dataset):
|
10 |
+
def __init__(self, cfg):
|
11 |
+
self.wavs = ['p232_284.wav', 'p232_071.wav', 'p257_171.wav']
|
12 |
+
self.dataset_path = cfg['validation']['path']
|
13 |
+
self.target_rate = cfg['dataloader']['sample_rate']
|
14 |
+
self.resampler = Resample(orig_freq=cfg['validation']['sample_rate'],
|
15 |
+
new_freq=cfg['dataloader']['sample_rate'])
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return len(self.wavs)
|
19 |
+
|
20 |
+
def __getitem__(self, idx):
|
21 |
+
wav, rate = torchaudio.load(self.wavs[idx])
|
22 |
+
wav = self.resampler(wav)
|
23 |
+
wav = torch.reshape(wav, (1, 1, -1))
|
24 |
+
return wav, self.target_rate
|
datasets/valentini.py
CHANGED
@@ -36,6 +36,7 @@ class Valentini(Dataset):
|
|
36 |
if self.transform:
|
37 |
random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
|
38 |
torch.manual_seed(random_seed)
|
|
|
39 |
noisy_wav = self.transform(noisy_wav)
|
40 |
torch.manual_seed(random_seed)
|
41 |
clean_wav = self.transform(clean_wav)
|
|
|
36 |
if self.transform:
|
37 |
random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
|
38 |
torch.manual_seed(random_seed)
|
39 |
+
|
40 |
noisy_wav = self.transform(noisy_wav)
|
41 |
torch.manual_seed(random_seed)
|
42 |
clean_wav = self.transform(clean_wav)
|
losses.py
CHANGED
@@ -1,7 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
LOSSES = {
|
4 |
-
'mse': torch.nn.MSELoss()
|
|
|
|
|
|
|
5 |
}
|
6 |
|
7 |
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
# Original copyright 2019 Tomoki Hayashi
|
9 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
10 |
+
|
11 |
+
|
12 |
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
"""STFT-based Loss modules."""
|
16 |
+
|
17 |
+
|
18 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
19 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
20 |
+
Args:
|
21 |
+
x (Tensor): Input signal tensor (B, T).
|
22 |
+
fft_size (int): FFT size.
|
23 |
+
hop_size (int): Hop size.
|
24 |
+
win_length (int): Window length.
|
25 |
+
window (str): Window function type.
|
26 |
+
Returns:
|
27 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
28 |
+
"""
|
29 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
30 |
+
real = x_stft[..., 0]
|
31 |
+
imag = x_stft[..., 1]
|
32 |
+
|
33 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
34 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
35 |
+
|
36 |
+
|
37 |
+
class SpectralConvergengeLoss(torch.nn.Module):
|
38 |
+
"""Spectral convergence loss module."""
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
"""Initilize spectral convergence loss module."""
|
42 |
+
super(SpectralConvergengeLoss, self).__init__()
|
43 |
+
|
44 |
+
def forward(self, x_mag, y_mag):
|
45 |
+
"""Calculate forward propagation.
|
46 |
+
Args:
|
47 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
48 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
49 |
+
Returns:
|
50 |
+
Tensor: Spectral convergence loss value.
|
51 |
+
"""
|
52 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
53 |
+
|
54 |
+
|
55 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
56 |
+
"""Log STFT magnitude loss module."""
|
57 |
+
|
58 |
+
def __init__(self):
|
59 |
+
"""Initilize los STFT magnitude loss module."""
|
60 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
61 |
+
|
62 |
+
def forward(self, x_mag, y_mag):
|
63 |
+
"""Calculate forward propagation.
|
64 |
+
Args:
|
65 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
66 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
67 |
+
Returns:
|
68 |
+
Tensor: Log STFT magnitude loss value.
|
69 |
+
"""
|
70 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
71 |
+
|
72 |
+
|
73 |
+
class STFTLoss(torch.nn.Module):
|
74 |
+
"""STFT loss module."""
|
75 |
+
|
76 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
77 |
+
"""Initialize STFT loss module."""
|
78 |
+
super(STFTLoss, self).__init__()
|
79 |
+
self.fft_size = fft_size
|
80 |
+
self.shift_size = shift_size
|
81 |
+
self.win_length = win_length
|
82 |
+
self.register_buffer("window", getattr(torch, window)(win_length))
|
83 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
84 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
85 |
+
|
86 |
+
def forward(self, x, y):
|
87 |
+
"""Calculate forward propagation.
|
88 |
+
Args:
|
89 |
+
x (Tensor): Predicted signal (B, T).
|
90 |
+
y (Tensor): Groundtruth signal (B, T).
|
91 |
+
Returns:
|
92 |
+
Tensor: Spectral convergence loss value.
|
93 |
+
Tensor: Log STFT magnitude loss value.
|
94 |
+
"""
|
95 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
96 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
97 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
98 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
99 |
+
|
100 |
+
return sc_loss, mag_loss
|
101 |
+
|
102 |
+
|
103 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
104 |
+
"""Multi resolution STFT loss module."""
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
fft_sizes=[1024, 2048, 512],
|
108 |
+
hop_sizes=[120, 240, 50],
|
109 |
+
win_lengths=[600, 1200, 240],
|
110 |
+
window="hann_window", factor_sc=0.1, factor_mag=0.1):
|
111 |
+
"""Initialize Multi resolution STFT loss module.
|
112 |
+
Args:
|
113 |
+
fft_sizes (list): List of FFT sizes.
|
114 |
+
hop_sizes (list): List of hop sizes.
|
115 |
+
win_lengths (list): List of window lengths.
|
116 |
+
window (str): Window function type.
|
117 |
+
factor (float): a balancing factor across different losses.
|
118 |
+
"""
|
119 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
120 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
121 |
+
self.stft_losses = torch.nn.ModuleList()
|
122 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
123 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
124 |
+
self.factor_sc = factor_sc
|
125 |
+
self.factor_mag = factor_mag
|
126 |
+
|
127 |
+
def forward(self, x, y):
|
128 |
+
"""Calculate forward propagation.
|
129 |
+
Args:
|
130 |
+
x (Tensor): Predicted signal (B, T).
|
131 |
+
y (Tensor): Groundtruth signal (B, T).
|
132 |
+
Returns:
|
133 |
+
Tensor: Multi resolution spectral convergence loss value.
|
134 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
135 |
+
"""
|
136 |
+
sc_loss = 0.0
|
137 |
+
mag_loss = 0.0
|
138 |
+
for f in self.stft_losses:
|
139 |
+
sc_l, mag_l = f(x, y)
|
140 |
+
sc_loss += sc_l
|
141 |
+
mag_loss += mag_l
|
142 |
+
sc_loss /= len(self.stft_losses)
|
143 |
+
mag_loss /= len(self.stft_losses)
|
144 |
+
|
145 |
+
return self.factor_sc*sc_loss, self.factor_mag*mag_loss
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class L1_Multi_STFT(torch.nn.Module):
|
151 |
+
"""STFT loss module."""
|
152 |
+
|
153 |
+
def __init__(self):
|
154 |
+
"""Initialize STFT loss module."""
|
155 |
+
super(L1_Multi_STFT, self).__init__()
|
156 |
+
self.multi_STFT_loss = MultiResolutionSTFTLoss()
|
157 |
+
self.l1_loss = torch.nn.L1Loss()
|
158 |
+
|
159 |
+
def forward(self, x, y):
|
160 |
+
"""Calculate forward propagation.
|
161 |
+
Args:
|
162 |
+
x (Tensor): Predicted signal (B, T).
|
163 |
+
y (Tensor): Groundtruth signal (B, T).
|
164 |
+
Returns:
|
165 |
+
Tensor: Spectral convergence loss value.
|
166 |
+
Tensor: Log STFT magnitude loss value.
|
167 |
+
"""
|
168 |
+
sc_loss, mag_loss = self.multi_STFT_loss(x, y)
|
169 |
+
l1_loss = self.l1_loss(x, y)
|
170 |
+
return sc_loss + mag_loss + l1_loss
|
171 |
+
|
172 |
|
173 |
LOSSES = {
|
174 |
+
'mse': torch.nn.MSELoss(),
|
175 |
+
'L1': torch.nn.L1Loss(),
|
176 |
+
'Multi_STFT': MultiResolutionSTFTLoss,
|
177 |
+
'L1_Multi_STFT': L1_Multi_STFT
|
178 |
}
|
179 |
|
180 |
|
main.py
CHANGED
@@ -5,7 +5,6 @@ from train import train
|
|
5 |
|
6 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
7 |
def main(cfg: DictConfig):
|
8 |
-
print(OmegaConf.to_yaml(cfg))
|
9 |
train(cfg)
|
10 |
|
11 |
|
|
|
5 |
|
6 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
7 |
def main(cfg: DictConfig):
|
|
|
8 |
train(cfg)
|
9 |
|
10 |
|
notebooks/EDA.ipynb
CHANGED
@@ -2,44 +2,25 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "f800718e-c29f-44d8-bf41-e02d50d0f730",
|
7 |
"metadata": {
|
8 |
"ExecuteTime": {
|
9 |
"start_time": "2023-04-29T13:11:15.198687Z",
|
10 |
"end_time": "2023-04-29T13:11:15.245584Z"
|
11 |
-
}
|
12 |
-
},
|
13 |
-
"outputs": [
|
14 |
-
{
|
15 |
-
"name": "stderr",
|
16 |
-
"output_type": "stream",
|
17 |
-
"text": [
|
18 |
-
"/home/maksim/.local/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:22: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)\n",
|
19 |
-
" EPSILON = torch.tensor(torch.finfo(torch.float).eps)\n"
|
20 |
-
]
|
21 |
},
|
22 |
-
{
|
23 |
-
"
|
24 |
-
"evalue": "No module named 'matplotlib'",
|
25 |
-
"output_type": "error",
|
26 |
-
"traceback": [
|
27 |
-
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
28 |
-
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
|
29 |
-
"Cell \u001B[0;32mIn[1], line 3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mdatasets\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Valentini\n\u001B[1;32m 5\u001B[0m dataset \u001B[38;5;241m=\u001B[39m Valentini()\n",
|
30 |
-
"File \u001B[0;32m~/PycharmProjects/denoising/datasets.py:4\u001B[0m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Dataset\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[0;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m load_wav\n\u001B[1;32m 7\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mValentini\u001B[39;00m(Dataset):\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, dataset_path\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m/media/public/datasets/denoising/DS_10283_2791/\u001B[39m\u001B[38;5;124m'\u001B[39m, transform\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m 9\u001B[0m valid\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n",
|
31 |
-
"File \u001B[0;32m~/PycharmProjects/denoising/utils.py:3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpyplot\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mplt\u001B[39;00m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[1;32m 7\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollect_valentini_paths\u001B[39m(dataset_path):\n",
|
32 |
-
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'matplotlib'"
|
33 |
-
]
|
34 |
}
|
35 |
-
|
|
|
36 |
"source": [
|
37 |
"\n",
|
38 |
"from pathlib import Path\n",
|
39 |
"\n",
|
40 |
"from datasets import Valentini\n",
|
41 |
"\n",
|
42 |
-
"dataset = Valentini()"
|
43 |
]
|
44 |
},
|
45 |
{
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "f800718e-c29f-44d8-bf41-e02d50d0f730",
|
7 |
"metadata": {
|
8 |
"ExecuteTime": {
|
9 |
"start_time": "2023-04-29T13:11:15.198687Z",
|
10 |
"end_time": "2023-04-29T13:11:15.245584Z"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
},
|
12 |
+
"pycharm": {
|
13 |
+
"is_executing": true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
}
|
15 |
+
},
|
16 |
+
"outputs": [],
|
17 |
"source": [
|
18 |
"\n",
|
19 |
"from pathlib import Path\n",
|
20 |
"\n",
|
21 |
"from datasets import Valentini\n",
|
22 |
"\n",
|
23 |
+
"dataset = Valentini('/media/public/datasets/denoising/DS_10283_2791/', valid=False)"
|
24 |
]
|
25 |
},
|
26 |
{
|
train.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
from torch.utils.data import DataLoader
|
4 |
-
|
5 |
from omegaconf import DictConfig
|
6 |
import wandb
|
7 |
-
import torchaudio
|
8 |
|
9 |
from checkpoing_saver import CheckpointSaver
|
10 |
from denoisers import get_model
|
@@ -12,7 +11,7 @@ from optimizers import get_optimizer
|
|
12 |
from losses import get_loss
|
13 |
from datasets import get_datasets
|
14 |
from testing.metrics import Metrics
|
15 |
-
import
|
16 |
|
17 |
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
18 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -26,67 +25,68 @@ def train(cfg: DictConfig):
|
|
26 |
config=omegaconf.OmegaConf.to_container(
|
27 |
cfg, resolve=True, throw_on_missing=True))
|
28 |
|
29 |
-
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'])
|
30 |
metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
|
31 |
|
32 |
model = get_model(cfg['model']).to(device)
|
33 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
34 |
loss_fn = get_loss(cfg['loss'])
|
35 |
train_dataset, valid_dataset = get_datasets(cfg)
|
|
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
wandb.watch(model, log_freq=100)
|
41 |
|
42 |
for epoch in range(cfg['training']['num_epochs']):
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
outputs = model(inputs)
|
49 |
-
loss = loss_fn(outputs, labels)
|
50 |
-
loss.backward()
|
51 |
-
optimizer.step()
|
52 |
-
|
53 |
-
if i % cfg['wandb']['log_interval'] == 0:
|
54 |
-
wandb.log({"loss": loss})
|
55 |
-
|
56 |
-
model.train(False)
|
57 |
-
|
58 |
-
running_vloss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
59 |
-
with torch.no_grad():
|
60 |
-
for i, vdata in enumerate(validation_loader):
|
61 |
-
vinputs, vlabels = vdata
|
62 |
-
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
|
63 |
-
voutputs = model(vinputs)
|
64 |
-
vloss = loss_fn(voutputs, vlabels)
|
65 |
-
running_vloss += vloss
|
66 |
-
running_metrics = metrics.calculate(denoised=voutputs, clean=vlabels)
|
67 |
-
running_pesq += running_metrics['PESQ']
|
68 |
-
running_stoi += running_metrics['STOI']
|
69 |
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
avg_pesq = running_pesq / len(validation_loader)
|
73 |
-
avg_stoi = running_stoi / len(validation_loader)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
prediction = model(wav)
|
83 |
-
wandb.log({
|
84 |
-
f"{tag}_epoch_{epoch}": wandb.Audio(
|
85 |
-
prediction.cpu()[0][0],
|
86 |
-
sample_rate=rate)})
|
87 |
-
|
88 |
-
checkpoint_saver(model, epoch, metric_val=avg_pesq)
|
89 |
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
if
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
from torch.utils.data import DataLoader
|
4 |
+
import omegaconf
|
5 |
from omegaconf import DictConfig
|
6 |
import wandb
|
|
|
7 |
|
8 |
from checkpoing_saver import CheckpointSaver
|
9 |
from denoisers import get_model
|
|
|
11 |
from losses import get_loss
|
12 |
from datasets import get_datasets
|
13 |
from testing.metrics import Metrics
|
14 |
+
from datasets.minimal import Minimal
|
15 |
|
16 |
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
25 |
config=omegaconf.OmegaConf.to_container(
|
26 |
cfg, resolve=True, throw_on_missing=True))
|
27 |
|
28 |
+
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
|
29 |
metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
|
30 |
|
31 |
model = get_model(cfg['model']).to(device)
|
32 |
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
33 |
loss_fn = get_loss(cfg['loss'])
|
34 |
train_dataset, valid_dataset = get_datasets(cfg)
|
35 |
+
minimal_dataset = Minimal(cfg)
|
36 |
|
37 |
+
dataloaders = {
|
38 |
+
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True),
|
39 |
+
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True),
|
40 |
+
'minimal': DataLoader(minimal_dataset)
|
41 |
+
}
|
42 |
|
43 |
wandb.watch(model, log_freq=100)
|
44 |
|
45 |
for epoch in range(cfg['training']['num_epochs']):
|
46 |
+
for phase in ['train', 'val']:
|
47 |
+
if phase == 'train':
|
48 |
+
model.train()
|
49 |
+
else:
|
50 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
53 |
+
for i, (inputs, labels) in enumerate(dataloaders[phase]):
|
54 |
+
inputs = inputs.to(device)
|
55 |
+
labels = labels.to(device)
|
56 |
|
57 |
+
optimizer.zero_grad()
|
|
|
|
|
58 |
|
59 |
+
with torch.set_grad_enabled(phase == 'train'):
|
60 |
+
outputs = model(inputs)
|
61 |
+
loss = loss_fn(outputs, labels)
|
62 |
|
63 |
+
if phase == 'train':
|
64 |
+
loss.backward()
|
65 |
+
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
running_metrics = metrics.calculate(denoised=outputs, clean=labels)
|
68 |
+
running_loss += loss.item() * inputs.size(0)
|
69 |
+
running_pesq += running_metrics['PESQ']
|
70 |
+
running_stoi += running_metrics['STOI']
|
71 |
|
72 |
+
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
73 |
+
wandb.log({"train_loss": running_loss / (i + 1),
|
74 |
+
"train_pesq": running_pesq / (i + 1),
|
75 |
+
"train_stoi": running_stoi / (i + 1)})
|
76 |
+
epoch_loss = running_loss / len(dataloaders[phase])
|
77 |
+
eposh_pesq = running_pesq / len(dataloaders[phase])
|
78 |
+
eposh_stoi = running_stoi / len(dataloaders[phase])
|
79 |
+
|
80 |
+
wandb.log({f"{phase}_loss": epoch_loss,
|
81 |
+
f"{phase}_pesq": eposh_pesq,
|
82 |
+
f"{phase}_stoi": eposh_stoi})
|
83 |
+
|
84 |
+
if phase == 'val':
|
85 |
+
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
86 |
+
prediction = model(wav)
|
87 |
+
wandb.log({
|
88 |
+
f"{i}_example": wandb.Audio(
|
89 |
+
prediction.cpu()[0][0],
|
90 |
+
sample_rate=rate)})
|
91 |
+
|
92 |
+
checkpoint_saver(model, epoch, metric_val=eposh_pesq)
|
transforms.py
CHANGED
@@ -3,6 +3,8 @@ import torch
|
|
3 |
from torchaudio.transforms import Resample
|
4 |
from torchvision.transforms import RandomCrop
|
5 |
|
|
|
|
|
6 |
class Transform(torch.nn.Module):
|
7 |
def __init__(
|
8 |
self,
|
|
|
3 |
from torchaudio.transforms import Resample
|
4 |
from torchvision.transforms import RandomCrop
|
5 |
|
6 |
+
|
7 |
+
|
8 |
class Transform(torch.nn.Module):
|
9 |
def __init__(
|
10 |
self,
|