#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scipy.signal import get_window def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False): if win_type == "None" or win_type is None: window = np.ones(win_size) else: window = get_window(win_type, win_size, fftbins=True)**0.5 fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size] real_kernel = np.real(fourier_basis) image_kernel = np.imag(fourier_basis) kernel = np.concatenate([real_kernel, image_kernel], 1).T if inverse: kernel = np.linalg.pinv(kernel).T kernel = kernel * window kernel = kernel[:, None, :] result = ( torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) ) return result class ConvSTFT(nn.Module): def __init__(self, nfft: int, win_size: int, hop_size: int, win_type: str = "hamming", feature_type: str = "real", requires_grad: bool = False): super(ConvSTFT, self).__init__() if nfft is None: self.nfft = int(2**np.ceil(np.log2(win_size))) else: self.nfft = nfft kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type) self.weight = nn.Parameter(kernel, requires_grad=requires_grad) self.win_size = win_size self.hop_size = hop_size self.stride = hop_size self.dim = self.nfft self.feature_type = feature_type def forward(self, inputs: torch.Tensor): if inputs.dim() == 2: inputs = torch.unsqueeze(inputs, 1) outputs = F.conv1d(inputs, self.weight, stride=self.stride) if self.feature_type == "complex": return outputs else: dim = self.dim // 2 + 1 real = outputs[:, :dim, :] imag = outputs[:, dim:, :] mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase class ConviSTFT(nn.Module): def __init__(self, win_size: int, hop_size: int, nfft: int = None, win_type: str = "hamming", feature_type: str = "real", requires_grad: bool = False): super(ConviSTFT, self).__init__() if nfft is None: self.nfft = int(2**np.ceil(np.log2(win_size))) else: self.nfft = nfft kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True) self.weight = nn.Parameter(kernel, requires_grad=requires_grad) self.win_size = win_size self.hop_size = hop_size self.win_type = win_type self.stride = hop_size self.dim = self.nfft self.feature_type = feature_type self.register_buffer("window", window) self.register_buffer("enframe", torch.eye(win_size)[:, None, :]) def forward(self, inputs: torch.Tensor, phase: torch.Tensor = None): """ :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags) :param phase: torch.Tensor, shape: [b, n//2+1, t] :return: """ if phase is not None: real = inputs * torch.cos(phase) imag = inputs * torch.sin(phase) inputs = torch.cat([real, imag], 1) outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) # this is from torch-stft: https://github.com/pseeth/torch-stft t = self.window.repeat(1, 1, inputs.size(-1))**2 coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) outputs = outputs / (coff + 1e-8) return outputs def main(): stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex") istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex") mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32) spec = stft.forward(mixture) # shape: [batch_size, freq_bins, time_steps] print(spec.shape) waveform = istft.forward(spec) # shape: [batch_size, channels, num_samples] print(waveform.shape) return if __name__ == "__main__": main()