Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -207,7 +207,7 @@ def main():
|
|
207 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
208 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
209 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
210 |
-
|
211 |
win_size_list=[120, 240, 480],
|
212 |
hop_size_list=[25, 50, 100],
|
213 |
reduction="mean"
|
|
|
207 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
208 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
209 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
210 |
+
fft_size_list=[256, 512, 1024],
|
211 |
win_size_list=[120, 240, 480],
|
212 |
hop_size_list=[25, 50, 100],
|
213 |
reduction="mean"
|
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -288,9 +288,9 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
288 |
reduction: str = "mean",
|
289 |
):
|
290 |
super(MultiResolutionSTFTLoss, self).__init__()
|
291 |
-
fft_size_list = fft_size_list or [1024, 2048
|
292 |
-
win_size_list = win_size_list or [600, 1200
|
293 |
-
hop_size_list = hop_size_list or [120, 240
|
294 |
|
295 |
if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
|
296 |
raise AssertionError
|
|
|
288 |
reduction: str = "mean",
|
289 |
):
|
290 |
super(MultiResolutionSTFTLoss, self).__init__()
|
291 |
+
fft_size_list = fft_size_list or [512, 1024, 2048]
|
292 |
+
win_size_list = win_size_list or [240, 600, 1200]
|
293 |
+
hop_size_list = hop_size_list or [50, 120, 240]
|
294 |
|
295 |
if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
|
296 |
raise AssertionError
|