Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class MossFormerDecoder(nn.ConvTranspose1d): | |
"""A decoder layer that consists of ConvTranspose1d. | |
Arguments | |
--------- | |
kernel_size : int | |
Length of filters. | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
Example | |
--------- | |
>>> x = torch.randn(2, 100, 1000) | |
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) | |
>>> h = decoder(x) | |
>>> h.shape | |
torch.Size([2, 1003]) | |
""" | |
def __init__(self, *args, **kwargs): | |
super(MossFormerDecoder, self).__init__(*args, **kwargs) | |
def forward(self, x): | |
"""Return the decoded output. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Input tensor with dimensionality [B, N, L]. | |
where, B = Batchsize, | |
N = number of filters | |
L = time points | |
""" | |
if x.dim() not in [2, 3]: | |
raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__)) | |
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) | |
if torch.squeeze(x).dim() == 1: | |
x = torch.squeeze(x, dim=1) | |
else: | |
x = torch.squeeze(x) | |
return x | |