|
from distutils.version import LooseVersion |
|
|
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import pad_list |
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.1"): |
|
DEFAULT_TIME_WARP_MODE = "bicubic" |
|
else: |
|
|
|
DEFAULT_TIME_WARP_MODE = "bilinear" |
|
|
|
|
|
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): |
|
"""Time warping using torch.interpolate. |
|
|
|
Args: |
|
x: (Batch, Time, Freq) |
|
window: time warp parameter |
|
mode: Interpolate mode |
|
""" |
|
|
|
|
|
org_size = x.size() |
|
if x.dim() == 3: |
|
|
|
x = x[:, None] |
|
|
|
t = x.shape[2] |
|
if t - window <= window: |
|
return x.view(*org_size) |
|
|
|
center = torch.randint(window, t - window, (1,))[0] |
|
warped = torch.randint(center - window, center + window, (1,))[0] + 1 |
|
|
|
|
|
|
|
left = torch.nn.functional.interpolate( |
|
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False |
|
) |
|
right = torch.nn.functional.interpolate( |
|
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False |
|
) |
|
|
|
if x.requires_grad: |
|
x = torch.cat([left, right], dim=-2) |
|
else: |
|
x[:, :, :warped] = left |
|
x[:, :, warped:] = right |
|
|
|
return x.view(*org_size) |
|
|
|
|
|
class TimeWarp(torch.nn.Module): |
|
"""Time warping using torch.interpolate. |
|
|
|
Args: |
|
window: time warp parameter |
|
mode: Interpolate mode |
|
""" |
|
|
|
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): |
|
super().__init__() |
|
self.window = window |
|
self.mode = mode |
|
|
|
def extra_repr(self): |
|
return f"window={self.window}, mode={self.mode}" |
|
|
|
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): |
|
"""Forward function. |
|
|
|
Args: |
|
x: (Batch, Time, Freq) |
|
x_lengths: (Batch,) |
|
""" |
|
|
|
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): |
|
|
|
y = time_warp(x, window=self.window, mode=self.mode) |
|
else: |
|
|
|
ys = [] |
|
for i in range(x.size(0)): |
|
_y = time_warp( |
|
x[i][None, : x_lengths[i]], |
|
window=self.window, |
|
mode=self.mode, |
|
)[0] |
|
ys.append(_y) |
|
y = pad_list(ys, 0.0) |
|
|
|
return y, x_lengths |
|
|