|
import torch |
|
from typeguard import check_argument_types |
|
from typing import Optional |
|
from typing import Tuple |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
|
|
|
|
class LabelAggregate(torch.nn.Module): |
|
def __init__( |
|
self, |
|
win_length: int = 512, |
|
hop_length: int = 128, |
|
center: bool = True, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
self.win_length = win_length |
|
self.hop_length = hop_length |
|
self.center = center |
|
|
|
def extra_repr(self): |
|
return ( |
|
f"win_length={self.win_length}, " |
|
f"hop_length={self.hop_length}, " |
|
f"center={self.center}, " |
|
) |
|
|
|
def forward( |
|
self, input: torch.Tensor, ilens: torch.Tensor = None |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
"""LabelAggregate forward function. |
|
|
|
Args: |
|
input: (Batch, Nsamples, Label_dim) |
|
ilens: (Batch) |
|
Returns: |
|
output: (Batch, Frames, Label_dim) |
|
|
|
""" |
|
bs = input.size(0) |
|
max_length = input.size(1) |
|
label_dim = input.size(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.center: |
|
pad = self.win_length // 2 |
|
max_length = max_length + 2 * pad |
|
input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) |
|
nframe = (max_length - self.win_length) // self.hop_length + 1 |
|
|
|
|
|
output = input.as_strided( |
|
(bs, nframe, self.win_length, label_dim), |
|
(max_length * label_dim, self.hop_length * label_dim, label_dim, 1), |
|
) |
|
|
|
|
|
output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2) |
|
output = output.float() |
|
|
|
|
|
if ilens is not None: |
|
if self.center: |
|
pad = self.win_length // 2 |
|
ilens = ilens + 2 * pad |
|
|
|
olens = (ilens - self.win_length) // self.hop_length + 1 |
|
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) |
|
else: |
|
olens = None |
|
|
|
return output, olens |
|
|