Spaces:
Runtime error
Runtime error
File size: 1,310 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
class MossFormerDecoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Arguments
---------
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(MossFormerDecoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Arguments
---------
x : torch.Tensor
Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x
|