mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
import torch
import torch.nn as nn
class MossFormerDecoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Arguments
---------
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(MossFormerDecoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Arguments
---------
x : torch.Tensor
Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x