|
'''
|
|
SCNet - great paper, great implementation
|
|
https://arxiv.org/pdf/2401.13276.pdf
|
|
https://github.com/amanteur/SCNet-PyTorch
|
|
'''
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
|
|
from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock
|
|
from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr
|
|
|
|
from einops import rearrange, pack, unpack
|
|
from functools import partial
|
|
|
|
from beartype.typing import Tuple, Optional, List, Callable
|
|
from beartype import beartype
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
|
|
def pack_one(t, pattern):
|
|
return pack([t], pattern)
|
|
|
|
|
|
def unpack_one(t, ps, pattern):
|
|
return unpack(t, ps, pattern)[0]
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.scale = dim ** 0.5
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
def forward(self, x):
|
|
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
|
|
|
|
|
class BandSplit(nn.Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_inputs: Tuple[int, ...]
|
|
):
|
|
super().__init__()
|
|
self.dim_inputs = dim_inputs
|
|
self.to_features = ModuleList([])
|
|
|
|
for dim_in in dim_inputs:
|
|
net = nn.Sequential(
|
|
RMSNorm(dim_in),
|
|
nn.Linear(dim_in, dim)
|
|
)
|
|
|
|
self.to_features.append(net)
|
|
|
|
def forward(self, x):
|
|
x = x.split(self.dim_inputs, dim=-1)
|
|
|
|
outs = []
|
|
for split_input, to_feature in zip(x, self.to_features):
|
|
split_output = to_feature(split_input)
|
|
outs.append(split_output)
|
|
|
|
return torch.stack(outs, dim=-2)
|
|
|
|
|
|
class SCNet(nn.Module):
|
|
"""
|
|
SCNet class implements a source separation network,
|
|
which explicitly split the spectrogram of the mixture into several subbands
|
|
and introduce a sparsity-based encoder to model different frequency bands.
|
|
|
|
Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION"
|
|
Authors: Weinan Tong, Jiaxu Zhu et al.
|
|
Link: https://arxiv.org/abs/2401.13276.pdf
|
|
|
|
Args:
|
|
- n_fft (int): Number of FFTs to determine the frequency dimension of the input.
|
|
- dims (List[int]): List of channel dimensions for each block.
|
|
- bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
|
|
- downsample_strides (List[int]): List of stride values for downsampling in each block.
|
|
- n_conv_modules (List[int]): List specifying the number of convolutional modules in each block.
|
|
- n_rnn_layers (int): Number of recurrent layers in the dual path RNN.
|
|
- rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN.
|
|
- n_sources (int, optional): Number of sources to be separated. Default is 4.
|
|
|
|
Shapes:
|
|
- Input: (B, C, T) where
|
|
B is batch size,
|
|
C is channel dim (mono / stereo),
|
|
T is time dim
|
|
- Output: (B, N, C, T) where
|
|
B is batch size,
|
|
N is the number of sources.
|
|
C is channel dim (mono / stereo),
|
|
T is sequence length,
|
|
"""
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
n_fft: int,
|
|
dims: List[int],
|
|
bandsplit_ratios: List[float],
|
|
downsample_strides: List[int],
|
|
n_conv_modules: List[int],
|
|
n_rnn_layers: int,
|
|
rnn_hidden_dim: int,
|
|
n_sources: int = 4,
|
|
hop_length: int = 1024,
|
|
win_length: int = 4096,
|
|
stft_window_fn: Optional[Callable] = None,
|
|
stft_normalized: bool = False,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initializes SCNet with input parameters.
|
|
"""
|
|
super().__init__()
|
|
self.assert_input_data(
|
|
bandsplit_ratios,
|
|
downsample_strides,
|
|
n_conv_modules,
|
|
)
|
|
|
|
n_blocks = len(dims) - 1
|
|
n_freq_bins = n_fft // 2 + 1
|
|
subband_shapes, sd_intervals = compute_sd_layer_shapes(
|
|
input_shape=n_freq_bins,
|
|
bandsplit_ratios=bandsplit_ratios,
|
|
downsample_strides=downsample_strides,
|
|
n_layers=n_blocks,
|
|
)
|
|
self.sd_blocks = nn.ModuleList(
|
|
SDBlock(
|
|
input_dim=dims[i],
|
|
output_dim=dims[i + 1],
|
|
bandsplit_ratios=bandsplit_ratios,
|
|
downsample_strides=downsample_strides,
|
|
n_conv_modules=n_conv_modules,
|
|
)
|
|
for i in range(n_blocks)
|
|
)
|
|
self.dualpath_blocks = DualPathRNN(
|
|
n_layers=n_rnn_layers,
|
|
input_dim=dims[-1],
|
|
hidden_dim=rnn_hidden_dim,
|
|
**kwargs
|
|
)
|
|
self.su_blocks = nn.ModuleList(
|
|
SUBlock(
|
|
input_dim=dims[i + 1],
|
|
output_dim=dims[i] if i != 0 else dims[i] * n_sources,
|
|
subband_shapes=subband_shapes[i],
|
|
sd_intervals=sd_intervals[i],
|
|
upsample_strides=downsample_strides,
|
|
)
|
|
for i in reversed(range(n_blocks))
|
|
)
|
|
self.gcr = compute_gcr(subband_shapes)
|
|
|
|
self.stft_kwargs = dict(
|
|
n_fft=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
normalized=stft_normalized
|
|
)
|
|
|
|
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length)
|
|
self.n_sources = n_sources
|
|
self.hop_length = hop_length
|
|
|
|
@staticmethod
|
|
def assert_input_data(*args):
|
|
"""
|
|
Asserts that the shapes of input features are equal.
|
|
"""
|
|
for arg1 in args:
|
|
for arg2 in args:
|
|
if len(arg1) != len(arg2):
|
|
raise ValueError(
|
|
f"Shapes of input features {arg1} and {arg2} are not equal."
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the SCNet.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, C, T).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, N, C, T).
|
|
"""
|
|
|
|
device = x.device
|
|
stft_window = self.stft_window_fn(device=device)
|
|
|
|
if x.ndim == 2:
|
|
x = rearrange(x, 'b t -> b 1 t')
|
|
|
|
c = x.shape[1]
|
|
|
|
stft_pad = self.hop_length - x.shape[-1] % self.hop_length
|
|
x = F.pad(x, (0, stft_pad))
|
|
|
|
|
|
x, ps = pack_one(x, '* t')
|
|
x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True)
|
|
x = torch.view_as_real(x)
|
|
x = unpack_one(x, ps, '* c f t')
|
|
x = rearrange(x, 'b c f t r -> b f t (c r)')
|
|
|
|
|
|
x_skips = []
|
|
for sd_block in self.sd_blocks:
|
|
x, x_skip = sd_block(x)
|
|
x_skips.append(x_skip)
|
|
|
|
|
|
x = self.dualpath_blocks(x)
|
|
|
|
|
|
for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)):
|
|
x = su_block(x, x_skip)
|
|
|
|
|
|
x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2)
|
|
x = x.contiguous()
|
|
|
|
x = torch.view_as_complex(x)
|
|
x = rearrange(x, 'b n c f t -> (b n c) f t')
|
|
x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False)
|
|
x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources)
|
|
|
|
x = x[..., :-stft_pad]
|
|
|
|
return x
|
|
|