File size: 2,547 Bytes
e27a095
 
 
 
 
4f045d5
e27a095
4f045d5
e27a095
4f045d5
 
 
 
 
 
 
 
e27a095
4f045d5
 
 
e27a095
4f045d5
e27a095
 
 
 
4f045d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e27a095
 
 
4f045d5
 
e27a095
4f045d5
 
e27a095
4f045d5
 
e27a095
4f045d5
e27a095
4f045d5
e27a095
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from joblib import Parallel, delayed
import numpy as np
from pesq import pesq
from typing import List

from pesq import cypesq


def run_pesq(clean_audio: np.ndarray,
             noisy_audio: np.ndarray,
             sample_rate: int = 16000,
             mode: str = "wb",
             ) -> float:
    if sample_rate == 8000 and mode == "wb":
        raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
    try:
        pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
    except cypesq.NoUtterancesError as e:
        pesq_score = -1
    except Exception as e:
        print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
        pesq_score = -1
    return pesq_score


def run_batch_pesq(clean_audio_list: List[np.ndarray],
                   noisy_audio_list: List[np.ndarray],
                   sample_rate: int = 16000,
                   mode: str = "wb",
                   n_jobs: int = 4,
                   ) -> List[float]:
    parallel = Parallel(n_jobs=n_jobs)

    parallel_tasks = list()
    for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
        parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
        parallel_tasks.append(parallel_task)

    pesq_score_list = parallel.__call__(parallel_tasks)
    return pesq_score_list


def run_pesq_score(clean_audio_list: List[np.ndarray],
                   noisy_audio_list: List[np.ndarray],
                   sample_rate: int = 16000,
                   mode: str = "wb",
                   n_jobs: int = 4,
                   ) -> List[float]:

    pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
                                     noisy_audio_list=noisy_audio_list,
                                     sample_rate=sample_rate,
                                     mode=mode,
                                     n_jobs=n_jobs,
                                     )

    pesq_score = np.mean(pesq_score_list)
    return pesq_score


def main():
    clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
    noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))

    clean_audio_list = list(clean_audio)
    noisy_audio_list = list(noisy_audio)

    pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
    print(pesq_score_list)

    pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
    print(pesq_score)

    return


if __name__ == "__main__":
    main()