File size: 7,763 Bytes
51e2f90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
'''
SCNet - great paper, great implementation
https://arxiv.org/pdf/2401.13276.pdf
https://github.com/amanteur/SCNet-PyTorch
'''
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock
from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr
from einops import rearrange, pack, unpack
from functools import partial
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.gamma
class BandSplit(nn.Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
class SCNet(nn.Module):
"""
SCNet class implements a source separation network,
which explicitly split the spectrogram of the mixture into several subbands
and introduce a sparsity-based encoder to model different frequency bands.
Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION"
Authors: Weinan Tong, Jiaxu Zhu et al.
Link: https://arxiv.org/abs/2401.13276.pdf
Args:
- n_fft (int): Number of FFTs to determine the frequency dimension of the input.
- dims (List[int]): List of channel dimensions for each block.
- bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
- downsample_strides (List[int]): List of stride values for downsampling in each block.
- n_conv_modules (List[int]): List specifying the number of convolutional modules in each block.
- n_rnn_layers (int): Number of recurrent layers in the dual path RNN.
- rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN.
- n_sources (int, optional): Number of sources to be separated. Default is 4.
Shapes:
- Input: (B, C, T) where
B is batch size,
C is channel dim (mono / stereo),
T is time dim
- Output: (B, N, C, T) where
B is batch size,
N is the number of sources.
C is channel dim (mono / stereo),
T is sequence length,
"""
@beartype
def __init__(
self,
n_fft: int,
dims: List[int],
bandsplit_ratios: List[float],
downsample_strides: List[int],
n_conv_modules: List[int],
n_rnn_layers: int,
rnn_hidden_dim: int,
n_sources: int = 4,
hop_length: int = 1024,
win_length: int = 4096,
stft_window_fn: Optional[Callable] = None,
stft_normalized: bool = False,
**kwargs
):
"""
Initializes SCNet with input parameters.
"""
super().__init__()
self.assert_input_data(
bandsplit_ratios,
downsample_strides,
n_conv_modules,
)
n_blocks = len(dims) - 1
n_freq_bins = n_fft // 2 + 1
subband_shapes, sd_intervals = compute_sd_layer_shapes(
input_shape=n_freq_bins,
bandsplit_ratios=bandsplit_ratios,
downsample_strides=downsample_strides,
n_layers=n_blocks,
)
self.sd_blocks = nn.ModuleList(
SDBlock(
input_dim=dims[i],
output_dim=dims[i + 1],
bandsplit_ratios=bandsplit_ratios,
downsample_strides=downsample_strides,
n_conv_modules=n_conv_modules,
)
for i in range(n_blocks)
)
self.dualpath_blocks = DualPathRNN(
n_layers=n_rnn_layers,
input_dim=dims[-1],
hidden_dim=rnn_hidden_dim,
**kwargs
)
self.su_blocks = nn.ModuleList(
SUBlock(
input_dim=dims[i + 1],
output_dim=dims[i] if i != 0 else dims[i] * n_sources,
subband_shapes=subband_shapes[i],
sd_intervals=sd_intervals[i],
upsample_strides=downsample_strides,
)
for i in reversed(range(n_blocks))
)
self.gcr = compute_gcr(subband_shapes)
self.stft_kwargs = dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length)
self.n_sources = n_sources
self.hop_length = hop_length
@staticmethod
def assert_input_data(*args):
"""
Asserts that the shapes of input features are equal.
"""
for arg1 in args:
for arg2 in args:
if len(arg1) != len(arg2):
raise ValueError(
f"Shapes of input features {arg1} and {arg2} are not equal."
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the SCNet.
Args:
- x (torch.Tensor): Input tensor of shape (B, C, T).
Returns:
- torch.Tensor: Output tensor of shape (B, N, C, T).
"""
device = x.device
stft_window = self.stft_window_fn(device=device)
if x.ndim == 2:
x = rearrange(x, 'b t -> b 1 t')
c = x.shape[1]
stft_pad = self.hop_length - x.shape[-1] % self.hop_length
x = F.pad(x, (0, stft_pad))
# stft
x, ps = pack_one(x, '* t')
x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True)
x = torch.view_as_real(x)
x = unpack_one(x, ps, '* c f t')
x = rearrange(x, 'b c f t r -> b f t (c r)')
# encoder part
x_skips = []
for sd_block in self.sd_blocks:
x, x_skip = sd_block(x)
x_skips.append(x_skip)
# separation part
x = self.dualpath_blocks(x)
# decoder part
for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)):
x = su_block(x, x_skip)
# istft
x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2)
x = x.contiguous()
x = torch.view_as_complex(x)
x = rearrange(x, 'b n c f t -> (b n c) f t')
x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False)
x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources)
x = x[..., :-stft_pad]
return x
|