BorisovMaksim commited on
Commit
1160793
·
1 Parent(s): 9ff4511

refactored train loop

Browse files

added Multi_STFT loss from paper
added minimal dataset

README.md CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
 
2
 
3
 
@@ -7,4 +15,3 @@
7
  | ideal denoising | 1.9709 | 0.9211 |
8
  | baseline | 1.7433 | 0.8844 |
9
 
10
-
 
1
+ # MVP
2
+ Сервисом является web interface, в котором пользователь
3
+ сможет записать своей голос в шумных условиях и получить на выход аудиозапись без шума.
4
+ Для обработки шумных аудио файлов есть доступ к API на питоне.
5
+
6
+ Web interface реализован на gradio. Сама работа пишется в контексте фрейморка pytorch.
7
+ В качестве системы контроля экспериментов выбран wandb. Для управления конфигами - hydra.
8
+ Архитектура модели базируется на работе "Real Time Speech Enhancement in the Waveform Domain" от facebook.
9
 
10
 
11
 
 
15
  | ideal denoising | 1.9709 | 0.9211 |
16
  | baseline | 1.7433 | 0.8844 |
17
 
 
checkpoing_saver.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import wandb
6
 
7
  class CheckpointSaver:
8
- def __init__(self, dirpath, decreasing=True, top_n=5):
9
  """
10
  dirpath: Directory path where to store all model weights
11
  decreasing: If decreasing is `True`, then lower metric is better
@@ -17,9 +17,10 @@ class CheckpointSaver:
17
  self.decreasing = decreasing
18
  self.top_model_paths = []
19
  self.best_metric_val = np.Inf if decreasing else -np.Inf
 
20
 
21
  def __call__(self, model, epoch, metric_val):
22
- model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
23
  save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
24
  if save:
25
  logging.info(
 
5
  import wandb
6
 
7
  class CheckpointSaver:
8
+ def __init__(self, dirpath, run_name='', decreasing=True, top_n=5):
9
  """
10
  dirpath: Directory path where to store all model weights
11
  decreasing: If decreasing is `True`, then lower metric is better
 
17
  self.decreasing = decreasing
18
  self.top_model_paths = []
19
  self.best_metric_val = np.Inf if decreasing else -np.Inf
20
+ self.run_name = run_name
21
 
22
  def __call__(self, model, epoch, metric_val):
23
+ model_path = os.path.join(self.dirpath, self.run_name, model.__class__.__name__ + f'_epoch{epoch}.pt')
24
  save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
25
  if save:
