Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/NVIDIA/CleanUNet/blob/main/train.py
|
|
|
|
|
5 |
"""
|
6 |
import argparse
|
7 |
import json
|
@@ -20,6 +22,7 @@ sys.path.append(os.path.join(pwd, "../../"))
|
|
20 |
|
21 |
import numpy as np
|
22 |
import torch
|
|
|
23 |
from torch.nn import functional as F
|
24 |
from torch.utils.data.dataloader import DataLoader
|
25 |
from tqdm import tqdm
|
@@ -27,6 +30,9 @@ from tqdm import tqdm
|
|
27 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
28 |
from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
|
29 |
from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
def get_args():
|
@@ -36,6 +42,9 @@ def get_args():
|
|
36 |
|
37 |
parser.add_argument("--max_epochs", default=100, type=int)
|
38 |
|
|
|
|
|
|
|
39 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
40 |
parser.add_argument("--patience", default=5, type=int)
|
41 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
@@ -158,56 +167,37 @@ def main():
|
|
158 |
model = CleanUNetPretrainedModel(config).to(device)
|
159 |
|
160 |
# optimizer
|
161 |
-
logger.info("prepare optimizer, lr_scheduler")
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
generator.load_state_dict(state_dict, strict=True)
|
184 |
-
logger.info(f"load state dict for discriminator.")
|
185 |
-
with open(discriminator_pt.as_posix(), "rb") as f:
|
186 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
187 |
-
discriminator.load_state_dict(state_dict, strict=True)
|
188 |
-
|
189 |
-
logger.info(f"load state dict for optim_g.")
|
190 |
-
with open(optim_g_pth.as_posix(), "rb") as f:
|
191 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
192 |
-
optim_g.load_state_dict(state_dict)
|
193 |
-
logger.info(f"load state dict for optim_d.")
|
194 |
-
with open(optim_d_pth.as_posix(), "rb") as f:
|
195 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
196 |
-
optim_d.load_state_dict(state_dict)
|
197 |
-
|
198 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
199 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
200 |
|
201 |
# training loop
|
202 |
|
203 |
# state
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
com_err = 10000000000
|
210 |
-
stft_err = 10000000000
|
211 |
|
212 |
model_list = list()
|
213 |
best_idx_epoch = None
|
@@ -215,96 +205,74 @@ def main():
|
|
215 |
patience_count = 0
|
216 |
|
217 |
logger.info("training")
|
218 |
-
for idx_epoch in range(
|
219 |
# train
|
220 |
-
|
221 |
-
discriminator.train()
|
222 |
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
225 |
total_batches = 0.
|
|
|
226 |
progress_bar = tqdm(
|
227 |
total=len(train_data_loader),
|
228 |
desc="Training; epoch: {}".format(idx_epoch),
|
229 |
)
|
230 |
for batch in train_data_loader:
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
# print("pesq is None!")
|
257 |
-
loss_disc_g = 0
|
258 |
-
|
259 |
-
loss_disc_all = loss_disc_r + loss_disc_g
|
260 |
-
loss_disc_all.backward()
|
261 |
-
optim_d.step()
|
262 |
-
|
263 |
-
# Generator
|
264 |
-
optim_g.zero_grad()
|
265 |
-
# L2 Magnitude Loss
|
266 |
-
loss_mag = F.mse_loss(clean_mag, mag_g)
|
267 |
-
# Anti-wrapping Phase Loss
|
268 |
-
loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
|
269 |
-
loss_pha = loss_ip + loss_gd + loss_iaf
|
270 |
-
# L2 Complex Loss
|
271 |
-
loss_com = F.mse_loss(clean_com, com_g) * 2
|
272 |
-
# L2 Consistency Loss
|
273 |
-
loss_stft = F.mse_loss(com_g, com_g_hat) * 2
|
274 |
-
# Time Loss
|
275 |
-
loss_time = F.l1_loss(clean_audio, audio_g)
|
276 |
-
# Metric Loss
|
277 |
-
metric_g = discriminator.forward(clean_mag, mag_g_hat)
|
278 |
-
loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
|
279 |
-
|
280 |
-
loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
|
281 |
-
|
282 |
-
loss_gen_all.backward()
|
283 |
-
optim_g.step()
|
284 |
-
|
285 |
-
total_loss_d += loss_disc_all.item()
|
286 |
-
total_loss_g += loss_gen_all.item()
|
287 |
total_batches += 1
|
288 |
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
291 |
|
292 |
progress_bar.update(1)
|
293 |
progress_bar.set_postfix({
|
294 |
-
"
|
295 |
-
"
|
|
|
|
|
|
|
296 |
})
|
297 |
|
298 |
# evaluation
|
299 |
-
|
300 |
-
discriminator.eval()
|
301 |
|
302 |
torch.cuda.empty_cache()
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
308 |
total_batches = 0.
|
309 |
|
310 |
progress_bar = tqdm(
|
@@ -313,61 +281,52 @@ def main():
|
|
313 |
)
|
314 |
with torch.no_grad():
|
315 |
for batch in valid_data_loader:
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
total_com_err += F.mse_loss(clean_com, com_g).item()
|
337 |
-
total_stft_err += F.mse_loss(com_g, com_g_hat).item()
|
338 |
-
|
339 |
total_batches += 1
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
|
347 |
progress_bar.update(1)
|
348 |
progress_bar.set_postfix({
|
349 |
-
"pesq_metric":
|
350 |
-
"
|
351 |
-
"
|
352 |
-
"
|
353 |
-
"
|
354 |
})
|
355 |
|
356 |
# scheduler
|
357 |
-
|
358 |
-
scheduler_d.step()
|
359 |
|
360 |
# save path
|
361 |
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
362 |
epoch_dir.mkdir(parents=True, exist_ok=False)
|
363 |
|
364 |
# save models
|
365 |
-
|
366 |
-
discriminator.save_pretrained(epoch_dir.as_posix())
|
367 |
-
|
368 |
-
# save optim
|
369 |
-
torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
|
370 |
-
torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
|
371 |
|
372 |
model_list.append(epoch_dir)
|
373 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
@@ -377,25 +336,23 @@ def main():
|
|
377 |
# save metric
|
378 |
if best_metric is None:
|
379 |
best_idx_epoch = idx_epoch
|
380 |
-
best_metric =
|
381 |
-
elif
|
382 |
# great is better.
|
383 |
best_idx_epoch = idx_epoch
|
384 |
-
best_metric =
|
385 |
else:
|
386 |
pass
|
387 |
|
388 |
metrics = {
|
389 |
"idx_epoch": idx_epoch,
|
390 |
"best_idx_epoch": best_idx_epoch,
|
391 |
-
|
392 |
-
"
|
393 |
-
|
394 |
-
"
|
395 |
-
"
|
396 |
-
"
|
397 |
-
"com_err": com_err,
|
398 |
-
"stft_err": stft_err,
|
399 |
|
400 |
}
|
401 |
metrics_filename = epoch_dir / "metrics_epoch.json"
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/NVIDIA/CleanUNet/blob/main/train.py
|
5 |
+
|
6 |
+
https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json
|
7 |
"""
|
8 |
import argparse
|
9 |
import json
|
|
|
22 |
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
+
import torch.nn as nn
|
26 |
from torch.nn import functional as F
|
27 |
from torch.utils.data.dataloader import DataLoader
|
28 |
from tqdm import tqdm
|
|
|
30 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
31 |
from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
|
32 |
from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
|
33 |
+
from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
|
34 |
+
from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
|
35 |
+
from toolbox.torchaudio.models.clean_unet.metrics import batch_pesq
|
36 |
|
37 |
|
38 |
def get_args():
|
|
|
42 |
|
43 |
parser.add_argument("--max_epochs", default=100, type=int)
|
44 |
|
45 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
46 |
+
parser.add_argument("--learning_rate", default=2e-4, type=float)
|
47 |
+
|
48 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
49 |
parser.add_argument("--patience", default=5, type=int)
|
50 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
|
|
167 |
model = CleanUNetPretrainedModel(config).to(device)
|
168 |
|
169 |
# optimizer
|
170 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
171 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.learning_rate)
|
172 |
+
lr_scheduler = LinearWarmupCosineDecay(
|
173 |
+
optimizer,
|
174 |
+
lr_max=args.learning_rate,
|
175 |
+
n_iter=250000,
|
176 |
+
iteration=250000,
|
177 |
+
divider=25,
|
178 |
+
warmup_proportion=0.05,
|
179 |
+
phase=("linear", "cosine"),
|
180 |
+
)
|
181 |
+
# ae_loss_fn = nn.MSELoss(reduction="mean")
|
182 |
+
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
183 |
+
|
184 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
185 |
+
fft_sizes=[512, 1024, 2048],
|
186 |
+
hop_sizes=[50, 120, 240],
|
187 |
+
win_lengths=[240, 600, 1200],
|
188 |
+
sc_lambda=0.5,
|
189 |
+
mag_lambda=0.5,
|
190 |
+
band="full"
|
191 |
+
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# training loop
|
194 |
|
195 |
# state
|
196 |
+
average_pesq_metric = 10000000000
|
197 |
+
average_loss = 10000000000
|
198 |
+
average_ae_loss = 10000000000
|
199 |
+
average_sc_loss = 10000000000
|
200 |
+
average_mag_loss = 10000000000
|
|
|
|
|
201 |
|
202 |
model_list = list()
|
203 |
best_idx_epoch = None
|
|
|
205 |
patience_count = 0
|
206 |
|
207 |
logger.info("training")
|
208 |
+
for idx_epoch in range(args.max_epochs):
|
209 |
# train
|
210 |
+
model.train()
|
|
|
211 |
|
212 |
+
total_pesq_metric = 0.
|
213 |
+
total_loss = 0.
|
214 |
+
total_ae_loss = 0.
|
215 |
+
total_sc_loss = 0.
|
216 |
+
total_mag_loss = 0.
|
217 |
total_batches = 0.
|
218 |
+
|
219 |
progress_bar = tqdm(
|
220 |
total=len(train_data_loader),
|
221 |
desc="Training; epoch: {}".format(idx_epoch),
|
222 |
)
|
223 |
for batch in train_data_loader:
|
224 |
+
clean_audios, noisy_audios = batch
|
225 |
+
clean_audios = clean_audios.to(device)
|
226 |
+
noisy_audios = noisy_audios.to(device)
|
227 |
+
|
228 |
+
enhanced_audios = model.forward(noisy_audios)
|
229 |
+
|
230 |
+
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
231 |
+
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
|
232 |
+
|
233 |
+
loss = ae_loss + sc_loss + mag_loss
|
234 |
+
|
235 |
+
enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
|
236 |
+
clean_audios_list_r = list(clean_audios.cpu().numpy())
|
237 |
+
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
238 |
+
|
239 |
+
optimizer.zero_grad()
|
240 |
+
loss.backward()
|
241 |
+
optimizer.step()
|
242 |
+
lr_scheduler.step()
|
243 |
+
|
244 |
+
total_pesq_metric += pesq_metric.item()
|
245 |
+
total_loss += loss.item()
|
246 |
+
total_ae_loss += ae_loss.item()
|
247 |
+
total_sc_loss += sc_loss.item()
|
248 |
+
total_mag_loss += mag_loss.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
total_batches += 1
|
250 |
|
251 |
+
average_pesq_metric = round(total_pesq_metric / total_batches, 4)
|
252 |
+
average_loss = round(total_loss / total_batches, 4)
|
253 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
254 |
+
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
255 |
+
average_mag_loss = round(total_mag_loss / total_batches, 4)
|
256 |
|
257 |
progress_bar.update(1)
|
258 |
progress_bar.set_postfix({
|
259 |
+
"pesq_metric": average_pesq_metric,
|
260 |
+
"loss": average_loss,
|
261 |
+
"ae_loss": average_ae_loss,
|
262 |
+
"sc_loss": average_sc_loss,
|
263 |
+
"mag_loss": average_mag_loss,
|
264 |
})
|
265 |
|
266 |
# evaluation
|
267 |
+
model.eval()
|
|
|
268 |
|
269 |
torch.cuda.empty_cache()
|
270 |
+
|
271 |
+
total_pesq_metric = 0.
|
272 |
+
total_loss = 0.
|
273 |
+
total_ae_loss = 0.
|
274 |
+
total_sc_loss = 0.
|
275 |
+
total_mag_loss = 0.
|
276 |
total_batches = 0.
|
277 |
|
278 |
progress_bar = tqdm(
|
|
|
281 |
)
|
282 |
with torch.no_grad():
|
283 |
for batch in valid_data_loader:
|
284 |
+
clean_audios, noisy_audios = batch
|
285 |
+
clean_audios = clean_audios.to(device)
|
286 |
+
noisy_audios = noisy_audios.to(device)
|
287 |
+
|
288 |
+
enhanced_audios = model.forward(noisy_audios)
|
289 |
+
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
290 |
+
ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
|
291 |
+
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
|
292 |
+
|
293 |
+
loss = ae_loss + sc_loss + mag_loss
|
294 |
+
|
295 |
+
enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
|
296 |
+
clean_audios_list_r = list(clean_audios.cpu().numpy())
|
297 |
+
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
298 |
+
|
299 |
+
total_pesq_metric += pesq_metric.item()
|
300 |
+
total_loss += loss.item()
|
301 |
+
total_ae_loss += ae_loss.item()
|
302 |
+
total_sc_loss += sc_loss.item()
|
303 |
+
total_mag_loss += mag_loss.item()
|
|
|
|
|
|
|
304 |
total_batches += 1
|
305 |
|
306 |
+
average_pesq_metric = round(total_pesq_metric / total_batches, 4)
|
307 |
+
average_loss = round(total_loss / total_batches, 4)
|
308 |
+
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
309 |
+
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
310 |
+
average_mag_loss = round(total_mag_loss / total_batches, 4)
|
311 |
|
312 |
progress_bar.update(1)
|
313 |
progress_bar.set_postfix({
|
314 |
+
"pesq_metric": average_pesq_metric,
|
315 |
+
"loss": average_loss,
|
316 |
+
"ae_loss": average_ae_loss,
|
317 |
+
"sc_loss": average_sc_loss,
|
318 |
+
"mag_loss": average_mag_loss,
|
319 |
})
|
320 |
|
321 |
# scheduler
|
322 |
+
lr_scheduler.step()
|
|
|
323 |
|
324 |
# save path
|
325 |
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
326 |
epoch_dir.mkdir(parents=True, exist_ok=False)
|
327 |
|
328 |
# save models
|
329 |
+
model.save_pretrained(epoch_dir.as_posix())
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
model_list.append(epoch_dir)
|
332 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
|
|
336 |
# save metric
|
337 |
if best_metric is None:
|
338 |
best_idx_epoch = idx_epoch
|
339 |
+
best_metric = average_pesq_metric
|
340 |
+
elif average_pesq_metric > best_metric:
|
341 |
# great is better.
|
342 |
best_idx_epoch = idx_epoch
|
343 |
+
best_metric = average_pesq_metric
|
344 |
else:
|
345 |
pass
|
346 |
|
347 |
metrics = {
|
348 |
"idx_epoch": idx_epoch,
|
349 |
"best_idx_epoch": best_idx_epoch,
|
350 |
+
|
351 |
+
"pesq_metric": average_pesq_metric,
|
352 |
+
"loss": average_loss,
|
353 |
+
"ae_loss": average_ae_loss,
|
354 |
+
"sc_loss": average_sc_loss,
|
355 |
+
"mag_loss": average_mag_loss,
|
|
|
|
|
356 |
|
357 |
}
|
358 |
metrics_filename = epoch_dir / "metrics_epoch.json"
|
toolbox/torchaudio/models/clean_unet/loss.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# from distutils.version import LooseVersion
|
8 |
+
|
9 |
+
|
10 |
+
# is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
|
11 |
+
is_pytorch_17plus = True
|
12 |
+
|
13 |
+
|
14 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
15 |
+
"""
|
16 |
+
Perform STFT and convert to magnitude spectrogram.
|
17 |
+
:param x: Tensor, Input signal tensor (B, T).
|
18 |
+
:param fft_size: int, FFT size.
|
19 |
+
:param hop_size: int, Hop size.
|
20 |
+
:param win_length: int, Window length.
|
21 |
+
:param window: str, Window function type.
|
22 |
+
:return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
23 |
+
"""
|
24 |
+
|
25 |
+
if is_pytorch_17plus:
|
26 |
+
x_stft = torch.stft(
|
27 |
+
x, fft_size, hop_size, win_length, window, return_complex=False
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
31 |
+
real = x_stft[..., 0]
|
32 |
+
imag = x_stft[..., 1]
|
33 |
+
|
34 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
35 |
+
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
36 |
+
|
37 |
+
|
38 |
+
class SpectralConvergenceLoss(torch.nn.Module):
|
39 |
+
"""Spectral convergence loss module."""
|
40 |
+
|
41 |
+
def __init__(self):
|
42 |
+
super(SpectralConvergenceLoss, self).__init__()
|
43 |
+
|
44 |
+
def forward(self, x_mag, y_mag):
|
45 |
+
"""
|
46 |
+
Calculate forward propagation.
|
47 |
+
:param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
48 |
+
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
49 |
+
:return: Tensor, Spectral convergence loss value.
|
50 |
+
"""
|
51 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
52 |
+
|
53 |
+
|
54 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
55 |
+
"""Log STFT magnitude loss module."""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
59 |
+
|
60 |
+
def forward(self, x_mag, y_mag):
|
61 |
+
"""
|
62 |
+
Calculate forward propagation.
|
63 |
+
:param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
64 |
+
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
65 |
+
:return: Tensor, Log STFT magnitude loss value.
|
66 |
+
"""
|
67 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
68 |
+
|
69 |
+
|
70 |
+
class STFTLoss(torch.nn.Module):
|
71 |
+
"""STFT loss module."""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
|
75 |
+
band="full"
|
76 |
+
):
|
77 |
+
super(STFTLoss, self).__init__()
|
78 |
+
self.fft_size = fft_size
|
79 |
+
self.shift_size = shift_size
|
80 |
+
self.win_length = win_length
|
81 |
+
self.band = band
|
82 |
+
|
83 |
+
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
84 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
85 |
+
# NOTE(kan-bayashi): Use register_buffer to fix #223
|
86 |
+
self.register_buffer("window", getattr(torch, window)(win_length))
|
87 |
+
|
88 |
+
def forward(self, x, y):
|
89 |
+
"""
|
90 |
+
Calculate forward propagation.
|
91 |
+
:param x: Tensor, Predicted signal (B, T).
|
92 |
+
:param y: Tensor, Groundtruth signal (B, T).
|
93 |
+
:return:
|
94 |
+
Tensor, Spectral convergence loss value.
|
95 |
+
Tensor, Log STFT magnitude loss value.
|
96 |
+
"""
|
97 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
98 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
99 |
+
|
100 |
+
if self.band == "high":
|
101 |
+
freq_mask_ind = x_mag.shape[1] // 2 # only select high frequency bands
|
102 |
+
sc_loss = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
|
103 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
|
104 |
+
elif self.band == "full":
|
105 |
+
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
106 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
107 |
+
else:
|
108 |
+
raise NotImplementedError
|
109 |
+
|
110 |
+
return sc_loss, mag_loss
|
111 |
+
|
112 |
+
|
113 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
114 |
+
"""Multi resolution STFT loss module."""
|
115 |
+
|
116 |
+
def __init__(self,
|
117 |
+
fft_sizes=None, hop_sizes=None, win_lengths=None,
|
118 |
+
window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full",
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Initialize Multi resolution STFT loss module.
|
122 |
+
:param fft_sizes: list, List of FFT sizes.
|
123 |
+
:param hop_sizes: list, List of hop sizes.
|
124 |
+
:param win_lengths: list, List of window lengths.
|
125 |
+
:param window: str, Window function type.
|
126 |
+
:param sc_lambda: float, a balancing factor across different losses.
|
127 |
+
:param mag_lambda: float, a balancing factor across different losses.
|
128 |
+
:param band: str, high-band or full-band loss
|
129 |
+
"""
|
130 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
131 |
+
fft_sizes = fft_sizes or [1024, 2048, 512]
|
132 |
+
hop_sizes = hop_sizes or [120, 240, 50]
|
133 |
+
win_lengths = win_lengths or [600, 1200, 240]
|
134 |
+
|
135 |
+
self.sc_lambda = sc_lambda
|
136 |
+
self.mag_lambda = mag_lambda
|
137 |
+
|
138 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
139 |
+
self.stft_losses = torch.nn.ModuleList()
|
140 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
141 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window, band)]
|
142 |
+
|
143 |
+
def forward(self, x, y):
|
144 |
+
"""
|
145 |
+
Calculate forward propagation.
|
146 |
+
:param x: Tensor, Predicted signal (B, T) or (B, #subband, T).
|
147 |
+
:param y: Tensor, Groundtruth signal (B, T) or (B, #subband, T).
|
148 |
+
:return:
|
149 |
+
Tensor, Multi resolution spectral convergence loss value.
|
150 |
+
Tensor, Multi resolution log STFT magnitude loss value.
|
151 |
+
"""
|
152 |
+
if len(x.shape) == 3:
|
153 |
+
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
|
154 |
+
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
|
155 |
+
sc_loss = 0.0
|
156 |
+
mag_loss = 0.0
|
157 |
+
for f in self.stft_losses:
|
158 |
+
sc_l, mag_l = f(x, y)
|
159 |
+
sc_loss += sc_l
|
160 |
+
mag_loss += mag_l
|
161 |
+
|
162 |
+
sc_loss *= self.sc_lambda
|
163 |
+
sc_loss /= len(self.stft_losses)
|
164 |
+
mag_loss *= self.mag_lambda
|
165 |
+
mag_loss /= len(self.stft_losses)
|
166 |
+
|
167 |
+
return sc_loss, mag_loss
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == '__main__':
|
171 |
+
pass
|
toolbox/torchaudio/models/clean_unet/metrics.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from joblib import Parallel, delayed
|
4 |
+
import numpy as np
|
5 |
+
from pesq import pesq
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def cal_pesq(clean, noisy, sr=16000):
|
10 |
+
try:
|
11 |
+
pesq_score = pesq(sr, clean, noisy, "wb")
|
12 |
+
except Exception as e:
|
13 |
+
# print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
14 |
+
# error can happen due to silent period
|
15 |
+
pesq_score = -1
|
16 |
+
return pesq_score
|
17 |
+
|
18 |
+
|
19 |
+
def batch_pesq(clean, noisy):
|
20 |
+
pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
|
21 |
+
pesq_score = np.array(pesq_score)
|
22 |
+
if -1 in pesq_score:
|
23 |
+
return None
|
24 |
+
pesq_score = (pesq_score - 1) / 3.5
|
25 |
+
return torch.FloatTensor(pesq_score)
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
|
30 |
+
prediction = torch.rand(size=(1, 160000), dtype=torch.float32)
|
31 |
+
ground_truth = torch.rand(size=(1, 160000), dtype=torch.float32)
|
32 |
+
|
33 |
+
prediction_list_r = list(prediction.cpu().numpy())
|
34 |
+
ground_truth_list_r = list(ground_truth.cpu().numpy())
|
35 |
+
|
36 |
+
pesq_score = batch_pesq(prediction_list_r, ground_truth_list_r)
|
37 |
+
print(pesq_score)
|
38 |
+
return
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
main()
|
toolbox/torchaudio/models/clean_unet/training.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def anneal_linear(start, end, proportion):
|
7 |
+
return start + proportion * (end - start)
|
8 |
+
|
9 |
+
|
10 |
+
def anneal_cosine(start, end, proportion):
|
11 |
+
cos_val = math.cos(math.pi * proportion) + 1
|
12 |
+
return end + (start - end) / 2 * cos_val
|
13 |
+
|
14 |
+
|
15 |
+
class Phase:
|
16 |
+
def __init__(self, start, end, n_iter, cur_iter, anneal_fn):
|
17 |
+
self.start, self.end = start, end
|
18 |
+
self.n_iter = n_iter
|
19 |
+
self.anneal_fn = anneal_fn
|
20 |
+
self.n = cur_iter
|
21 |
+
|
22 |
+
def step(self):
|
23 |
+
self.n += 1
|
24 |
+
|
25 |
+
return self.anneal_fn(self.start, self.end, self.n / self.n_iter)
|
26 |
+
|
27 |
+
def reset(self):
|
28 |
+
self.n = 0
|
29 |
+
|
30 |
+
@property
|
31 |
+
def is_done(self):
|
32 |
+
return self.n >= self.n_iter
|
33 |
+
|
34 |
+
|
35 |
+
class LinearWarmupCosineDecay(object):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
optimizer,
|
39 |
+
lr_max,
|
40 |
+
n_iter,
|
41 |
+
iteration=0,
|
42 |
+
divider=25,
|
43 |
+
warmup_proportion=0.3,
|
44 |
+
phase=('linear', 'cosine'),
|
45 |
+
):
|
46 |
+
self.optimizer = optimizer
|
47 |
+
|
48 |
+
phase1 = int(n_iter * warmup_proportion)
|
49 |
+
phase2 = n_iter - phase1
|
50 |
+
lr_min = lr_max / divider
|
51 |
+
|
52 |
+
phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine}
|
53 |
+
|
54 |
+
cur_iter_phase1 = iteration
|
55 |
+
cur_iter_phase2 = max(0, iteration - phase1)
|
56 |
+
self.lr_phase = [
|
57 |
+
Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]),
|
58 |
+
Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]),
|
59 |
+
]
|
60 |
+
|
61 |
+
if iteration < phase1:
|
62 |
+
self.phase = 0
|
63 |
+
else:
|
64 |
+
self.phase = 1
|
65 |
+
|
66 |
+
def step(self):
|
67 |
+
lr = self.lr_phase[self.phase].step()
|
68 |
+
|
69 |
+
for group in self.optimizer.param_groups:
|
70 |
+
group['lr'] = lr
|
71 |
+
|
72 |
+
if self.lr_phase[self.phase].is_done:
|
73 |
+
self.phase += 1
|
74 |
+
|
75 |
+
if self.phase >= len(self.lr_phase):
|
76 |
+
for phase in self.lr_phase:
|
77 |
+
phase.reset()
|
78 |
+
|
79 |
+
self.phase = 0
|
80 |
+
|
81 |
+
return lr
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
pass
|