nickovchinnikov's picture
Init
9d61c9b
import torch
def stft(
x: torch.Tensor,
fft_size: int,
hop_size: int,
win_length: int,
window: torch.Tensor,
) -> torch.Tensor:
r"""Perform STFT and convert to magnitude spectrogram.
STFT stands for Short-Time Fourier Transform. It is a signal processing technique that is used to analyze the frequency content of a signal over time. The STFT is computed by dividing a long signal into shorter segments, and then computing the Fourier transform of each segment. This results in a time-frequency representation of the signal, where the frequency content of the signal is shown as a function of time.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (torch.Tensor): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
x_stft = torch.view_as_real(x_stft)
real = x_stft[..., 0]
imag = x_stft[..., 1]
# NOTE (kan-bayashi): clamp is needed to avoid nan or inf
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)