26
  logging.info(
conf/config.yaml CHANGED
@@ -16,10 +16,7 @@ dataloader:
16
 
17
  validation:
18
  path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
19
- wavs:
20
- easy: p232_284.wav
21
- medium: p232_071.wav
22
- hard : p257_171.wav
23
 
24
 
25
  wandb:
 
16
 
17
  validation:
18
  path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
19
+ sample_rate: 48000
 
 
 
20
 
21
 
22
  wandb:
conf/loss/L1_Multi_STFT.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ name: L1_Multi_STFT
datasets/minimal.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ import torchaudio
5
+ import numpy as np
6
+ from torchaudio.transforms import Resample
7
+
8
+
9
+ class Minimal(Dataset):
10
+ def __init__(self, cfg):
11
+ self.wavs = ['p232_284.wav', 'p232_071.wav', 'p257_171.wav']
12
+ self.dataset_path = cfg['validation']['path']
13
+ self.target_rate = cfg['dataloader']['sample_rate']
14
+ self.resampler = Resample(orig_freq=cfg['validation']['sample_rate'],
15
+ new_freq=cfg['dataloader']['sample_rate'])
16
+
17
+ def __len__(self):
18
+ return len(self.wavs)
19
+
20
+ def __getitem__(self, idx):
21
+ wav, rate = torchaudio.load(self.wavs[idx])
22
+ wav = self.resampler(wav)
23
+ wav = torch.reshape(wav, (1, 1, -1))
24
+ return wav, self.target_rate
datasets/valentini.py CHANGED
@@ -36,6 +36,7 @@ class Valentini(Dataset):
36
  if self.transform:
37
  random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
38
  torch.manual_seed(random_seed)
 
39
  noisy_wav = self.transform(noisy_wav)
40
  torch.manual_seed(random_seed)
41
  clean_wav = self.transform(clean_wav)
 
36
  if self.transform:
37
  random_seed = 0 if self.valid else torch.randint(HIGH_RANDOM_SEED, (1,))[0]
38
  torch.manual_seed(random_seed)
39
+
40
  noisy_wav = self.transform(noisy_wav)
41
  torch.manual_seed(random_seed)
42
  clean_wav = self.transform(clean_wav)
losses.py CHANGED
@@ -1,7 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  LOSSES = {
4
- 'mse': torch.nn.MSELoss()
 
 
 
5
  }
6
 
7
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Original copyright 2019 Tomoki Hayashi
9
+ # MIT License (https://opensource.org/licenses/MIT)
10
+
11
+
12
  import torch
13
+ import torch.nn.functional as F
14
+
15
+ """STFT-based Loss modules."""
16
+
17
+
18
+ def stft(x, fft_size, hop_size, win_length, window):
19
+ """Perform STFT and convert to magnitude spectrogram.
20
+ Args:
21
+ x (Tensor): Input signal tensor (B, T).
22
+ fft_size (int): FFT size.
23
+ hop_size (int): Hop size.
24
+ win_length (int): Window length.
25
+ window (str): Window function type.
26
+ Returns:
27
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
28
+ """
29
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
30
+ real = x_stft[..., 0]
31
+ imag = x_stft[..., 1]
32
+
33
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
34
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
35
+
36
+
37
+ class SpectralConvergengeLoss(torch.nn.Module):
38
+ """Spectral convergence loss module."""
39
+
40
+ def __init__(self):
41
+ """Initilize spectral convergence loss module."""
42
+ super(SpectralConvergengeLoss, self).__init__()
43
+
44
+ def forward(self, x_mag, y_mag):
45
+ """Calculate forward propagation.
46
+ Args:
47
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
48
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
49
+ Returns:
50
+ Tensor: Spectral convergence loss value.
51
+ """
52
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
53
+
54
+
55
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
56
+ """Log STFT magnitude loss module."""
57
+
58
+ def __init__(self):
59
+ """Initilize los STFT magnitude loss module."""
60
+ super(LogSTFTMagnitudeLoss, self).__init__()
61
+
62
+ def forward(self, x_mag, y_mag):
63
+ """Calculate forward propagation.
64
+ Args:
65
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
66
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
67
+ Returns:
68
+ Tensor: Log STFT magnitude loss value.
69
+ """
70
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
71
+
72
+
73
+ class STFTLoss(torch.nn.Module):
74
+ """STFT loss module."""
75
+
76
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
77
+ """Initialize STFT loss module."""
78
+ super(STFTLoss, self).__init__()
79
+ self.fft_size = fft_size
80
+ self.shift_size = shift_size
81
+ self.win_length = win_length
82
+ self.register_buffer("window", getattr(torch, window)(win_length))
83
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
84
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
85
+
86
+ def forward(self, x, y):
87
+ """Calculate forward propagation.
88
+ Args:
89
+ x (Tensor): Predicted signal (B, T).
90
+ y (Tensor): Groundtruth signal (B, T).
91
+ Returns:
92
+ Tensor: Spectral convergence loss value.
93
+ Tensor: Log STFT magnitude loss value.
94
+ """
95
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
96
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
97
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
98
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
99
+
100
+ return sc_loss, mag_loss
101
+
102
+
103
+ class MultiResolutionSTFTLoss(torch.nn.Module):
104
+ """Multi resolution STFT loss module."""
105
+
106
+ def __init__(self,
107
+ fft_sizes=[1024, 2048, 512],
108
+ hop_sizes=[120, 240, 50],
109
+ win_lengths=[600, 1200, 240],
110
+ window="hann_window", factor_sc=0.1, factor_mag=0.1):
111
+ """Initialize Multi resolution STFT loss module.
112
+ Args:
113
+ fft_sizes (list): List of FFT sizes.
114
+ hop_sizes (list): List of hop sizes.
115
+ win_lengths (list): List of window lengths.
116
+ window (str): Window function type.
117
+ factor (float): a balancing factor across different losses.
118
+ """
119
+ super(MultiResolutionSTFTLoss, self).__init__()
120
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
121
+ self.stft_losses = torch.nn.ModuleList()
122
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
123
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
124
+ self.factor_sc = factor_sc
125
+ self.factor_mag = factor_mag
126
+
127
+ def forward(self, x, y):
128
+ """Calculate forward propagation.
129
+ Args:
130
+ x (Tensor): Predicted signal (B, T).
131
+ y (Tensor): Groundtruth signal (B, T).
132
+ Returns:
133
+ Tensor: Multi resolution spectral convergence loss value.
134
+ Tensor: Multi resolution log STFT magnitude loss value.
135
+ """
136
+ sc_loss = 0.0
137
+ mag_loss = 0.0
138
+ for f in self.stft_losses:
139
+ sc_l, mag_l = f(x, y)
140
+ sc_loss += sc_l
141
+ mag_loss += mag_l
142
+ sc_loss /= len(self.stft_losses)
143
+ mag_loss /= len(self.stft_losses)
144
+
145
+ return self.factor_sc*sc_loss, self.factor_mag*mag_loss
146
+
147
+
148
+
149
+
150
+ class L1_Multi_STFT(torch.nn.Module):
151
+ """STFT loss module."""
152
+
153
+ def __init__(self):
154
+ """Initialize STFT loss module."""
155
+ super(L1_Multi_STFT, self).__init__()
156
+ self.multi_STFT_loss = MultiResolutionSTFTLoss()
157
+ self.l1_loss = torch.nn.L1Loss()
158
+
159
+ def forward(self, x, y):
160
+ """Calculate forward propagation.
161
+ Args:
162
+ x (Tensor): Predicted signal (B, T).
163
+ y (Tensor): Groundtruth signal (B, T).
164
+ Returns:
165
+ Tensor: Spectral convergence loss value.
166
+ Tensor: Log STFT magnitude loss value.
167
+ """
168
+ sc_loss, mag_loss = self.multi_STFT_loss(x, y)
169
+ l1_loss = self.l1_loss(x, y)
170
+ return sc_loss + mag_loss + l1_loss
171
+
172
 
173
  LOSSES = {
174
+ 'mse': torch.nn.MSELoss(),
175
+ 'L1': torch.nn.L1Loss(),
176
+ 'Multi_STFT': MultiResolutionSTFTLoss,
177
+ 'L1_Multi_STFT': L1_Multi_STFT
178
  }
179
 
180
 
main.py CHANGED
@@ -5,7 +5,6 @@ from train import train
5
 
6
  @hydra.main(version_base=None, config_path="conf", config_name="config")
7
  def main(cfg: DictConfig):
8
- print(OmegaConf.to_yaml(cfg))
9
  train(cfg)
10
 
11
 
 
5
 
6
  @hydra.main(version_base=None, config_path="conf", config_name="config")
7
  def main(cfg: DictConfig):
 
8
  train(cfg)
9
 
10
 
notebooks/EDA.ipynb CHANGED
@@ -2,44 +2,25 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "f800718e-c29f-44d8-bf41-e02d50d0f730",
7
  "metadata": {
8
  "ExecuteTime": {
9
  "start_time": "2023-04-29T13:11:15.198687Z",
10
  "end_time": "2023-04-29T13:11:15.245584Z"
11
- }
12
- },
13
- "outputs": [
14
- {
15
- "name": "stderr",
16
- "output_type": "stream",
17
- "text": [
18
- "/home/maksim/.local/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:22: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)\n",
19
- " EPSILON = torch.tensor(torch.finfo(torch.float).eps)\n"
20
- ]
21
  },
22
- {
23
- "ename": "ModuleNotFoundError",
24
- "evalue": "No module named 'matplotlib'",
25
- "output_type": "error",
26
- "traceback": [
27
- "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
28
- "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
29
- "Cell \u001B[0;32mIn[1], line 3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mdatasets\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Valentini\n\u001B[1;32m 5\u001B[0m dataset \u001B[38;5;241m=\u001B[39m Valentini()\n",
30
- "File \u001B[0;32m~/PycharmProjects/denoising/datasets.py:4\u001B[0m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Dataset\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[0;32m----> 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m load_wav\n\u001B[1;32m 7\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mValentini\u001B[39;00m(Dataset):\n\u001B[1;32m 8\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, dataset_path\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m/media/public/datasets/denoising/DS_10283_2791/\u001B[39m\u001B[38;5;124m'\u001B[39m, transform\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m 9\u001B[0m valid\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n",
31
- "File \u001B[0;32m~/PycharmProjects/denoising/utils.py:3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchaudio\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mpyplot\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mplt\u001B[39;00m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mpathlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Path\n\u001B[1;32m 7\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mcollect_valentini_paths\u001B[39m(dataset_path):\n",
32
- "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'matplotlib'"
33
- ]
34
  }
35
- ],
 
36
  "source": [
37
  "\n",
38
  "from pathlib import Path\n",
39
  "\n",
40
  "from datasets import Valentini\n",
41
  "\n",
42
- "dataset = Valentini()"
43
  ]
44
  },
