Spaces:
Running
Running
update
Browse files
examples/frcrn/run.sh
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
|
5 |
|
6 |
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
|
7 |
-
--config_file "yaml/config-
|
8 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
9 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
10 |
|
|
|
4 |
|
5 |
|
6 |
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
|
7 |
+
--config_file "yaml/config-10.yaml" \
|
8 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
9 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
10 |
|
examples/frcrn/step_2_train_model.py
CHANGED
@@ -34,6 +34,7 @@ from tqdm import tqdm
|
|
34 |
|
35 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
36 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
|
|
37 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
38 |
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
39 |
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
|
@@ -220,6 +221,14 @@ def main():
|
|
220 |
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
221 |
|
222 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
# training loop
|
225 |
|
@@ -248,6 +257,7 @@ def main():
|
|
248 |
|
249 |
total_pesq_score = 0.
|
250 |
total_loss = 0.
|
|
|
251 |
total_neg_si_snr_loss = 0.
|
252 |
total_mask_loss = 0.
|
253 |
total_batches = 0.
|
@@ -264,10 +274,11 @@ def main():
|
|
264 |
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
265 |
denoise_audios = est_wav
|
266 |
|
|
|
267 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
268 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
269 |
|
270 |
-
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
271 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
272 |
logger.info(f"find nan or inf in loss.")
|
273 |
continue
|
@@ -284,12 +295,14 @@ def main():
|
|
284 |
|
285 |
total_pesq_score += pesq_score
|
286 |
total_loss += loss.item()
|
|
|
287 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
288 |
total_mask_loss += mask_loss.item()
|
289 |
total_batches += 1
|
290 |
|
291 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
292 |
average_loss = round(total_loss / total_batches, 4)
|
|
|
293 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
294 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
295 |
|
@@ -298,6 +311,7 @@ def main():
|
|
298 |
"lr": lr_scheduler.get_last_lr()[0],
|
299 |
"pesq_score": average_pesq_score,
|
300 |
"loss": average_loss,
|
|
|
301 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
302 |
"mask_loss": average_mask_loss,
|
303 |
})
|
@@ -311,6 +325,7 @@ def main():
|
|
311 |
|
312 |
total_pesq_score = 0.
|
313 |
total_loss = 0.
|
|
|
314 |
total_neg_si_snr_loss = 0.
|
315 |
total_mask_loss = 0.
|
316 |
total_batches = 0.
|
@@ -327,10 +342,11 @@ def main():
|
|
327 |
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
328 |
denoise_audios = est_wav
|
329 |
|
|
|
330 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
331 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
332 |
|
333 |
-
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
334 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
335 |
logger.info(f"find nan or inf in loss.")
|
336 |
continue
|
@@ -347,6 +363,7 @@ def main():
|
|
347 |
|
348 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
349 |
average_loss = round(total_loss / total_batches, 4)
|
|
|
350 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
351 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
352 |
|
@@ -355,12 +372,14 @@ def main():
|
|
355 |
"lr": lr_scheduler.get_last_lr()[0],
|
356 |
"pesq_score": average_pesq_score,
|
357 |
"loss": average_loss,
|
|
|
358 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
359 |
"mask_loss": average_mask_loss,
|
360 |
})
|
361 |
|
362 |
total_pesq_score = 0.
|
363 |
total_loss = 0.
|
|
|
364 |
total_neg_si_snr_loss = 0.
|
365 |
total_mask_loss = 0.
|
366 |
total_batches = 0.
|
|
|
34 |
|
35 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
36 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
37 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
38 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
39 |
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
40 |
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
|
|
|
221 |
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
222 |
|
223 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
224 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
225 |
+
fft_size_list=[256, 512, 1024],
|
226 |
+
win_size_list=[256, 512, 1024],
|
227 |
+
hop_size_list=[128, 256, 512],
|
228 |
+
factor_sc=1.5,
|
229 |
+
factor_mag=1.0,
|
230 |
+
reduction="mean"
|
231 |
+
).to(device)
|
232 |
|
233 |
# training loop
|
234 |
|
|
|
257 |
|
258 |
total_pesq_score = 0.
|
259 |
total_loss = 0.
|
260 |
+
total_mr_stft_loss = 0.
|
261 |
total_neg_si_snr_loss = 0.
|
262 |
total_mask_loss = 0.
|
263 |
total_batches = 0.
|
|
|
274 |
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
275 |
denoise_audios = est_wav
|
276 |
|
277 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
278 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
279 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
280 |
|
281 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
282 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
283 |
logger.info(f"find nan or inf in loss.")
|
284 |
continue
|
|
|
295 |
|
296 |
total_pesq_score += pesq_score
|
297 |
total_loss += loss.item()
|
298 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
299 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
300 |
total_mask_loss += mask_loss.item()
|
301 |
total_batches += 1
|
302 |
|
303 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
304 |
average_loss = round(total_loss / total_batches, 4)
|
305 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
306 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
307 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
308 |
|
|
|
311 |
"lr": lr_scheduler.get_last_lr()[0],
|
312 |
"pesq_score": average_pesq_score,
|
313 |
"loss": average_loss,
|
314 |
+
"mr_stft_loss": average_mr_stft_loss,
|
315 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
316 |
"mask_loss": average_mask_loss,
|
317 |
})
|
|
|
325 |
|
326 |
total_pesq_score = 0.
|
327 |
total_loss = 0.
|
328 |
+
total_mr_stft_loss = 0.
|
329 |
total_neg_si_snr_loss = 0.
|
330 |
total_mask_loss = 0.
|
331 |
total_batches = 0.
|
|
|
342 |
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
343 |
denoise_audios = est_wav
|
344 |
|
345 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
346 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
347 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
348 |
|
349 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
350 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
351 |
logger.info(f"find nan or inf in loss.")
|
352 |
continue
|
|
|
363 |
|
364 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
365 |
average_loss = round(total_loss / total_batches, 4)
|
366 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
367 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
368 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
369 |
|
|
|
372 |
"lr": lr_scheduler.get_last_lr()[0],
|
373 |
"pesq_score": average_pesq_score,
|
374 |
"loss": average_loss,
|
375 |
+
"mr_stft_loss": average_mr_stft_loss,
|
376 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
377 |
"mask_loss": average_mask_loss,
|
378 |
})
|
379 |
|
380 |
total_pesq_score = 0.
|
381 |
total_loss = 0.
|
382 |
+
total_mr_stft_loss = 0.
|
383 |
total_neg_si_snr_loss = 0.
|
384 |
total_mask_loss = 0.
|
385 |
total_batches = 0.
|
toolbox/torchaudio/models/frcrn/modeling_frcrn.py
CHANGED
@@ -300,25 +300,38 @@ class FRCRNPretrainedModel(FRCRN):
|
|
300 |
def main():
|
301 |
# model = FRCRN(
|
302 |
# use_complex_networks=True,
|
303 |
-
# model_complexity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
# model_depth=14,
|
305 |
# padding_mode="zeros",
|
306 |
-
# nfft=
|
307 |
-
# win_size=
|
308 |
-
# hop_size=
|
309 |
# win_type="hann",
|
310 |
# )
|
|
|
311 |
model = FRCRN(
|
312 |
use_complex_networks=True,
|
313 |
-
model_complexity=
|
314 |
-
model_depth=
|
315 |
padding_mode="zeros",
|
316 |
-
nfft=
|
317 |
-
win_size=
|
318 |
-
hop_size=
|
319 |
win_type="hann",
|
320 |
)
|
321 |
-
|
|
|
322 |
|
323 |
est_spec, est_wav, est_mask = model.forward(mixture)
|
324 |
print(est_spec.shape)
|
|
|
300 |
def main():
|
301 |
# model = FRCRN(
|
302 |
# use_complex_networks=True,
|
303 |
+
# model_complexity=-1,
|
304 |
+
# model_depth=10,
|
305 |
+
# padding_mode="zeros",
|
306 |
+
# nfft=128,
|
307 |
+
# win_size=128,
|
308 |
+
# hop_size=64,
|
309 |
+
# win_type="hann",
|
310 |
+
# )
|
311 |
+
|
312 |
+
# model = FRCRN(
|
313 |
+
# use_complex_networks=True,
|
314 |
+
# model_complexity=-1,
|
315 |
# model_depth=14,
|
316 |
# padding_mode="zeros",
|
317 |
+
# nfft=640,
|
318 |
+
# win_size=640,
|
319 |
+
# hop_size=320,
|
320 |
# win_type="hann",
|
321 |
# )
|
322 |
+
|
323 |
model = FRCRN(
|
324 |
use_complex_networks=True,
|
325 |
+
model_complexity=20,
|
326 |
+
model_depth=20,
|
327 |
padding_mode="zeros",
|
328 |
+
nfft=512,
|
329 |
+
win_size=512,
|
330 |
+
hop_size=256,
|
331 |
win_type="hann",
|
332 |
)
|
333 |
+
|
334 |
+
mixture = torch.rand(size=(1, 32000), dtype=torch.float32)
|
335 |
|
336 |
est_spec, est_wav, est_mask = model.forward(mixture)
|
337 |
print(est_spec.shape)
|
toolbox/torchaudio/models/frcrn/unet.py
CHANGED
@@ -339,19 +339,8 @@ class UNet(nn.Module):
|
|
339 |
return cmp_spec
|
340 |
|
341 |
|
342 |
-
def
|
343 |
# [batch_size, 1, freq_bins, time_steps, 2]
|
344 |
-
# x = torch.rand(size=(1, 1, 257, 2000, 2))
|
345 |
-
# unet = UNet(
|
346 |
-
# in_channels=1,
|
347 |
-
# model_complexity=45,
|
348 |
-
# model_depth=20,
|
349 |
-
# use_complex_networks=True
|
350 |
-
# )
|
351 |
-
# print(unet)
|
352 |
-
# result = unet.forward(x)
|
353 |
-
# print(result.shape)
|
354 |
-
|
355 |
# x = torch.rand(size=(1, 1, 65, 2000, 2))
|
356 |
x = torch.rand(size=(1, 1, 65, 200, 2))
|
357 |
unet = UNet(
|
@@ -366,5 +355,20 @@ def main():
|
|
366 |
return
|
367 |
|
368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
if __name__ == "__main__":
|
370 |
-
|
|
|
339 |
return cmp_spec
|
340 |
|
341 |
|
342 |
+
def main10():
|
343 |
# [batch_size, 1, freq_bins, time_steps, 2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
# x = torch.rand(size=(1, 1, 65, 2000, 2))
|
345 |
x = torch.rand(size=(1, 1, 65, 200, 2))
|
346 |
unet = UNet(
|
|
|
355 |
return
|
356 |
|
357 |
|
358 |
+
def main20():
|
359 |
+
# [batch_size, 1, freq_bins, time_steps, 2]
|
360 |
+
x = torch.rand(size=(1, 1, 257, 2000, 2))
|
361 |
+
unet = UNet(
|
362 |
+
in_channels=1,
|
363 |
+
model_complexity=45,
|
364 |
+
model_depth=20,
|
365 |
+
use_complex_networks=True
|
366 |
+
)
|
367 |
+
print(unet)
|
368 |
+
result = unet.forward(x)
|
369 |
+
print(result.shape)
|
370 |
+
return
|
371 |
+
|
372 |
+
|
373 |
if __name__ == "__main__":
|
374 |
+
main20()
|