Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -212,7 +212,7 @@ def main():
|
|
212 |
clean_audio, noisy_audio = batch
|
213 |
clean_audio = clean_audio.to(device)
|
214 |
noisy_audio = noisy_audio.to(device)
|
215 |
-
one_labels = torch.ones(
|
216 |
|
217 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
218 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
@@ -234,7 +234,7 @@ def main():
|
|
234 |
if batch_pesq_score is not None:
|
235 |
loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
|
236 |
else:
|
237 |
-
print("pesq is None!")
|
238 |
loss_disc_g = 0
|
239 |
|
240 |
loss_disc_all = loss_disc_r + loss_disc_g
|
|
|
212 |
clean_audio, noisy_audio = batch
|
213 |
clean_audio = clean_audio.to(device)
|
214 |
noisy_audio = noisy_audio.to(device)
|
215 |
+
one_labels = torch.ones(clean_audio.shape[0]).to(device)
|
216 |
|
217 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
218 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
|
|
234 |
if batch_pesq_score is not None:
|
235 |
loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
|
236 |
else:
|
237 |
+
# print("pesq is None!")
|
238 |
loss_disc_g = 0
|
239 |
|
240 |
loss_disc_all = loss_disc_r + loss_disc_g
|