File size: 1,127 Bytes
e27a095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from joblib import Parallel, delayed
import numpy as np
from pesq import pesq
import torch


def cal_pesq(clean, noisy, sr=16000):
    try:
        pesq_score = pesq(sr, clean, noisy, "wb")
    except Exception as e:
        # print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
        # error can happen due to silent period
        pesq_score = -1
    return pesq_score


def batch_pesq(clean, noisy):
    pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
    pesq_score = np.array(pesq_score)
    if -1 in pesq_score:
        return None
    pesq_score = (pesq_score - 1) / 3.5
    return torch.FloatTensor(pesq_score)


def main():

    prediction = torch.rand(size=(1, 160000), dtype=torch.float32)
    ground_truth = torch.rand(size=(1, 160000), dtype=torch.float32)

    prediction_list_r = list(prediction.cpu().numpy())
    ground_truth_list_r = list(ground_truth.cpu().numpy())

    pesq_score = batch_pesq(prediction_list_r, ground_truth_list_r)
    print(pesq_score)
    return


if __name__ == "__main__":
    main()