poiqazwsx's picture
Upload 57 files
51e2f90
raw
history blame
9.74 kB
from typing import List, Tuple
import torch
import torch.nn as nn
from models.scnet_unofficial.utils import create_intervals
class Downsample(nn.Module):
"""
Downsample class implements a module for downsampling input tensors using 2D convolution.
Args:
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels.
- stride (int): Stride value for the convolution operation.
Shapes:
- Input: (B, C_in, F, T) where
B is batch size,
C_in is the number of input channels,
F is the frequency dimension,
T is the time dimension.
- Output: (B, C_out, F // stride, T) where
B is batch size,
C_out is the number of output channels,
F // stride is the downsampled frequency dimension.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
stride: int,
):
"""
Initializes Downsample with input dimension, output dimension, and stride.
"""
super().__init__()
self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the Downsample module.
Args:
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
Returns:
- torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
"""
return self.conv(x)
class ConvolutionModule(nn.Module):
"""
ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
Args:
- input_dim (int): Dimensionality of the input features.
- hidden_dim (int): Dimensionality of the hidden features.
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
- bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
Shapes:
- Input: (B, T, D) where
B is batch size,
T is sequence length,
D is input dimensionality.
- Output: (B, T, D) where
B is batch size,
T is sequence length,
D is input dimensionality.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
kernel_sizes: List[int],
bias: bool = False,
) -> None:
"""
Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
"""
super().__init__()
self.sequential = nn.Sequential(
nn.GroupNorm(num_groups=1, num_channels=input_dim),
nn.Conv1d(
input_dim,
2 * hidden_dim,
kernel_sizes[0],
stride=1,
padding=(kernel_sizes[0] - 1) // 2,
bias=bias,
),
nn.GLU(dim=1),
nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_sizes[1],
stride=1,
padding=(kernel_sizes[1] - 1) // 2,
groups=hidden_dim,
bias=bias,
),
nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
nn.SiLU(),
nn.Conv1d(
hidden_dim,
input_dim,
kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the ConvolutionModule.
Args:
- x (torch.Tensor): Input tensor of shape (B, T, D).
Returns:
- torch.Tensor: Output tensor of shape (B, T, D).
"""
x = x.transpose(1, 2)
x = x + self.sequential(x)
x = x.transpose(1, 2)
return x
class SDLayer(nn.Module):
"""
SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
Args:
- subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels after downsampling.
- downsample_stride (int): Stride value for the downsampling operation.
- n_conv_modules (int): Number of convolutional modules.
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
- bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
Shapes:
- Input: (B, Fi, T, Ci) where
B is batch size,
Fi is the number of input subbands,
T is sequence length, and
Ci is the number of input channels.
- Output: (B, Fi+1, T, Ci+1) where
B is batch size,
Fi+1 is the number of output subbands,
T is sequence length,
Ci+1 is the number of output channels.
"""
def __init__(
self,
subband_interval: Tuple[float, float],
input_dim: int,
output_dim: int,
downsample_stride: int,
n_conv_modules: int,
kernel_sizes: List[int],
bias: bool = True,
):
"""
Initializes SDLayer with subband interval, input dimension,
output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
"""
super().__init__()
self.subband_interval = subband_interval
self.downsample = Downsample(input_dim, output_dim, downsample_stride)
self.activation = nn.GELU()
conv_modules = [
ConvolutionModule(
input_dim=output_dim,
hidden_dim=output_dim // 4,
kernel_sizes=kernel_sizes,
bias=bias,
)
for _ in range(n_conv_modules)
]
self.conv_modules = nn.Sequential(*conv_modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the SDLayer.
Args:
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
Returns:
- torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
"""
B, F, T, C = x.shape
x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
x = x.permute(0, 3, 1, 2)
x = self.downsample(x)
x = self.activation(x)
x = x.permute(0, 2, 3, 1)
B, F, T, C = x.shape
x = x.reshape((B * F), T, C)
x = self.conv_modules(x)
x = x.reshape(B, F, T, C)
return x
class SDBlock(nn.Module):
"""
SDBlock class implements a block with subband decomposition layers and global convolution.
Args:
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels.
- bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
- downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
- n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
- kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
Shapes:
- Input: (B, Fi, T, Ci) where
B is batch size,
Fi is the number of input subbands,
T is sequence length,
Ci is the number of input channels.
- Output: (B, Fi+1, T, Ci+1) where
B is batch size,
Fi+1 is the number of output subbands,
T is sequence length,
Ci+1 is the number of output channels.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
bandsplit_ratios: List[float],
downsample_strides: List[int],
n_conv_modules: List[int],
kernel_sizes: List[int] = None,
):
"""
Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
"""
super().__init__()
if kernel_sizes is None:
kernel_sizes = [3, 3, 1]
assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
subband_intervals = create_intervals(bandsplit_ratios)
self.sd_layers = nn.ModuleList(
SDLayer(
input_dim=input_dim,
output_dim=output_dim,
subband_interval=sbi,
downsample_stride=dss,
n_conv_modules=ncm,
kernel_sizes=kernel_sizes,
)
for sbi, dss, ncm in zip(
subband_intervals, downsample_strides, n_conv_modules
)
)
self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Performs forward pass through the SDBlock.
Args:
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
"""
x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
return x, x_skip