File size: 2,329 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
class LabelAggregate(torch.nn.Module):
def __init__(
self,
win_length: int = 512,
hop_length: int = 128,
center: bool = True,
):
assert check_argument_types()
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.center = center
def extra_repr(self):
return (
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""LabelAggregate forward function.
Args:
input: (Batch, Nsamples, Label_dim)
ilens: (Batch)
Returns:
output: (Batch, Frames, Label_dim)
"""
bs = input.size(0)
max_length = input.size(1)
label_dim = input.size(2)
# NOTE(jiatong):
# The default behaviour of label aggregation is compatible with
# torch.stft about framing and padding.
# Step1: center padding
if self.center:
pad = self.win_length // 2
max_length = max_length + 2 * pad
input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
nframe = (max_length - self.win_length) // self.hop_length + 1
# Step2: framing
output = input.as_strided(
(bs, nframe, self.win_length, label_dim),
(max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
)
# Step3: aggregate label
output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
output = output.float()
# Step4: process lengths
if ilens is not None:
if self.center:
pad = self.win_length // 2
ilens = ilens + 2 * pad
olens = (ilens - self.win_length) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output, olens
|