Spaces:
Running
Running
import torch | |
from torch.nn import Module | |
from .conv1d import DepthWiseConv1d, PointwiseConv1d | |
class BSConv1d(Module): | |
r"""`BSConv1d` implements the `BSConv` concept which is based on the paper [BSConv: | |
Binarized Separated Convolutional Neural Networks](https://arxiv.org/pdf/2003.13549.pdf). | |
`BSConv` is an amalgamation of depthwise separable convolution and pointwise convolution. | |
Depthwise separable convolution utilizes far fewer parameters by separating the spatial | |
(depthwise) and channel-wise (pointwise) operations. Meanwhile, pointwise convolution | |
helps in transforming the channel characteristics without considering the channel's context. | |
Args: | |
channels_in (int): Number of input channels | |
channels_out (int): Number of output channels produced by the convolution | |
kernel_size (int): Size of the kernel used in depthwise convolution | |
padding (int): Zeropadding added around the input tensor along the height and width directions | |
Attributes: | |
pointwise (PointwiseConv1d): Pointwise convolution module | |
depthwise (DepthWiseConv1d): Depthwise separable convolution module | |
""" | |
def __init__( | |
self, | |
channels_in: int, | |
channels_out: int, | |
kernel_size: int, | |
padding: int, | |
): | |
super().__init__() | |
# Instantiate Pointwise Convolution Module: | |
# First operation in BSConv: the number of input channels is transformed to the number | |
# of output channels without taking into account the channel context. | |
self.pointwise = PointwiseConv1d(channels_in, channels_out) | |
# Instantiate Depthwise Convolution Module: | |
# Second operation in BSConv: A spatial convolution is performed independently over each output | |
# channel from the pointwise convolution. | |
self.depthwise = DepthWiseConv1d( | |
channels_out, | |
channels_out, | |
kernel_size=kernel_size, | |
padding=padding, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# Propagate input tensor through pointwise convolution. | |
x1 = self.pointwise(x) | |
# Propagate the result of the previous pointwise convolution through the depthwise convolution. | |
# Return final output of the sequence of pointwise and depthwise convolutions | |
return self.depthwise(x1) | |