Spaces:
Running
Running
File size: 5,864 Bytes
e27a095 877f661 e27a095 efe955e e27a095 852815e e27a095 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
import torch
import torch.nn.functional as F
def stft(x, fft_size, hop_size, win_length, window):
"""
Perform STFT and convert to magnitude spectrogram.
:param x: Tensor, Input signal tensor (B, T).
:param fft_size: int, FFT size.
:param hop_size: int, Hop size.
:param win_length: int, Window length.
:param window: str, Window function type.
:return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
return x_stft.abs()
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""
Calculate forward propagation.
:param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
:return: Tensor, Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""
Calculate forward propagation.
:param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
:param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
:return: Tensor, Log STFT magnitude loss value.
"""
y_mag = torch.clamp(y_mag, min=1e-8)
x_mag = torch.clamp(x_mag, min=1e-8)
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(
self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
band="full"
):
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.band = band
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
# NOTE(kan-bayashi): Use register_buffer to fix #223
self.register_buffer("window", getattr(torch, window)(win_length))
def forward(self, x, y):
"""
Calculate forward propagation.
:param x: Tensor, Predicted signal (B, T).
:param y: Tensor, Groundtruth signal (B, T).
:return:
Tensor, Spectral convergence loss value.
Tensor, Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
if self.band == "high":
freq_mask_ind = x_mag.shape[1] // 2 # only select high frequency bands
sc_loss = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
elif self.band == "full":
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
else:
raise NotImplementedError
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(self,
fft_sizes=None, hop_sizes=None, win_lengths=None,
window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full",
):
"""
Initialize Multi resolution STFT loss module.
:param fft_sizes: list, List of FFT sizes.
:param hop_sizes: list, List of hop sizes.
:param win_lengths: list, List of window lengths.
:param window: str, Window function type.
:param sc_lambda: float, a balancing factor across different losses.
:param mag_lambda: float, a balancing factor across different losses.
:param band: str, high-band or full-band loss
"""
super(MultiResolutionSTFTLoss, self).__init__()
fft_sizes = fft_sizes or [1024, 2048, 512]
hop_sizes = hop_sizes or [120, 240, 50]
win_lengths = win_lengths or [600, 1200, 240]
self.sc_lambda = sc_lambda
self.mag_lambda = mag_lambda
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window, band)]
def forward(self, x, y):
"""
Calculate forward propagation.
:param x: Tensor, Predicted signal (B, T) or (B, #subband, T).
:param y: Tensor, Groundtruth signal (B, T) or (B, #subband, T).
:return:
Tensor, Multi resolution spectral convergence loss value.
Tensor, Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss *= self.sc_lambda
sc_loss /= len(self.stft_losses)
mag_loss *= self.mag_lambda
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
if __name__ == '__main__':
pass
|