IlayMalinyak
moved filed to util
1379e6f
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv1d(in_channels=args.encoder_dim,
out_channels=args.encoder_dim,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.encoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
return self.layers(x).transpose(1, 2)
class ConvBlockDecoder(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv1d(in_channels=args.decoder_dim,
out_channels=args.decoder_dim,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.decoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
return self.layers(x).transpose(1, 2)
class ResNetLayer(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.conv_layer = nn.Sequential(
nn.Conv1d(in_channels=args.encoder_dim,
out_channels=args.encoder_dim,
kernel_size=3,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=args.encoder_dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv_layer(x)+x
class ResNetBlock(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)