#!/usr/bin/python3 # -*- coding: utf-8 -*- from joblib import Parallel, delayed import numpy as np from pesq import pesq from typing import List from pesq import cypesq def run_pesq(clean_audio: np.ndarray, noisy_audio: np.ndarray, sample_rate: int = 16000, mode: str = "wb", ) -> float: if sample_rate == 8000 and mode == "wb": raise AssertionError(f"mode should be `nb` when sample_rate is 8000") try: pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) except cypesq.NoUtterancesError as e: pesq_score = -1 except Exception as e: print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") pesq_score = -1 return pesq_score def run_batch_pesq(clean_audio_list: List[np.ndarray], noisy_audio_list: List[np.ndarray], sample_rate: int = 16000, mode: str = "wb", n_jobs: int = 4, ) -> List[float]: parallel = Parallel(n_jobs=n_jobs) parallel_tasks = list() for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) parallel_tasks.append(parallel_task) pesq_score_list = parallel.__call__(parallel_tasks) return pesq_score_list def run_pesq_score(clean_audio_list: List[np.ndarray], noisy_audio_list: List[np.ndarray], sample_rate: int = 16000, mode: str = "wb", n_jobs: int = 4, ) -> List[float]: pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, noisy_audio_list=noisy_audio_list, sample_rate=sample_rate, mode=mode, n_jobs=n_jobs, ) pesq_score = np.mean(pesq_score_list) return pesq_score def main(): clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) clean_audio_list = list(clean_audio) noisy_audio_list = list(noisy_audio) pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) print(pesq_score_list) pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) print(pesq_score) return if __name__ == "__main__": main()