|
import torch |
|
|
|
from espnet2.enh.decoder.abs_decoder import AbsDecoder |
|
|
|
|
|
class ConvDecoder(AbsDecoder): |
|
"""Transposed Convolutional decoder for speech enhancement and separation """ |
|
|
|
def __init__( |
|
self, |
|
channel: int, |
|
kernel_size: int, |
|
stride: int, |
|
): |
|
super().__init__() |
|
self.convtrans1d = torch.nn.ConvTranspose1d( |
|
channel, 1, kernel_size, bias=False, stride=stride |
|
) |
|
|
|
def forward(self, input: torch.Tensor, ilens: torch.Tensor): |
|
"""Forward. |
|
|
|
Args: |
|
input (torch.Tensor): spectrum [Batch, T, F] |
|
ilens (torch.Tensor): input lengths [Batch] |
|
""" |
|
input = input.transpose(1, 2) |
|
batch_size = input.shape[0] |
|
wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max())) |
|
wav = wav.squeeze(1) |
|
|
|
return wav, ilens |
|
|