45
  {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "f800718e-c29f-44d8-bf41-e02d50d0f730",
7
  "metadata": {
8
  "ExecuteTime": {
9
  "start_time": "2023-04-29T13:11:15.198687Z",
10
  "end_time": "2023-04-29T13:11:15.245584Z"
 
 
 
 
 
 
 
 
 
 
11
  },
12
+ "pycharm": {
13
+ "is_executing": true
 
 
 
 
 
 
 
 
 
 
14
  }
15
+ },
16
+ "outputs": [],
17
  "source": [
18
  "\n",
19
  "from pathlib import Path\n",
20
  "\n",
21
  "from datasets import Valentini\n",
22
  "\n",
23
+ "dataset = Valentini('/media/public/datasets/denoising/DS_10283_2791/', valid=False)"
24
  ]
25
  },
26
  {
train.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  import torch
3
  from torch.utils.data import DataLoader
4
- from pathlib import Path
5
  from omegaconf import DictConfig
6
  import wandb
7
- import torchaudio
8
 
9
  from checkpoing_saver import CheckpointSaver
10
  from denoisers import get_model
@@ -12,7 +11,7 @@ from optimizers import get_optimizer
12
  from losses import get_loss
13
  from datasets import get_datasets
14
  from testing.metrics import Metrics
15
- import omegaconf
16
 
17
  os.environ['CUDA_VISIBLE_DEVICES'] = "1"
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,67 +25,68 @@ def train(cfg: DictConfig):
26
  config=omegaconf.OmegaConf.to_container(
27
  cfg, resolve=True, throw_on_missing=True))
