File size: 750 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from espnet2.layers.label_aggregation import LabelAggregate
class LabelProcessor(torch.nn.Module):
"""Label aggregator for speaker diarization """
def __init__(
self, win_length: int = 512, hop_length: int = 128, center: bool = True
):
super().__init__()
self.label_aggregator = LabelAggregate(win_length, hop_length, center)
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input: (Batch, Nsamples, Label_dim)
ilens: (Batch)
Returns:
output: (Batch, Frames, Label_dim)
olens: (Batch)
"""
output, olens = self.label_aggregator(input, ilens)
return output, olens
|