Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -206,7 +206,7 @@ def main():
|
|
206 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
207 |
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
208 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
209 |
-
fft_size_list=[256, 512, 1024],
|
210 |
win_size_list=[120, 240, 480],
|
211 |
hop_size_list=[25, 50, 100],
|
212 |
reduction="mean"
|
@@ -240,7 +240,7 @@ def main():
|
|
240 |
total_lds_loss = 0.
|
241 |
total_batches = 0.
|
242 |
progress_bar = tqdm(
|
243 |
-
desc="Training; epoch
|
244 |
)
|
245 |
for batch in train_data_loader:
|
246 |
clean_audios, noisy_audios = batch
|
@@ -305,7 +305,7 @@ def main():
|
|
305 |
total_batches = 0.
|
306 |
|
307 |
progress_bar = tqdm(
|
308 |
-
desc="Evaluation; epoch
|
309 |
)
|
310 |
with torch.no_grad():
|
311 |
for batch in valid_data_loader:
|
|
|
206 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
207 |
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
208 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
209 |
+
# fft_size_list=[256, 512, 1024],
|
210 |
win_size_list=[120, 240, 480],
|
211 |
hop_size_list=[25, 50, 100],
|
212 |
reduction="mean"
|
|
|
240 |
total_lds_loss = 0.
|
241 |
total_batches = 0.
|
242 |
progress_bar = tqdm(
|
243 |
+
desc="Training; epoch-{}".format(idx_epoch),
|
244 |
)
|
245 |
for batch in train_data_loader:
|
246 |
clean_audios, noisy_audios = batch
|
|
|
305 |
total_batches = 0.
|
306 |
|
307 |
progress_bar = tqdm(
|
308 |
+
desc="Evaluation; epoch-{}".format(idx_epoch),
|
309 |
)
|
310 |
with torch.no_grad():
|
311 |
for batch in valid_data_loader:
|