28
 
29
- checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'])
30
  metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
31
 
32
  model = get_model(cfg['model']).to(device)
33
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
34
  loss_fn = get_loss(cfg['loss'])
35
  train_dataset, valid_dataset = get_datasets(cfg)
 
36
 
37
- training_loader = DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True)
38
- validation_loader = DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True)
 
 
 
39
 
40
  wandb.watch(model, log_freq=100)
41
 
42
  for epoch in range(cfg['training']['num_epochs']):
43
- model.train(True)
44
- for i, data in enumerate(training_loader):
45
- inputs, labels = data
46
- inputs, labels = inputs.to(device), labels.to(device)
47
- optimizer.zero_grad()
48
- outputs = model(inputs)
49
- loss = loss_fn(outputs, labels)
50
- loss.backward()
51
- optimizer.step()
52
-
53
- if i % cfg['wandb']['log_interval'] == 0:
54
- wandb.log({"loss": loss})
55
-
56
- model.train(False)
57
-
58
- running_vloss, running_pesq, running_stoi = 0.0, 0.0, 0.0
59
- with torch.no_grad():
60
- for i, vdata in enumerate(validation_loader):
61
- vinputs, vlabels = vdata
62
- vinputs, vlabels = vinputs.to(device), vlabels.to(device)
63
- voutputs = model(vinputs)
64
- vloss = loss_fn(voutputs, vlabels)
65
- running_vloss += vloss
66
- running_metrics = metrics.calculate(denoised=voutputs, clean=vlabels)
67
- running_pesq += running_metrics['PESQ']
68
- running_stoi += running_metrics['STOI']
69
 
 
 
 
 
70
 
71
- avg_vloss = running_vloss / len(validation_loader)
72
- avg_pesq = running_pesq / len(validation_loader)
73
- avg_stoi = running_stoi / len(validation_loader)
74
 
75
- wandb.log({"valid_loss": avg_vloss,
76
- "valid_pesq": avg_pesq,
77
- "valid_stoi": avg_stoi})
78
 
