Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -32,7 +32,7 @@ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanU
|
|
32 |
from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
|
33 |
from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
|
34 |
from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
|
35 |
-
from toolbox.torchaudio.models.clean_unet.metrics import
|
36 |
|
37 |
torch.autograd.set_detect_anomaly(True)
|
38 |
|
@@ -217,7 +217,7 @@ def main():
|
|
217 |
# train
|
218 |
model.train()
|
219 |
|
220 |
-
|
221 |
total_loss = 0.
|
222 |
total_ae_loss = 0.
|
223 |
total_sc_loss = 0.
|
@@ -243,25 +243,21 @@ def main():
|
|
243 |
|
244 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
245 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
246 |
-
|
247 |
-
if pesq_metric is None:
|
248 |
-
pesq_metric = 0
|
249 |
-
else:
|
250 |
-
pesq_metric = torch.mean(pesq_metric).item()
|
251 |
|
252 |
optimizer.zero_grad()
|
253 |
loss.backward()
|
254 |
optimizer.step()
|
255 |
lr_scheduler.step()
|
256 |
|
257 |
-
|
258 |
total_loss += loss.item()
|
259 |
total_ae_loss += ae_loss.item()
|
260 |
total_sc_loss += sc_loss.item()
|
261 |
total_mag_loss += mag_loss.item()
|
262 |
total_batches += 1
|
263 |
|
264 |
-
|
265 |
average_loss = round(total_loss / total_batches, 4)
|
266 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
267 |
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
@@ -269,7 +265,7 @@ def main():
|
|
269 |
|
270 |
progress_bar.update(1)
|
271 |
progress_bar.set_postfix({
|
272 |
-
"
|
273 |
"loss": average_loss,
|
274 |
"ae_loss": average_ae_loss,
|
275 |
"sc_loss": average_sc_loss,
|
@@ -281,7 +277,7 @@ def main():
|
|
281 |
|
282 |
torch.cuda.empty_cache()
|
283 |
|
284 |
-
|
285 |
total_loss = 0.
|
286 |
total_ae_loss = 0.
|
287 |
total_sc_loss = 0.
|
@@ -308,20 +304,16 @@ def main():
|
|
308 |
|
309 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
310 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
311 |
-
|
312 |
-
if pesq_metric is None:
|
313 |
-
pesq_metric = 0
|
314 |
-
else:
|
315 |
-
pesq_metric = torch.mean(pesq_metric).item()
|
316 |
|
317 |
-
|
318 |
total_loss += loss.item()
|
319 |
total_ae_loss += ae_loss.item()
|
320 |
total_sc_loss += sc_loss.item()
|
321 |
total_mag_loss += mag_loss.item()
|
322 |
total_batches += 1
|
323 |
|
324 |
-
|
325 |
average_loss = round(total_loss / total_batches, 4)
|
326 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
327 |
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
@@ -329,7 +321,7 @@ def main():
|
|
329 |
|
330 |
progress_bar.update(1)
|
331 |
progress_bar.set_postfix({
|
332 |
-
"
|
333 |
"loss": average_loss,
|
334 |
"ae_loss": average_ae_loss,
|
335 |
"sc_loss": average_sc_loss,
|
|
|
32 |
from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
|
33 |
from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
|
34 |
from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
|
35 |
+
from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score
|
36 |
|
37 |
torch.autograd.set_detect_anomaly(True)
|
38 |
|
|
|
217 |
# train
|
218 |
model.train()
|
219 |
|
220 |
+
total_pesq_score = 0.
|
221 |
total_loss = 0.
|
222 |
total_ae_loss = 0.
|
223 |
total_sc_loss = 0.
|
|
|
243 |
|
244 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
245 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
246 |
+
pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
|
|
|
|
|
|
|
|
|
247 |
|
248 |
optimizer.zero_grad()
|
249 |
loss.backward()
|
250 |
optimizer.step()
|
251 |
lr_scheduler.step()
|
252 |
|
253 |
+
total_pesq_score += pesq_score
|
254 |
total_loss += loss.item()
|
255 |
total_ae_loss += ae_loss.item()
|
256 |
total_sc_loss += sc_loss.item()
|
257 |
total_mag_loss += mag_loss.item()
|
258 |
total_batches += 1
|
259 |
|
260 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
261 |
average_loss = round(total_loss / total_batches, 4)
|
262 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
263 |
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
|
|
265 |
|
266 |
progress_bar.update(1)
|
267 |
progress_bar.set_postfix({
|
268 |
+
"pesq_score": average_pesq_score,
|
269 |
"loss": average_loss,
|
270 |
"ae_loss": average_ae_loss,
|
271 |
"sc_loss": average_sc_loss,
|
|
|
277 |
|
278 |
torch.cuda.empty_cache()
|
279 |
|
280 |
+
total_pesq_score = 0.
|
281 |
total_loss = 0.
|
282 |
total_ae_loss = 0.
|
283 |
total_sc_loss = 0.
|
|
|
304 |
|
305 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
306 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
307 |
+
pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
total_pesq_score += pesq_score
|
310 |
total_loss += loss.item()
|
311 |
total_ae_loss += ae_loss.item()
|
312 |
total_sc_loss += sc_loss.item()
|
313 |
total_mag_loss += mag_loss.item()
|
314 |
total_batches += 1
|
315 |
|
316 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
317 |
average_loss = round(total_loss / total_batches, 4)
|
318 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
319 |
average_sc_loss = round(total_sc_loss / total_batches, 4)
|
|
|
321 |
|
322 |
progress_bar.update(1)
|
323 |
progress_bar.set_postfix({
|
324 |
+
"pesq_score": average_pesq_score,
|
325 |
"loss": average_loss,
|
326 |
"ae_loss": average_ae_loss,
|
327 |
"sc_loss": average_sc_loss,
|
toolbox/torchaudio/models/clean_unet/metrics.py
CHANGED
@@ -3,38 +3,76 @@
|
|
3 |
from joblib import Parallel, delayed
|
4 |
import numpy as np
|
5 |
from pesq import pesq
|
6 |
-
import
|
7 |
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
pesq_score = pesq(
|
|
|
|
|
12 |
except Exception as e:
|
13 |
-
|
14 |
-
# error can happen due to silent period
|
15 |
pesq_score = -1
|
16 |
return pesq_score
|
17 |
|
18 |
|
19 |
-
def
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
def main():
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
pesq_score =
|
37 |
print(pesq_score)
|
|
|
38 |
return
|
39 |
|
40 |
|
|
|
3 |
from joblib import Parallel, delayed
|
4 |
import numpy as np
|
5 |
from pesq import pesq
|
6 |
+
from typing import List
|
7 |
|
8 |
+
from pesq import cypesq
|
9 |
|
10 |
+
|
11 |
+
def run_pesq(clean_audio: np.ndarray,
|
12 |
+
noisy_audio: np.ndarray,
|
13 |
+
sample_rate: int = 16000,
|
14 |
+
mode: str = "wb",
|
15 |
+
) -> float:
|
16 |
+
if sample_rate == 8000 and mode == "wb":
|
17 |
+
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
|
18 |
try:
|
19 |
+
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
|
20 |
+
except cypesq.NoUtterancesError as e:
|
21 |
+
pesq_score = -1
|
22 |
except Exception as e:
|
23 |
+
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
|
|
24 |
pesq_score = -1
|
25 |
return pesq_score
|
26 |
|
27 |
|
28 |
+
def run_batch_pesq(clean_audio_list: List[np.ndarray],
|
29 |
+
noisy_audio_list: List[np.ndarray],
|
30 |
+
sample_rate: int = 16000,
|
31 |
+
mode: str = "wb",
|
32 |
+
n_jobs: int = 4,
|
33 |
+
) -> List[float]:
|
34 |
+
parallel = Parallel(n_jobs=n_jobs)
|
35 |
+
|
36 |
+
parallel_tasks = list()
|
37 |
+
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
|
38 |
+
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
|
39 |
+
parallel_tasks.append(parallel_task)
|
40 |
+
|
41 |
+
pesq_score_list = parallel.__call__(parallel_tasks)
|
42 |
+
return pesq_score_list
|
43 |
+
|
44 |
+
|
45 |
+
def run_pesq_score(clean_audio_list: List[np.ndarray],
|
46 |
+
noisy_audio_list: List[np.ndarray],
|
47 |
+
sample_rate: int = 16000,
|
48 |
+
mode: str = "wb",
|
49 |
+
n_jobs: int = 4,
|
50 |
+
) -> List[float]:
|
51 |
+
|
52 |
+
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
|
53 |
+
noisy_audio_list=noisy_audio_list,
|
54 |
+
sample_rate=sample_rate,
|
55 |
+
mode=mode,
|
56 |
+
n_jobs=n_jobs,
|
57 |
+
)
|
58 |
+
|
59 |
+
pesq_score = np.mean(pesq_score_list)
|
60 |
+
return pesq_score
|
61 |
|
62 |
|
63 |
def main():
|
64 |
+
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
65 |
+
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
66 |
|
67 |
+
clean_audio_list = list(clean_audio)
|
68 |
+
noisy_audio_list = list(noisy_audio)
|
69 |
|
70 |
+
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
|
71 |
+
print(pesq_score_list)
|
72 |
|
73 |
+
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
|
74 |
print(pesq_score)
|
75 |
+
|
76 |
return
|
77 |
|
78 |
|
toolbox/torchaudio/models/mpnet/metrics.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from joblib import Parallel, delayed
|
4 |
+
import numpy as np
|
5 |
+
from pesq import pesq
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from pesq import cypesq
|
9 |
+
|
10 |
+
|
11 |
+
def run_pesq(clean_audio: np.ndarray,
|
12 |
+
noisy_audio: np.ndarray,
|
13 |
+
sample_rate: int = 16000,
|
14 |
+
mode: str = "wb",
|
15 |
+
) -> float:
|
16 |
+
if sample_rate == 8000 and mode == "wb":
|
17 |
+
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
|
18 |
+
try:
|
19 |
+
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
|
20 |
+
except cypesq.NoUtterancesError as e:
|
21 |
+
pesq_score = -1
|
22 |
+
except Exception as e:
|
23 |
+
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
24 |
+
pesq_score = -1
|
25 |
+
return pesq_score
|
26 |
+
|
27 |
+
|
28 |
+
def run_batch_pesq(clean_audio_list: List[np.ndarray],
|
29 |
+
noisy_audio_list: List[np.ndarray],
|
30 |
+
sample_rate: int = 16000,
|
31 |
+
mode: str = "wb",
|
32 |
+
n_jobs: int = 4,
|
33 |
+
) -> List[float]:
|
34 |
+
parallel = Parallel(n_jobs=n_jobs)
|
35 |
+
|
36 |
+
parallel_tasks = list()
|
37 |
+
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
|
38 |
+
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
|
39 |
+
parallel_tasks.append(parallel_task)
|
40 |
+
|
41 |
+
pesq_score_list = parallel.__call__(parallel_tasks)
|
42 |
+
return pesq_score_list
|
43 |
+
|
44 |
+
|
45 |
+
def run_pesq_score(clean_audio_list: List[np.ndarray],
|
46 |
+
noisy_audio_list: List[np.ndarray],
|
47 |
+
sample_rate: int = 16000,
|
48 |
+
mode: str = "wb",
|
49 |
+
n_jobs: int = 4,
|
50 |
+
) -> List[float]:
|
51 |
+
|
52 |
+
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
|
53 |
+
noisy_audio_list=noisy_audio_list,
|
54 |
+
sample_rate=sample_rate,
|
55 |
+
mode=mode,
|
56 |
+
n_jobs=n_jobs,
|
57 |
+
)
|
58 |
+
|
59 |
+
pesq_score = np.mean(pesq_score_list)
|
60 |
+
return pesq_score
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
65 |
+
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
66 |
+
|
67 |
+
clean_audio_list = list(clean_audio)
|
68 |
+
noisy_audio_list = list(noisy_audio)
|
69 |
+
|
70 |
+
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
|
71 |
+
print(pesq_score_list)
|
72 |
+
|
73 |
+
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
|
74 |
+
print(pesq_score)
|
75 |
+
|
76 |
+
return
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|