Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -227,12 +227,8 @@ def main():
|
|
227 |
clean_audios = clean_audios.to(device)
|
228 |
noisy_audios = noisy_audios.to(device)
|
229 |
|
230 |
-
print(f"clean_audios shape: {clean_audios.shape}, dtype: {clean_audios.dtype}")
|
231 |
-
print(f"noisy_audios shape: {noisy_audios.shape}, dtype: {noisy_audios.dtype}")
|
232 |
-
|
233 |
enhanced_audios = model.forward(noisy_audios)
|
234 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
235 |
-
print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
|
236 |
|
237 |
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
238 |
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
@@ -248,7 +244,7 @@ def main():
|
|
248 |
optimizer.step()
|
249 |
lr_scheduler.step()
|
250 |
|
251 |
-
total_pesq_metric += pesq_metric.item()
|
252 |
total_loss += loss.item()
|
253 |
total_ae_loss += ae_loss.item()
|
254 |
total_sc_loss += sc_loss.item()
|
|
|
227 |
clean_audios = clean_audios.to(device)
|
228 |
noisy_audios = noisy_audios.to(device)
|
229 |
|
|
|
|
|
|
|
230 |
enhanced_audios = model.forward(noisy_audios)
|
231 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
|
|
232 |
|
233 |
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
234 |
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
|
|
244 |
optimizer.step()
|
245 |
lr_scheduler.step()
|
246 |
|
247 |
+
total_pesq_metric += 0 if pesq_metric is None else pesq_metric.item()
|
248 |
total_loss += loss.item()
|
249 |
total_ae_loss += ae_loss.item()
|
250 |
total_sc_loss += sc_loss.item()
|