Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://zhuanlan.zhihu.com/p/627039860 | |
https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py | |
""" | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
class LSDLoss(nn.Module): | |
""" | |
Log Spectral Distance | |
Mean square error of power spectrum | |
""" | |
def __init__(self, | |
n_fft: int = 512, | |
win_size: int = 512, | |
hop_size: int = 256, | |
center: bool = True, | |
eps: float = 1e-8, | |
reduction: str = "mean", | |
): | |
super(LSDLoss, self).__init__() | |
self.n_fft = n_fft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.center = center | |
self.eps = eps | |
self.reduction = reduction | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor): | |
""" | |
:param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...) | |
:param clean_power: power spectrum of the target signal (batch_size, ...) | |
:return: | |
""" | |
denoise_power = denoise_power + self.eps | |
clean_power = clean_power + self.eps | |
log_denoise_power = torch.log10(denoise_power) | |
log_clean_power = torch.log10(clean_power) | |
# mean_square_error shape: [b, f] | |
mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1) | |
if self.reduction == "mean": | |
lsd_loss = torch.mean(mean_square_error) | |
elif self.reduction == "sum": | |
lsd_loss = torch.sum(mean_square_error) | |
else: | |
raise AssertionError | |
return lsd_loss | |
class ComplexSpectralLoss(nn.Module): | |
def __init__(self, | |
n_fft: int = 512, | |
win_size: int = 512, | |
hop_size: int = 256, | |
center: bool = True, | |
eps: float = 1e-8, | |
reduction: str = "mean", | |
factor_mag: float = 0.5, | |
factor_pha: float = 0.3, | |
factor_gra: float = 0.2, | |
): | |
super().__init__() | |
self.n_fft = n_fft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.center = center | |
self.eps = eps | |
self.reduction = reduction | |
self.factor_mag = factor_mag | |
self.factor_pha = factor_pha | |
self.factor_gra = factor_gra | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
""" | |
:param denoise: The estimated signal (batch_size, signal_length) | |
:param clean: The target signal (batch_size, signal_length) | |
:return: | |
""" | |
if denoise.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
# denoise_stft, clean_stft shape: [b, f, t] | |
denoise_stft = torch.stft( | |
denoise, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
clean_stft = torch.stft( | |
clean, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
# complex_diff shape: [b, f, t], dtype: torch.complex64 | |
complex_diff = denoise_stft - clean_stft | |
# magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32 | |
magnitude_diff = torch.abs(complex_diff) | |
phase_diff = torch.angle(complex_diff) | |
# magnitude_loss, phase_loss shape: [b,] | |
magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2)) | |
phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2)) | |
# phase_grad shape: [b, f, t-1], dtype: torch.float32 | |
phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1) | |
grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2)) | |
# loss, grad_loss shape: [b,] | |
batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss | |
# print(f"magnitude_loss: {magnitude_loss}") | |
# print(f"phase_loss: {phase_loss}") | |
# print(f"grad_loss: {grad_loss}") | |
if self.reduction == "mean": | |
loss = torch.mean(batch_loss) | |
elif self.reduction == "sum": | |
loss = torch.sum(batch_loss) | |
else: | |
raise AssertionError | |
return loss | |
class SpectralConvergenceLoss(torch.nn.Module): | |
"""Spectral convergence loss module.""" | |
def __init__(self, | |
reduction: str = "mean", | |
eps: float = 1e-8, | |
): | |
super(SpectralConvergenceLoss, self).__init__() | |
self.reduction = reduction | |
self.eps = eps | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
def forward(self, | |
denoise_magnitude: torch.Tensor, | |
clean_magnitude: torch.Tensor, | |
): | |
""" | |
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
:return: | |
""" | |
error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2)) | |
truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2)) | |
batch_loss = error_norm / (truth_norm + self.eps) | |
if self.reduction == "mean": | |
loss = torch.mean(batch_loss) | |
elif self.reduction == "sum": | |
loss = torch.sum(batch_loss) | |
else: | |
raise AssertionError | |
return loss | |
class LogSTFTMagnitudeLoss(torch.nn.Module): | |
"""Log STFT magnitude loss module.""" | |
def __init__(self, | |
reduction: str = "mean", | |
eps: float = 1e-8, | |
): | |
super(LogSTFTMagnitudeLoss, self).__init__() | |
self.reduction = reduction | |
self.eps = eps | |
if reduction not in ("sum", "mean"): | |
raise AssertionError(f"param reduction must be sum or mean.") | |
def forward(self, | |
denoise_magnitude: torch.Tensor, | |
clean_magnitude: torch.Tensor, | |
): | |
""" | |
:param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] | |
:return: | |
""" | |
loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps)) | |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): | |
raise AssertionError("SpectralConvergenceLoss, nan or inf in loss") | |
return loss | |
class STFTLoss(torch.nn.Module): | |
"""STFT loss module.""" | |
def __init__(self, | |
n_fft: int = 1024, | |
win_size: int = 600, | |
hop_size: int = 120, | |
center: bool = True, | |
reduction: str = "mean", | |
): | |
super(STFTLoss, self).__init__() | |
self.n_fft = n_fft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.center = center | |
self.reduction = reduction | |
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) | |
self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction) | |
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction) | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
""" | |
:param denoise: | |
:param clean: | |
:return: | |
""" | |
if denoise.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
# denoise_stft, clean_stft shape: [b, f, t] | |
denoise_stft = torch.stft( | |
denoise, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
clean_stft = torch.stft( | |
clean, | |
n_fft=self.n_fft, | |
win_length=self.win_size, | |
hop_length=self.hop_size, | |
window=self.window, | |
center=self.center, | |
pad_mode="reflect", | |
normalized=False, | |
return_complex=True | |
) | |
denoise_magnitude = torch.abs(denoise_stft) | |
clean_magnitude = torch.abs(clean_stft) | |
sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude) | |
mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude) | |
return sc_loss, mag_loss | |
class MultiResolutionSTFTLoss(torch.nn.Module): | |
"""Multi resolution STFT loss module.""" | |
def __init__(self, | |
fft_size_list: List[int] = None, | |
win_size_list: List[int] = None, | |
hop_size_list: List[int] = None, | |
factor_sc=0.1, | |
factor_mag=0.1, | |
reduction: str = "mean", | |
): | |
super(MultiResolutionSTFTLoss, self).__init__() | |
fft_size_list = fft_size_list or [512, 1024, 2048] | |
win_size_list = win_size_list or [240, 600, 1200] | |
hop_size_list = hop_size_list or [50, 120, 240] | |
if not len(fft_size_list) == len(win_size_list) == len(hop_size_list): | |
raise AssertionError | |
loss_fn_list = nn.ModuleList([]) | |
for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list): | |
loss_fn_list.append( | |
STFTLoss( | |
n_fft=n_fft, | |
win_size=win_size, | |
hop_size=hop_size, | |
reduction=reduction, | |
) | |
) | |
self.loss_fn_list = loss_fn_list | |
self.factor_sc = factor_sc | |
self.factor_mag = factor_mag | |
def forward(self, denoise: torch.Tensor, clean: torch.Tensor): | |
""" | |
:param denoise: | |
:param clean: | |
:return: | |
""" | |
if denoise.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
sc_loss = 0.0 | |
mag_loss = 0.0 | |
for loss_fn in self.loss_fn_list: | |
sc_l, mag_l = loss_fn.forward(denoise, clean) | |
sc_loss += sc_l | |
mag_loss += mag_l | |
sc_loss = sc_loss / len(self.loss_fn_list) | |
mag_loss = mag_loss / len(self.loss_fn_list) | |
sc_loss = self.factor_sc * sc_loss | |
mag_loss = self.factor_mag * mag_loss | |
loss = sc_loss + mag_loss | |
return loss | |
def main(): | |
batch_size = 2 | |
signal_length = 16000 | |
estimated_signal = torch.randn(batch_size, signal_length) | |
target_signal = torch.randn(batch_size, signal_length) | |
# loss_fn = LSDLoss() | |
# loss_fn = ComplexSpectralLoss() | |
loss_fn = MultiResolutionSTFTLoss() | |
loss = loss_fn.forward(estimated_signal, target_signal) | |
print(f"loss: {loss.item()}") | |
return | |
if __name__ == "__main__": | |
main() | |