File size: 2,387 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch.nn import Module

from .bsconv import BSConv1d


class ConvTransposed(Module):
    r"""`ConvTransposed` applies a 1D convolution operation, with the main difference that it transposes the
    last two dimensions of the input tensor before and after applying the `BSConv1d` convolution operation.
    This can be useful in certain architectures where the tensor dimensions are processed in a different order.

    The `ConvTransposed` class performs a `BSConv` operation after transposing the input tensor dimensions. Specifically, it swaps the channels and width dimensions of a tensor, applies the convolution, and then swaps the dimensions back to their original order. The intuition behind swapping dimensions can depend on the specific use case in the larger architecture; typically, it's used when the operation or sequence of operations expected a different arrangement of dimensions.

    Args:
        in_channels (int): Number of channels in the input
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the kernel used in convolution
        padding (int): Zero-padding added around the input tensor along the width direction

    Attributes:
        conv (BSConv1d): `BSConv1d` module to apply convolution.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 1,
        padding: int = 0,
    ):
        super().__init__()

        # Define BSConv1d convolutional layer
        self.conv = BSConv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward propagation method for the ConvTransposed layer.

        Args:
            x (torch.Tensor): input tensor

        Returns:
            x (torch.Tensor): output tensor after application of ConvTransposed
        """
        # Transpose the last two dimensions (dimension 1 and 2 here). Now the tensor has shape (N, W, C)
        x = x.contiguous().transpose(1, 2)

        # Apply BSConv1d convolution.
        x = self.conv(x)

        # Transpose the last two dimensions back to their original order. Now the tensor has shape (N, C, W)
        # Return final output tensor
        return x.contiguous().transpose(1, 2)