File size: 874 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from espnet2.enh.decoder.abs_decoder import AbsDecoder
class ConvDecoder(AbsDecoder):
"""Transposed Convolutional decoder for speech enhancement and separation """
def __init__(
self,
channel: int,
kernel_size: int,
stride: int,
):
super().__init__()
self.convtrans1d = torch.nn.ConvTranspose1d(
channel, 1, kernel_size, bias=False, stride=stride
)
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): spectrum [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
input = input.transpose(1, 2)
batch_size = input.shape[0]
wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max()))
wav = wav.squeeze(1)
return wav, ilens
|