|
from typing import Callable, Sequence, Type, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] |
|
|
|
|
|
class FeedForwardModule(nn.Module): |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.net = None |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.net(x) |
|
|
|
|
|
class Residual(nn.Module): |
|
|
|
def __init__(self, module: nn.Module) -> None: |
|
super().__init__() |
|
self.module = module |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.module(x) + x |
|
|
|
|
|
class DilatedConvolutionalUnit(FeedForwardModule): |
|
|
|
def __init__( |
|
self, |
|
hidden_dim: int, |
|
dilation: int, |
|
kernel_size: int, |
|
activation: ModuleFactory, |
|
normalization: Callable[[nn.Module], |
|
nn.Module] = lambda x: x) -> None: |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
activation(), |
|
normalization( |
|
nn.Conv1d( |
|
in_channels=hidden_dim, |
|
out_channels=hidden_dim, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
padding=((kernel_size - 1) * dilation) // 2, |
|
)), |
|
activation(), |
|
nn.Conv1d(in_channels=hidden_dim, |
|
out_channels=hidden_dim, |
|
kernel_size=1), |
|
) |
|
|
|
|
|
class UpsamplingUnit(FeedForwardModule): |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
stride: int, |
|
activation: ModuleFactory, |
|
normalization: Callable[[nn.Module], |
|
nn.Module] = lambda x: x) -> None: |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
activation(), |
|
normalization( |
|
nn.ConvTranspose1d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=stride // 2+ stride % 2, |
|
output_padding=1 if stride % 2 != 0 else 0 |
|
))) |
|
|
|
|
|
class DownsamplingUnit(FeedForwardModule): |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
stride: int, |
|
activation: ModuleFactory, |
|
normalization: Callable[[nn.Module], |
|
nn.Module] = lambda x: x) -> None: |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
activation(), |
|
normalization( |
|
nn.Conv1d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding= stride // 2+ stride % 2, |
|
|
|
))) |
|
|
|
|
|
class DilatedResidualEncoder(FeedForwardModule): |
|
|
|
def __init__( |
|
self, |
|
capacity: int, |
|
dilated_unit: Type[DilatedConvolutionalUnit], |
|
downsampling_unit: Type[DownsamplingUnit], |
|
ratios: Sequence[int], |
|
dilations: Union[Sequence[int], Sequence[Sequence[int]]], |
|
pre_network_conv: Type[nn.Conv1d], |
|
post_network_conv: Type[nn.Conv1d], |
|
normalization: Callable[[nn.Module], |
|
nn.Module] = lambda x: x) -> None: |
|
super().__init__() |
|
channels = capacity * 2**np.arange(len(ratios) + 1) |
|
|
|
dilations_list = self.normalize_dilations(dilations, ratios) |
|
|
|
net = [normalization(pre_network_conv(out_channels=channels[0]))] |
|
|
|
for ratio, dilations, input_dim, output_dim in zip( |
|
ratios, dilations_list, channels[:-1], channels[1:]): |
|
for dilation in dilations: |
|
net.append(Residual(dilated_unit(input_dim, dilation))) |
|
net.append(downsampling_unit(input_dim, output_dim, ratio)) |
|
|
|
net.append(post_network_conv(in_channels=output_dim)) |
|
|
|
self.net = nn.Sequential(*net) |
|
|
|
@staticmethod |
|
def normalize_dilations(dilations: Union[Sequence[int], |
|
Sequence[Sequence[int]]], |
|
ratios: Sequence[int]): |
|
if isinstance(dilations[0], int): |
|
dilations = [dilations for _ in ratios] |
|
return dilations |
|
|
|
|
|
class DilatedResidualDecoder(FeedForwardModule): |
|
|
|
def __init__( |
|
self, |
|
capacity: int, |
|
dilated_unit: Type[DilatedConvolutionalUnit], |
|
upsampling_unit: Type[UpsamplingUnit], |
|
ratios: Sequence[int], |
|
dilations: Union[Sequence[int], Sequence[Sequence[int]]], |
|
pre_network_conv: Type[nn.Conv1d], |
|
post_network_conv: Type[nn.Conv1d], |
|
normalization: Callable[[nn.Module], |
|
nn.Module] = lambda x: x) -> None: |
|
super().__init__() |
|
channels = capacity * 2**np.arange(len(ratios) + 1) |
|
channels = channels[::-1] |
|
|
|
dilations_list = self.normalize_dilations(dilations, ratios) |
|
dilations_list = dilations_list[::-1] |
|
|
|
net = [pre_network_conv(out_channels=channels[0])] |
|
|
|
for ratio, dilations, input_dim, output_dim in zip( |
|
ratios, dilations_list, channels[:-1], channels[1:]): |
|
net.append(upsampling_unit(input_dim, output_dim, ratio)) |
|
for dilation in dilations: |
|
net.append(Residual(dilated_unit(output_dim, dilation))) |
|
|
|
net.append(normalization(post_network_conv(in_channels=output_dim))) |
|
|
|
self.net = nn.Sequential(*net) |
|
|
|
@staticmethod |
|
def normalize_dilations(dilations: Union[Sequence[int], |
|
Sequence[Sequence[int]]], |
|
ratios: Sequence[int]): |
|
if isinstance(dilations[0], int): |
|
dilations = [dilations for _ in ratios] |
|
return dilations |