conex / espnet2 /enh /decoder /conv_decoder.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
874 Bytes
import torch
from espnet2.enh.decoder.abs_decoder import AbsDecoder
class ConvDecoder(AbsDecoder):
"""Transposed Convolutional decoder for speech enhancement and separation """
def __init__(
self,
channel: int,
kernel_size: int,
stride: int,
):
super().__init__()
self.convtrans1d = torch.nn.ConvTranspose1d(
channel, 1, kernel_size, bias=False, stride=stride
)
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): spectrum [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
input = input.transpose(1, 2)
batch_size = input.shape[0]
wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max()))
wav = wav.squeeze(1)
return wav, ilens