HoneyTian's picture
update
e27a095
raw
history blame
1.13 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from joblib import Parallel, delayed
import numpy as np
from pesq import pesq
import torch
def cal_pesq(clean, noisy, sr=16000):
try:
pesq_score = pesq(sr, clean, noisy, "wb")
except Exception as e:
# print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
# error can happen due to silent period
pesq_score = -1
return pesq_score
def batch_pesq(clean, noisy):
pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
pesq_score = np.array(pesq_score)
if -1 in pesq_score:
return None
pesq_score = (pesq_score - 1) / 3.5
return torch.FloatTensor(pesq_score)
def main():
prediction = torch.rand(size=(1, 160000), dtype=torch.float32)
ground_truth = torch.rand(size=(1, 160000), dtype=torch.float32)
prediction_list_r = list(prediction.cpu().numpy())
ground_truth_list_r = list(ground_truth.cpu().numpy())
pesq_score = batch_pesq(prediction_list_r, ground_truth_list_r)
print(pesq_score)
return
if __name__ == "__main__":
main()