|
import torch |
|
import torchaudio |
|
|
|
def stft(x, fft_size, hop_size, win_length, window, use_complex=False): |
|
"""Perform STFT and convert to magnitude spectrogram. |
|
Args: |
|
x (Tensor): Input signal tensor (B, T). |
|
fft_size (int): FFT size. |
|
hop_size (int): Hop size. |
|
win_length (int): Window length. |
|
window (str): Window function type. |
|
Returns: |
|
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). |
|
""" |
|
|
|
x_stft = torch.stft(x, fft_size, hop_size, win_length, window.to(x.device), |
|
return_complex=True) |
|
|
|
|
|
if not use_complex: |
|
return torch.sqrt(torch.clamp( |
|
x_stft.real ** 2 + x_stft.imag ** 2, min=1e-7, max=1e3)).transpose(2, 1) |
|
else: |
|
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1) |
|
res = res.transpose(2, 3) |
|
return res |
|
|