HoneyTian's picture
update
33aff71
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from joblib import Parallel, delayed
import numpy as np
from pesq import pesq
from typing import List
from pesq import cypesq
def run_pesq(clean_audio: np.ndarray,
noisy_audio: np.ndarray,
sample_rate: int = 16000,
mode: str = "wb",
) -> float:
if sample_rate == 8000 and mode == "wb":
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
try:
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
except cypesq.NoUtterancesError as e:
pesq_score = -1
except Exception as e:
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
pesq_score = -1
return pesq_score
def run_batch_pesq(clean_audio_list: List[np.ndarray],
noisy_audio_list: List[np.ndarray],
sample_rate: int = 16000,
mode: str = "wb",
n_jobs: int = 4,
) -> List[float]:
parallel = Parallel(n_jobs=n_jobs)
parallel_tasks = list()
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
parallel_tasks.append(parallel_task)
pesq_score_list = parallel.__call__(parallel_tasks)
return pesq_score_list
def run_pesq_score(clean_audio_list: List[np.ndarray],
noisy_audio_list: List[np.ndarray],
sample_rate: int = 16000,
mode: str = "wb",
n_jobs: int = 4,
) -> List[float]:
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
noisy_audio_list=noisy_audio_list,
sample_rate=sample_rate,
mode=mode,
n_jobs=n_jobs,
)
pesq_score = np.mean(pesq_score_list)
return pesq_score
def main():
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
clean_audio_list = list(clean_audio)
noisy_audio_list = list(noisy_audio)
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
print(pesq_score_list)
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
print(pesq_score)
return
if __name__ == "__main__":
main()