|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from RepCodec.repcodec.layers.conv_layer import Conv1d, Conv1d1x1 |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size=3, |
|
dilation=1, |
|
bias=False, |
|
nonlinear_activation="ELU", |
|
nonlinear_activation_params={}, |
|
): |
|
super().__init__() |
|
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) |
|
self.conv1 = Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
self.conv2 = Conv1d1x1(out_channels, out_channels, bias) |
|
|
|
def forward(self, x): |
|
y = self.conv1(self.activation(x)) |
|
y = self.conv2(self.activation(y)) |
|
return x + y |
|
|