Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -171,7 +171,7 @@ def main():
|
|
171 |
num_params = 0
|
172 |
for p in generator.parameters():
|
173 |
num_params += p.numel()
|
174 |
-
logger.info("
|
175 |
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
@@ -210,9 +210,9 @@ def main():
|
|
210 |
)
|
211 |
for batch in train_data_loader:
|
212 |
clean_audio, noisy_audio = batch
|
213 |
-
clean_audio =
|
214 |
-
noisy_audio =
|
215 |
-
one_labels = torch.ones(config.batch_size).to(device
|
216 |
|
217 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
218 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
@@ -293,8 +293,8 @@ def main():
|
|
293 |
with torch.no_grad():
|
294 |
for batch in valid_data_loader:
|
295 |
clean_audio, noisy_audio = batch
|
296 |
-
clean_audio =
|
297 |
-
noisy_audio =
|
298 |
|
299 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
300 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
|
|
171 |
num_params = 0
|
172 |
for p in generator.parameters():
|
173 |
num_params += p.numel()
|
174 |
+
logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
|
175 |
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
|
|
210 |
)
|
211 |
for batch in train_data_loader:
|
212 |
clean_audio, noisy_audio = batch
|
213 |
+
clean_audio = clean_audio.to(device)
|
214 |
+
noisy_audio = noisy_audio.to(device)
|
215 |
+
one_labels = torch.ones(config.batch_size).to(device)
|
216 |
|
217 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
218 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
|
|
293 |
with torch.no_grad():
|
294 |
for batch in valid_data_loader:
|
295 |
clean_audio, noisy_audio = batch
|
296 |
+
clean_audio = clean_audio.to(device)
|
297 |
+
noisy_audio = noisy_audio.to(device)
|
298 |
|
299 |
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|
300 |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
|