Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from joblib import Parallel, delayed | |
import numpy as np | |
from pesq import pesq | |
from typing import List | |
from pesq import cypesq | |
def run_pesq(clean_audio: np.ndarray, | |
noisy_audio: np.ndarray, | |
sample_rate: int = 16000, | |
mode: str = "wb", | |
) -> float: | |
if sample_rate == 8000 and mode == "wb": | |
raise AssertionError(f"mode should be `nb` when sample_rate is 8000") | |
try: | |
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) | |
except cypesq.NoUtterancesError as e: | |
pesq_score = -1 | |
except Exception as e: | |
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") | |
pesq_score = -1 | |
return pesq_score | |
def run_batch_pesq(clean_audio_list: List[np.ndarray], | |
noisy_audio_list: List[np.ndarray], | |
sample_rate: int = 16000, | |
mode: str = "wb", | |
n_jobs: int = 4, | |
) -> List[float]: | |
parallel = Parallel(n_jobs=n_jobs) | |
parallel_tasks = list() | |
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): | |
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) | |
parallel_tasks.append(parallel_task) | |
pesq_score_list = parallel.__call__(parallel_tasks) | |
return pesq_score_list | |
def run_pesq_score(clean_audio_list: List[np.ndarray], | |
noisy_audio_list: List[np.ndarray], | |
sample_rate: int = 16000, | |
mode: str = "wb", | |
n_jobs: int = 4, | |
) -> List[float]: | |
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, | |
noisy_audio_list=noisy_audio_list, | |
sample_rate=sample_rate, | |
mode=mode, | |
n_jobs=n_jobs, | |
) | |
pesq_score = np.mean(pesq_score_list) | |
return pesq_score | |
def main(): | |
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) | |
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) | |
clean_audio_list = list(clean_audio) | |
noisy_audio_list = list(noisy_audio) | |
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) | |
print(pesq_score_list) | |
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) | |
print(pesq_score) | |
return | |
if __name__ == "__main__": | |
main() | |