poiqazwsx's picture
Upload 57 files
51e2f90
raw
history blame
8.24 kB
from typing import List, Tuple
import torch
import torch.nn as nn
from models.scnet_unofficial.utils import get_convtranspose_output_padding
class FusionLayer(nn.Module):
"""
FusionLayer class implements a module for fusing two input tensors using convolutional operations.
Args:
- input_dim (int): Dimensionality of the input channels.
- kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3.
- stride (int, optional): Stride value for the convolutional layer. Default is 1.
- padding (int, optional): Padding value for the convolutional layer. Default is 1.
Shapes:
- Input: (B, F, T, C) and (B, F, T, C) where
B is batch size,
F is the number of features,
T is sequence length,
C is input dimensionality.
- Output: (B, F, T, C) where
B is batch size,
F is the number of features,
T is sequence length,
C is input dimensionality.
"""
def __init__(
self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
):
"""
Initializes FusionLayer with input dimension, kernel size, stride, and padding.
"""
super().__init__()
self.conv = nn.Conv2d(
input_dim * 2,
input_dim * 2,
kernel_size=(kernel_size, 1),
stride=(stride, 1),
padding=(padding, 0),
)
self.activation = nn.GLU()
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the FusionLayer.
Args:
- x1 (torch.Tensor): First input tensor of shape (B, F, T, C).
- x2 (torch.Tensor): Second input tensor of shape (B, F, T, C).
Returns:
- torch.Tensor: Output tensor of shape (B, F, T, C).
"""
x = x1 + x2
x = x.repeat(1, 1, 1, 2)
x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
x = self.activation(x)
return x
class Upsample(nn.Module):
"""
Upsample class implements a module for upsampling input tensors using transposed 2D convolution.
Args:
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels.
- stride (int): Stride value for the transposed convolution operation.
- output_padding (int): Output padding value for the transposed convolution operation.
Shapes:
- Input: (B, C_in, F, T) where
B is batch size,
C_in is the number of input channels,
F is the frequency dimension,
T is the time dimension.
- Output: (B, C_out, F * stride + output_padding, T) where
B is batch size,
C_out is the number of output channels,
F * stride + output_padding is the upsampled frequency dimension.
"""
def __init__(
self, input_dim: int, output_dim: int, stride: int, output_padding: int
):
"""
Initializes Upsample with input dimension, output dimension, stride, and output padding.
"""
super().__init__()
self.conv = nn.ConvTranspose2d(
input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the Upsample module.
Args:
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
Returns:
- torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T).
"""
return self.conv(x)
class SULayer(nn.Module):
"""
SULayer class implements a subband upsampling layer using transposed convolution.
Args:
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels.
- upsample_stride (int): Stride value for the upsampling operation.
- subband_shape (int): Shape of the subband.
- sd_interval (Tuple[int, int]): Start and end indices of the subband interval.
Shapes:
- Input: (B, F, T, C) where
B is batch size,
F is the number of features,
T is sequence length,
C is input dimensionality.
- Output: (B, F, T, C) where
B is batch size,
F is the number of features,
T is sequence length,
C is input dimensionality.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
upsample_stride: int,
subband_shape: int,
sd_interval: Tuple[int, int],
):
"""
Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval.
"""
super().__init__()
sd_shape = sd_interval[1] - sd_interval[0]
upsample_output_padding = get_convtranspose_output_padding(
input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
)
self.upsample = Upsample(
input_dim=input_dim,
output_dim=output_dim,
stride=upsample_stride,
output_padding=upsample_output_padding,
)
self.sd_interval = sd_interval
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the SULayer.
Args:
- x (torch.Tensor): Input tensor of shape (B, F, T, C).
Returns:
- torch.Tensor: Output tensor of shape (B, F, T, C).
"""
x = x[:, self.sd_interval[0] : self.sd_interval[1]]
x = x.permute(0, 3, 1, 2)
x = self.upsample(x)
x = x.permute(0, 2, 3, 1)
return x
class SUBlock(nn.Module):
"""
SUBlock class implements a block with fusion layer and subband upsampling layers.
Args:
- input_dim (int): Dimensionality of the input channels.
- output_dim (int): Dimensionality of the output channels.
- upsample_strides (List[int]): List of stride values for the upsampling operations.
- subband_shapes (List[int]): List of shapes for the subbands.
- sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition.
Shapes:
- Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where
B is batch size,
Fi-1 is the number of input subbands,
T is sequence length,
Ci-1 is the number of input channels.
- Output: (B, Fi, T, Ci) where
B is batch size,
Fi is the number of output subbands,
T is sequence length,
Ci is the number of output channels.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
upsample_strides: List[int],
subband_shapes: List[int],
sd_intervals: List[Tuple[int, int]],
):
"""
Initializes SUBlock with input dimension, output dimension,
upsample strides, subband shapes, and subband intervals.
"""
super().__init__()
self.fusion_layer = FusionLayer(input_dim=input_dim)
self.su_layers = nn.ModuleList(
SULayer(
input_dim=input_dim,
output_dim=output_dim,
upsample_stride=uss,
subband_shape=sbs,
sd_interval=sdi,
)
for i, (uss, sbs, sdi) in enumerate(
zip(upsample_strides, subband_shapes, sd_intervals)
)
)
def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
"""
Performs forward pass through the SUBlock.
Args:
- x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1).
- x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1).
Returns:
- torch.Tensor: Output tensor of shape (B, Fi, T, Ci).
"""
x = self.fusion_layer(x, x_skip)
x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
return x