HoneyTian commited on
Commit
7b7acb0
·
1 Parent(s): a0cbcda
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -207,7 +207,7 @@ def main():
207
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
208
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
209
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
210
- # fft_size_list=[256, 512, 1024],
211
  win_size_list=[120, 240, 480],
212
  hop_size_list=[25, 50, 100],
213
  reduction="mean"
 
207
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
208
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
209
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
210
+ fft_size_list=[256, 512, 1024],
211
  win_size_list=[120, 240, 480],
212
  hop_size_list=[25, 50, 100],
213
  reduction="mean"
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -288,9 +288,9 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
288
  reduction: str = "mean",
289
  ):
290
  super(MultiResolutionSTFTLoss, self).__init__()
291
- fft_size_list = fft_size_list or [1024, 2048, 512]
292
- win_size_list = win_size_list or [600, 1200, 240]
293
- hop_size_list = hop_size_list or [120, 240, 50]
294
 
295
  if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
296
  raise AssertionError
 
288
  reduction: str = "mean",
289
  ):
290
  super(MultiResolutionSTFTLoss, self).__init__()
291
+ fft_size_list = fft_size_list or [512, 1024, 2048]
292
+ win_size_list = win_size_list or [240, 600, 1200]
293
+ hop_size_list = hop_size_list or [50, 120, 240]
294
 
295
  if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
296
  raise AssertionError