79
- for tag, wav_path in cfg['validation']['wavs'].items():
80
- wav, rate = torchaudio.load(Path(cfg['validation']['path']) / wav_path)
81
- wav = torch.reshape(wav, (1, 1, -1)).to(device)
82
- prediction = model(wav)
83
- wandb.log({
84
- f"{tag}_epoch_{epoch}": wandb.Audio(
85
- prediction.cpu()[0][0],
86
- sample_rate=rate)})
87
-
88
- checkpoint_saver(model, epoch, metric_val=avg_pesq)
89
 
 
 
 
 
90
 
91
- if __name__ == '__main__':
92
- train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  from torch.utils.data import DataLoader
4
+ import omegaconf
5
  from omegaconf import DictConfig
6
  import wandb
 
7
 
8
  from checkpoing_saver import CheckpointSaver
9
  from denoisers import get_model
 
11
  from losses import get_loss
12
  from datasets import get_datasets
13
  from testing.metrics import Metrics
14
+ from datasets.minimal import Minimal
15
 
16
  os.environ['CUDA_VISIBLE_DEVICES'] = "1"
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
25
  config=omegaconf.OmegaConf.to_container(
26
  cfg, resolve=True, throw_on_missing=True))
27
 
28
+ checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
29
  metrics = Metrics(rate=cfg['dataloader']['sample_rate'])
30
 
31
  model = get_model(cfg['model']).to(device)
32
  optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
33
  loss_fn = get_loss(cfg['loss'])
34
  train_dataset, valid_dataset = get_datasets(cfg)
35
+ minimal_dataset = Minimal(cfg)
36
 
37
+ dataloaders = {
38
+ 'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True),
39
+ 'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True),
40
+ 'minimal': DataLoader(minimal_dataset)
41
+ }
42
 
43
  wandb.watch(model, log_freq=100)
44
 
45
  for epoch in range(cfg['training']['num_epochs']):
46
+ for phase in ['train', 'val']:
47
+ if phase == 'train':
48
+ model.train()
49
+ else:
50
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
53
+ for i, (inputs, labels) in enumerate(dataloaders[phase]):
54
+ inputs = inputs.to(device)
55
+ labels = labels.to(device)
56
 
57
+ optimizer.zero_grad()
 
 
58
 
59
+ with torch.set_grad_enabled(phase == 'train'):
60
+ outputs = model(inputs)
61
+ loss = loss_fn(outputs, labels)
62
 
63
+ if phase == 'train':
64
+ loss.backward()
65
+ optimizer.step()
 
 
 
 
 
 
 
66
 
67
+ running_metrics = metrics.calculate(denoised=outputs, clean=labels)
68
+ running_loss += loss.item() * inputs.size(0)
69
+ running_pesq += running_metrics['PESQ']
70
+ running_stoi += running_metrics['STOI']
71
 
72
+ if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
73
+ wandb.log({"train_loss": running_loss / (i + 1),
74
+ "train_pesq": running_pesq / (i + 1),
75
+ "train_stoi": running_stoi / (i + 1)})
76
+ epoch_loss = running_loss / len(dataloaders[phase])
77
+ eposh_pesq = running_pesq / len(dataloaders[phase])
78
+ eposh_stoi = running_stoi / len(dataloaders[phase])
79
+
80
+ wandb.log({f"{phase}_loss": epoch_loss,
81
+ f"{phase}_pesq": eposh_pesq,
82
+ f"{phase}_stoi": eposh_stoi})
83
+
84
+ if phase == 'val':
85
+ for i, (wav, rate) in enumerate(dataloaders['minimal']):
86
+ prediction = model(wav)
87
+ wandb.log({
88
+ f"{i}_example": wandb.Audio(
89
+ prediction.cpu()[0][0],
90
+ sample_rate=rate)})
91
+
92
+ checkpoint_saver(model, epoch, metric_val=eposh_pesq)
transforms.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  from torchaudio.transforms import Resample
4
  from torchvision.transforms import RandomCrop
5
 
 
 
6
  class Transform(torch.nn.Module):
7
  def __init__(
8
  self,
 
3
  from torchaudio.transforms import Resample
4
  from torchvision.transforms import RandomCrop
5
 
6
+
7
+
8
  class Transform(torch.nn.Module):
9
  def __init__(
10
  self,