Spaces:
Running
Running
import torch | |
from torch.nn import Module | |
from .bsconv import BSConv1d | |
class ConvTransposed(Module): | |
r"""`ConvTransposed` applies a 1D convolution operation, with the main difference that it transposes the | |
last two dimensions of the input tensor before and after applying the `BSConv1d` convolution operation. | |
This can be useful in certain architectures where the tensor dimensions are processed in a different order. | |
The `ConvTransposed` class performs a `BSConv` operation after transposing the input tensor dimensions. Specifically, it swaps the channels and width dimensions of a tensor, applies the convolution, and then swaps the dimensions back to their original order. The intuition behind swapping dimensions can depend on the specific use case in the larger architecture; typically, it's used when the operation or sequence of operations expected a different arrangement of dimensions. | |
Args: | |
in_channels (int): Number of channels in the input | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int): Size of the kernel used in convolution | |
padding (int): Zero-padding added around the input tensor along the width direction | |
Attributes: | |
conv (BSConv1d): `BSConv1d` module to apply convolution. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int = 1, | |
padding: int = 0, | |
): | |
super().__init__() | |
# Define BSConv1d convolutional layer | |
self.conv = BSConv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=padding, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Forward propagation method for the ConvTransposed layer. | |
Args: | |
x (torch.Tensor): input tensor | |
Returns: | |
x (torch.Tensor): output tensor after application of ConvTransposed | |
""" | |
# Transpose the last two dimensions (dimension 1 and 2 here). Now the tensor has shape (N, W, C) | |
x = x.contiguous().transpose(1, 2) | |
# Apply BSConv1d convolution. | |
x = self.conv(x) | |
# Transpose the last two dimensions back to their original order. Now the tensor has shape (N, C, W) | |
# Return final output tensor | |
return x.contiguous().transpose(1, 2) | |