File size: 755 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from espnet2.diar.decoder.abs_decoder import AbsDecoder
class LinearDecoder(AbsDecoder):
"""Linear decoder for speaker diarization """
def __init__(
self,
encoder_output_size: int,
num_spk: int = 2,
):
super().__init__()
self._num_spk = num_spk
self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk)
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): hidden_space [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
output = self.linear_decoder(input)
return output
@property
def num_spk(self):
return self._num_spk
|