Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_1_prepare_data.py
CHANGED
@@ -101,7 +101,7 @@ def get_dataset(args):
|
|
101 |
count = 0
|
102 |
process_bar = tqdm(desc="build dataset excel")
|
103 |
for noise, speech in zip(noise_generator, speech_generator):
|
104 |
-
if count
|
105 |
break
|
106 |
|
107 |
noise_filename = noise["filename"]
|
|
|
101 |
count = 0
|
102 |
process_bar = tqdm(desc="build dataset excel")
|
103 |
for noise, speech in zip(noise_generator, speech_generator):
|
104 |
+
if count >= args.max_count:
|
105 |
break
|
106 |
|
107 |
noise_filename = noise["filename"]
|
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -232,8 +232,8 @@ def main():
|
|
232 |
|
233 |
loss = ae_loss + sc_loss + mag_loss
|
234 |
|
235 |
-
enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
|
236 |
-
clean_audios_list_r = list(clean_audios.cpu().numpy())
|
237 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
238 |
|
239 |
optimizer.zero_grad()
|
@@ -292,8 +292,8 @@ def main():
|
|
292 |
|
293 |
loss = ae_loss + sc_loss + mag_loss
|
294 |
|
295 |
-
enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
|
296 |
-
clean_audios_list_r = list(clean_audios.cpu().numpy())
|
297 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
298 |
|
299 |
total_pesq_metric += pesq_metric.item()
|
|
|
232 |
|
233 |
loss = ae_loss + sc_loss + mag_loss
|
234 |
|
235 |
+
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
236 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
237 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
238 |
|
239 |
optimizer.zero_grad()
|
|
|
292 |
|
293 |
loss = ae_loss + sc_loss + mag_loss
|
294 |
|
295 |
+
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
296 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
297 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
298 |
|
299 |
total_pesq_metric += pesq_metric.item()
|