Spaces:
Sleeping
Sleeping
import torch | |
from torch.nn import functional as F | |
class Stretch2d(torch.nn.Module): | |
def __init__(self, x_scale, y_scale, mode="nearest"): | |
super().__init__() | |
self.x_scale = x_scale | |
self.y_scale = y_scale | |
self.mode = mode | |
def forward(self, x): | |
""" | |
x (Tensor): Input tensor (B, C, F, T). | |
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), | |
""" | |
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) | |
class UpsampleNetwork(torch.nn.Module): | |
# pylint: disable=dangerous-default-value | |
def __init__( | |
self, | |
upsample_factors, | |
nonlinear_activation=None, | |
nonlinear_activation_params={}, | |
interpolate_mode="nearest", | |
freq_axis_kernel_size=1, | |
use_causal_conv=False, | |
): | |
super().__init__() | |
self.use_causal_conv = use_causal_conv | |
self.up_layers = torch.nn.ModuleList() | |
for scale in upsample_factors: | |
# interpolation layer | |
stretch = Stretch2d(scale, 1, interpolate_mode) | |
self.up_layers += [stretch] | |
# conv layer | |
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." | |
freq_axis_padding = (freq_axis_kernel_size - 1) // 2 | |
kernel_size = (freq_axis_kernel_size, scale * 2 + 1) | |
if use_causal_conv: | |
padding = (freq_axis_padding, scale * 2) | |
else: | |
padding = (freq_axis_padding, scale) | |
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) | |
self.up_layers += [conv] | |
# nonlinear | |
if nonlinear_activation is not None: | |
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) | |
self.up_layers += [nonlinear] | |
def forward(self, c): | |
""" | |
c : (B, C, T_in). | |
Tensor: (B, C, T_upsample) | |
""" | |
c = c.unsqueeze(1) # (B, 1, C, T) | |
for f in self.up_layers: | |
c = f(c) | |
return c.squeeze(1) # (B, C, T') | |
class ConvUpsample(torch.nn.Module): | |
# pylint: disable=dangerous-default-value | |
def __init__( | |
self, | |
upsample_factors, | |
nonlinear_activation=None, | |
nonlinear_activation_params={}, | |
interpolate_mode="nearest", | |
freq_axis_kernel_size=1, | |
aux_channels=80, | |
aux_context_window=0, | |
use_causal_conv=False, | |
): | |
super().__init__() | |
self.aux_context_window = aux_context_window | |
self.use_causal_conv = use_causal_conv and aux_context_window > 0 | |
# To capture wide-context information in conditional features | |
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 | |
# NOTE(kan-bayashi): Here do not use padding because the input is already padded | |
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) | |
self.upsample = UpsampleNetwork( | |
upsample_factors=upsample_factors, | |
nonlinear_activation=nonlinear_activation, | |
nonlinear_activation_params=nonlinear_activation_params, | |
interpolate_mode=interpolate_mode, | |
freq_axis_kernel_size=freq_axis_kernel_size, | |
use_causal_conv=use_causal_conv, | |
) | |
def forward(self, c): | |
""" | |
c : (B, C, T_in). | |
Tensor: (B, C, T_upsampled), | |
""" | |
c_ = self.conv_in(c) | |
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ | |
return self.upsample(c) | |