HoneyTian's picture
update
c797dfd
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://zhuanlan.zhihu.com/p/627039860
https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py
"""
from typing import List
import torch
import torch.nn as nn
from torch.nn import functional as F
class LSDLoss(nn.Module):
"""
Log Spectral Distance
Mean square error of power spectrum
"""
def __init__(self,
n_fft: int = 512,
win_size: int = 512,
hop_size: int = 256,
center: bool = True,
eps: float = 1e-8,
reduction: str = "mean",
):
super(LSDLoss, self).__init__()
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_size
self.center = center
self.eps = eps
self.reduction = reduction
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor):
"""
:param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...)
:param clean_power: power spectrum of the target signal (batch_size, ...)
:return:
"""
denoise_power = denoise_power + self.eps
clean_power = clean_power + self.eps
log_denoise_power = torch.log10(denoise_power)
log_clean_power = torch.log10(clean_power)
# mean_square_error shape: [b, f]
mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1)
if self.reduction == "mean":
lsd_loss = torch.mean(mean_square_error)
elif self.reduction == "sum":
lsd_loss = torch.sum(mean_square_error)
else:
raise AssertionError
return lsd_loss
class ComplexSpectralLoss(nn.Module):
def __init__(self,
n_fft: int = 512,
win_size: int = 512,
hop_size: int = 256,
center: bool = True,
eps: float = 1e-8,
reduction: str = "mean",
factor_mag: float = 0.5,
factor_pha: float = 0.3,
factor_gra: float = 0.2,
):
super().__init__()
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_size
self.center = center
self.eps = eps
self.reduction = reduction
self.factor_mag = factor_mag
self.factor_pha = factor_pha
self.factor_gra = factor_gra
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
"""
:param denoise: The estimated signal (batch_size, signal_length)
:param clean: The target signal (batch_size, signal_length)
:return:
"""
if denoise.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
# denoise_stft, clean_stft shape: [b, f, t]
denoise_stft = torch.stft(
denoise,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
clean_stft = torch.stft(
clean,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
# complex_diff shape: [b, f, t], dtype: torch.complex64
complex_diff = denoise_stft - clean_stft
# magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32
magnitude_diff = torch.abs(complex_diff)
phase_diff = torch.angle(complex_diff)
# magnitude_loss, phase_loss shape: [b,]
magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2))
phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2))
# phase_grad shape: [b, f, t-1], dtype: torch.float32
phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1)
grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2))
# loss, grad_loss shape: [b,]
batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss
# print(f"magnitude_loss: {magnitude_loss}")
# print(f"phase_loss: {phase_loss}")
# print(f"grad_loss: {grad_loss}")
if self.reduction == "mean":
loss = torch.mean(batch_loss)
elif self.reduction == "sum":
loss = torch.sum(batch_loss)
else:
raise AssertionError
return loss
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self,
reduction: str = "mean",
eps: float = 1e-8,
):
super(SpectralConvergenceLoss, self).__init__()
self.reduction = reduction
self.eps = eps
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self,
denoise_magnitude: torch.Tensor,
clean_magnitude: torch.Tensor,
):
"""
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
:return:
"""
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2))
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2))
batch_loss = error_norm / (truth_norm + self.eps)
if self.reduction == "mean":
loss = torch.mean(batch_loss)
elif self.reduction == "sum":
loss = torch.sum(batch_loss)
else:
raise AssertionError
return loss
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self,
reduction: str = "mean",
eps: float = 1e-8,
):
super(LogSTFTMagnitudeLoss, self).__init__()
self.reduction = reduction
self.eps = eps
if reduction not in ("sum", "mean"):
raise AssertionError(f"param reduction must be sum or mean.")
def forward(self,
denoise_magnitude: torch.Tensor,
clean_magnitude: torch.Tensor,
):
"""
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
:return:
"""
loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
raise AssertionError("SpectralConvergenceLoss, nan or inf in loss")
return loss
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self,
n_fft: int = 1024,
win_size: int = 600,
hop_size: int = 120,
center: bool = True,
reduction: str = "mean",
):
super(STFTLoss, self).__init__()
self.n_fft = n_fft
self.win_size = win_size
self.hop_size = hop_size
self.center = center
self.reduction = reduction
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction)
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction)
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
"""
:param denoise:
:param clean:
:return:
"""
if denoise.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
# denoise_stft, clean_stft shape: [b, f, t]
denoise_stft = torch.stft(
denoise,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
clean_stft = torch.stft(
clean,
n_fft=self.n_fft,
win_length=self.win_size,
hop_length=self.hop_size,
window=self.window,
center=self.center,
pad_mode="reflect",
normalized=False,
return_complex=True
)
denoise_magnitude = torch.abs(denoise_stft)
clean_magnitude = torch.abs(clean_stft)
sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude)
mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(self,
fft_size_list: List[int] = None,
win_size_list: List[int] = None,
hop_size_list: List[int] = None,
factor_sc=0.1,
factor_mag=0.1,
reduction: str = "mean",
):
super(MultiResolutionSTFTLoss, self).__init__()
fft_size_list = fft_size_list or [512, 1024, 2048]
win_size_list = win_size_list or [240, 600, 1200]
hop_size_list = hop_size_list or [50, 120, 240]
if not len(fft_size_list) == len(win_size_list) == len(hop_size_list):
raise AssertionError
loss_fn_list = nn.ModuleList([])
for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list):
loss_fn_list.append(
STFTLoss(
n_fft=n_fft,
win_size=win_size,
hop_size=hop_size,
reduction=reduction,
)
)
self.loss_fn_list = loss_fn_list
self.factor_sc = factor_sc
self.factor_mag = factor_mag
def forward(self, denoise: torch.Tensor, clean: torch.Tensor):
"""
:param denoise:
:param clean:
:return:
"""
if denoise.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
sc_loss = 0.0
mag_loss = 0.0
for loss_fn in self.loss_fn_list:
sc_l, mag_l = loss_fn.forward(denoise, clean)
sc_loss += sc_l
mag_loss += mag_l
sc_loss = sc_loss / len(self.loss_fn_list)
mag_loss = mag_loss / len(self.loss_fn_list)
sc_loss = self.factor_sc * sc_loss
mag_loss = self.factor_mag * mag_loss
loss = sc_loss + mag_loss
return loss
def main():
batch_size = 2
signal_length = 16000
estimated_signal = torch.randn(batch_size, signal_length)
target_signal = torch.randn(batch_size, signal_length)
# loss_fn = LSDLoss()
# loss_fn = ComplexSpectralLoss()
loss_fn = MultiResolutionSTFTLoss()
loss = loss_fn.forward(estimated_signal, target_signal)
print(f"loss: {loss.item()}")
return
if __name__ == "__main__":
main()