"""CausalConv1d module definition for custom decoder.""" import torch class CausalConv1d(torch.nn.Module): """CausalConv1d module for custom decoder. Args: idim (int): dimension of inputs odim (int): dimension of outputs kernel_size (int): size of convolving kernel stride (int): stride of the convolution dilation (int): spacing between the kernel points groups (int): number of blocked connections from ichannels to ochannels bias (bool): whether to add a learnable bias to the output """ def __init__( self, idim, odim, kernel_size, stride=1, dilation=1, groups=1, bias=True ): """Construct a CausalConv1d object.""" super().__init__() self._pad = (kernel_size - 1) * dilation self.causal_conv1d = torch.nn.Conv1d( idim, odim, kernel_size=kernel_size, stride=stride, padding=self._pad, dilation=dilation, groups=groups, bias=bias, ) def forward(self, x, x_mask, cache=None): """CausalConv1d forward for x. Args: x (torch.Tensor): input torch (B, U, idim) x_mask (torch.Tensor): (B, 1, U) Returns: x (torch.Tensor): input torch (B, sub(U), attention_dim) x_mask (torch.Tensor): (B, 1, sub(U)) """ x = x.permute(0, 2, 1) x = self.causal_conv1d(x) if self._pad != 0: x = x[:, :, : -self._pad] x = x.permute(0, 2, 1) return x, x_mask