HoneyTian's picture
add frcrn model
1d4c9c3
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
if win_type == "None" or win_type is None:
window = np.ones(win_size)
else:
window = get_window(win_type, win_size, fftbins=True)**0.5
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
real_kernel = np.real(fourier_basis)
image_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, image_kernel], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
result = (
torch.from_numpy(kernel.astype(np.float32)),
torch.from_numpy(window[None, :, None].astype(np.float32))
)
return result
class ConvSTFT(nn.Module):
def __init__(self,
nfft: int,
win_size: int,
hop_size: int,
win_type: str = "hamming",
feature_type: str = "real",
requires_grad: bool = False):
super(ConvSTFT, self).__init__()
if nfft is None:
self.nfft = int(2**np.ceil(np.log2(win_size)))
else:
self.nfft = nfft
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
self.win_size = win_size
self.hop_size = hop_size
self.stride = hop_size
self.dim = self.nfft
self.feature_type = feature_type
def forward(self, inputs: torch.Tensor):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == "complex":
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return mags, phase
class ConviSTFT(nn.Module):
def __init__(self,
win_size: int,
hop_size: int,
nfft: int = None,
win_type: str = "hamming",
feature_type: str = "real",
requires_grad: bool = False):
super(ConviSTFT, self).__init__()
if nfft is None:
self.nfft = int(2**np.ceil(np.log2(win_size)))
else:
self.nfft = nfft
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
self.stride = hop_size
self.dim = self.nfft
self.feature_type = feature_type
self.register_buffer("window", window)
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
def forward(self,
inputs: torch.Tensor,
phase: torch.Tensor = None):
"""
:param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags)
:param phase: torch.Tensor, shape: [b, n//2+1, t]
:return:
"""
if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
# this is from torch-stft: https://github.com/pseeth/torch-stft
t = self.window.repeat(1, 1, inputs.size(-1))**2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
return outputs
def main():
stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
spec = stft.forward(mixture)
# shape: [batch_size, freq_bins, time_steps]
print(spec.shape)
waveform = istft.forward(spec)
# shape: [batch_size, channels, num_samples]
print(waveform.shape)
return
if __name__ == "__main__":
main()