HoneyTian's picture
add dfnet2
ed91efa
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
"""
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
if win_type == "None" or win_type is None:
window = np.ones(win_size)
else:
window = get_window(win_type, win_size, fftbins=True)**0.5
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
real_kernel = np.real(fourier_basis)
image_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, image_kernel], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
result = (
torch.from_numpy(kernel.astype(np.float32)),
torch.from_numpy(window[None, :, None].astype(np.float32))
)
return result
class ConvSTFT(nn.Module):
def __init__(self,
nfft: int,
win_size: int,
hop_size: int,
win_type: str = "hamming",
power: int = None,
requires_grad: bool = False):
super(ConvSTFT, self).__init__()
if nfft is None:
self.nfft = int(2**np.ceil(np.log2(win_size)))
else:
self.nfft = nfft
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
self.win_size = win_size
self.hop_size = hop_size
self.stride = hop_size
self.dim = self.nfft
self.power = power
def forward(self, waveform: torch.Tensor):
if waveform.dim() == 2:
waveform = torch.unsqueeze(waveform, 1)
matrix = F.conv1d(waveform, self.weight, stride=self.stride)
dim = self.dim // 2 + 1
real = matrix[:, :dim, :]
imag = matrix[:, dim:, :]
spec = torch.complex(real, imag)
# spec shape: [b, f, t], torch.complex64
if self.power is None:
return spec
elif self.power == 1:
mags = torch.sqrt(real**2 + imag**2)
# phase = torch.atan2(imag, real)
return mags
elif self.power == 2:
power = real**2 + imag**2
return power
else:
raise AssertionError
class ConviSTFT(nn.Module):
def __init__(self,
win_size: int,
hop_size: int,
nfft: int = None,
win_type: str = "hamming",
requires_grad: bool = False):
super(ConviSTFT, self).__init__()
if nfft is None:
self.nfft = int(2**np.ceil(np.log2(win_size)))
else:
self.nfft = nfft
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
# weight shape: [f*2, 1, nfft]
# f = nfft // 2 + 1
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
self.stride = hop_size
self.dim = self.nfft
self.register_buffer("window", window)
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
# window shape: [1, nfft, 1]
# enframe shape: [nfft, 1, nfft]
def forward(self,
spec: torch.Tensor):
"""
self.weight shape: [f*2, 1, win_size]
self.window shape: [1, win_size, 1]
self.enframe shape: [win_size, 1, win_size]
:param spec: torch.Tensor, shape: [b, f, t, 2]
:return:
"""
spec = torch.view_as_real(spec)
# spec shape: [b, f, t, 2]
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
# matrix shape: [b, f*2, t]
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
# waveform shape: [b, 1, num_samples]
# this is from torch-stft: https://github.com/pseeth/torch-stft
t = self.window.repeat(1, 1, matrix.size(-1))**2
# t shape: [1, win_size, t]
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
# coff shape: [1, 1, num_samples]
waveform = waveform / (coff + 1e-8)
# waveform = waveform / coff
return waveform
@torch.no_grad()
def forward_chunk(self,
spec: torch.Tensor,
cache_dict: dict = None
):
"""
:param spec: shape: [b, f, t]
:param cache_dict: dict,
waveform_cache shape: [b, 1, win_size - hop_size]
coff_cache shape: [b, 1, win_size - hop_size]
:return:
"""
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
waveform_cache = cache_dict["waveform_cache"]
coff_cache = cache_dict["coff_cache"]
spec = torch.view_as_real(spec)
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
t = self.window.repeat(1, 1, matrix.size(-1))**2
coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride)
overlap_size = self.win_size - self.hop_size
if waveform_cache is not None:
waveform_current[:, :, :overlap_size] += waveform_cache
waveform_output = waveform_current[:, :, :self.hop_size]
new_waveform_cache = waveform_current[:, :, self.hop_size:]
if coff_cache is not None:
coff_current[:, :, :overlap_size] += coff_cache
coff_output = coff_current[:, :, :self.hop_size]
new_coff_cache = coff_current[:, :, self.hop_size:]
waveform_output = waveform_output / (coff_output + 1e-8)
new_cache_dict = {
"waveform_cache": new_waveform_cache,
"coff_cache": new_coff_cache,
}
return waveform_output, new_cache_dict
def main():
nfft = 512
win_size = 512
hop_size = 256
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
mixture = torch.rand(size=(1, 16000), dtype=torch.float32)
b, num_samples = mixture.shape
t = (num_samples - win_size) / hop_size + 1
spec = stft.forward(mixture)
b, f, t = spec.shape
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
# spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
waveform = istft.forward(spec)
# shape: [batch_size, channels, num_samples]
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
for i in range(int(t)):
begin = i * hop_size
end = begin + win_size
sub_spec = spec[:, :, i:i+1]
sub_waveform = istft.forward(sub_spec)
# (b, 1, win_size)
waveform[:, :, begin:end] = sub_waveform
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
return
def main2():
nfft = 512
win_size = 512
hop_size = 256
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
mixture = torch.rand(size=(1, 16128), dtype=torch.float32)
b, num_samples = mixture.shape
spec = stft.forward(mixture)
b, f, t = spec.shape
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
waveform = istft.forward(spec)
# shape: [batch_size, channels, num_samples]
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
cache_dict = None
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
for i in range(int(t)):
sub_spec = spec[:, :, i:i+1]
begin = i * hop_size
end = begin + win_size - hop_size
sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict)
# end = begin + win_size
# sub_waveform = istft.forward(sub_spec)
waveform[:, :, begin:end] = sub_waveform
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
return
if __name__ == "__main__":
main2()