Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://zhuanlan.zhihu.com/p/627039860 | |
""" | |
import torch | |
import torch.nn as nn | |
from torch_stoi import NegSTOILoss as TorchNegSTOILoss | |
from torch_pesq import PesqLoss as TorchPesqLoss | |
class PMSQELoss(object): | |
""" | |
A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality | |
https://sigmat.ugr.es/PMSQE/ | |
On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement | |
https://arxiv.org/abs/1909.01019 | |
https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py | |
""" | |
class NegSTOILoss(nn.Module): | |
""" | |
STOI短时客观可懂度(Short-Time Objective Intelligibility), | |
通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度, | |
范围从0到1,分数越高可懂度越高。 | |
它适用于评估噪声环境下的语音可懂度改善效果。 | |
https://github.com/mpariente/pytorch_stoi | |
https://github.com/mpariente/pystoi | |
https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py | |
""" | |
def __init__(self, | |
sample_rate: int, | |
reduction: str = "mean", | |
): | |
super(NegSTOILoss, self).__init__() | |
self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate) | |
self.reduction = reduction | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
batch_loss = self.loss_fn.forward(denoise, clean) | |
if self.reduction == "mean": | |
loss = torch.mean(batch_loss) | |
elif self.reduction == "sum": | |
loss = torch.sum(batch_loss) | |
else: | |
raise AssertionError | |
return loss | |
class PesqLoss(nn.Module): | |
def __init__(self, | |
factor: float, | |
sample_rate: int = 48000, | |
nbarks: int = 49, | |
win_length: int = 512, | |
n_fft: int = 512, | |
hop_length: int = 256, | |
reduction: str = "mean", | |
): | |
super(PesqLoss, self).__init__() | |
self.factor = factor | |
self.sample_rate = sample_rate | |
self.nbarks = nbarks | |
self.win_length = win_length | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.reduction = reduction | |
self.loss_fn = TorchPesqLoss( | |
factor=factor, | |
sample_rate=sample_rate, | |
nbarks=nbarks, | |
win_length=win_length, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
) | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
batch_loss = self.loss_fn.forward(clean, denoise) | |
# mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss)) | |
# batch_loss = batch_loss[mask] | |
# if len(batch_loss) == 0: | |
# raise AssertionError | |
if self.reduction == "mean": | |
loss = torch.mean(batch_loss) | |
elif self.reduction == "sum": | |
loss = torch.sum(batch_loss) | |
else: | |
raise AssertionError | |
return loss | |
def main(): | |
sample_rate = 16000 | |
loss_func = NegSTOILoss( | |
sample_rate=sample_rate, | |
reduction="mean", | |
) | |
denoise = torch.randn(2, sample_rate) | |
clean = torch.randn(2, sample_rate) | |
loss_batch = loss_func.forward(denoise, clean) | |
print(loss_batch) | |
return | |
if __name__ == "__main__": | |
main() | |