HoneyTian commited on
Commit
a556ebf
·
1 Parent(s): 4927a3a
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -206,7 +206,7 @@ def main():
206
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
207
  lds_loss_fn = LSDLoss(reduction="mean").to(device)
208
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
209
- fft_size_list=[256, 512, 1024],
210
  win_size_list=[120, 240, 480],
211
  hop_size_list=[25, 50, 100],
212
  reduction="mean"
@@ -240,7 +240,7 @@ def main():
240
  total_lds_loss = 0.
241
  total_batches = 0.
242
  progress_bar = tqdm(
243
- desc="Training; epoch: {}".format(idx_epoch),
244
  )
245
  for batch in train_data_loader:
246
  clean_audios, noisy_audios = batch
@@ -305,7 +305,7 @@ def main():
305
  total_batches = 0.
306
 
307
  progress_bar = tqdm(
308
- desc="Evaluation; epoch: {}".format(idx_epoch),
309
  )
310
  with torch.no_grad():
311
  for batch in valid_data_loader:
 
206
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
207
  lds_loss_fn = LSDLoss(reduction="mean").to(device)
208
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
209
+ # fft_size_list=[256, 512, 1024],
210
  win_size_list=[120, 240, 480],
211
  hop_size_list=[25, 50, 100],
212
  reduction="mean"
 
240
  total_lds_loss = 0.
241
  total_batches = 0.
242
  progress_bar = tqdm(
243
+ desc="Training; epoch-{}".format(idx_epoch),
244
  )
245
  for batch in train_data_loader:
246
  clean_audios, noisy_audios = batch
 
305
  total_batches = 0.
306
 
307
  progress_bar = tqdm(
308
+ desc="Evaluation; epoch-{}".format(idx_epoch),
309
  )
310
  with torch.no_grad():
311
  for batch in valid_data_loader: