HoneyTian's picture
update
14f8597
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://zhuanlan.zhihu.com/p/627039860
"""
import torch
import torch.nn as nn
from torch_stoi import NegSTOILoss as TorchNegSTOILoss
from torch_pesq import PesqLoss as TorchPesqLoss
class PMSQELoss(object):
"""
A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality
https://sigmat.ugr.es/PMSQE/
On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement
https://arxiv.org/abs/1909.01019
https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py
"""
class NegSTOILoss(nn.Module):
"""
STOI短时客观可懂度(Short-Time Objective Intelligibility),
通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度,
范围从0到1,分数越高可懂度越高。
它适用于评估噪声环境下的语音可懂度改善效果。
https://github.com/mpariente/pytorch_stoi
https://github.com/mpariente/pystoi
https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py
"""
def __init__(self,
sample_rate: int,
reduction: str = "mean",
):
super(NegSTOILoss, self).__init__()
self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate)
self.reduction = reduction
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
batch_loss = self.loss_fn.forward(denoise, clean)
if self.reduction == "mean":
loss = torch.mean(batch_loss)
elif self.reduction == "sum":
loss = torch.sum(batch_loss)
else:
raise AssertionError
return loss
class PesqLoss(nn.Module):
def __init__(self,
factor: float,
sample_rate: int = 48000,
nbarks: int = 49,
win_length: int = 512,
n_fft: int = 512,
hop_length: int = 256,
reduction: str = "mean",
):
super(PesqLoss, self).__init__()
self.factor = factor
self.sample_rate = sample_rate
self.nbarks = nbarks
self.win_length = win_length
self.n_fft = n_fft
self.hop_length = hop_length
self.reduction = reduction
self.loss_fn = TorchPesqLoss(
factor=factor,
sample_rate=sample_rate,
nbarks=nbarks,
win_length=win_length,
n_fft=n_fft,
hop_length=hop_length,
)
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
batch_loss = self.loss_fn.forward(clean, denoise)
# mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss))
# batch_loss = batch_loss[mask]
# if len(batch_loss) == 0:
# raise AssertionError
if self.reduction == "mean":
loss = torch.mean(batch_loss)
elif self.reduction == "sum":
loss = torch.sum(batch_loss)
else:
raise AssertionError
return loss
def main():
sample_rate = 16000
loss_func = NegSTOILoss(
sample_rate=sample_rate,
reduction="mean",
)
denoise = torch.randn(2, sample_rate)
clean = torch.randn(2, sample_rate)
loss_batch = loss_func.forward(denoise, clean)
print(loss_batch)
return
if __name__ == "__main__":
main()