Spaces:
Runtime error
Runtime error
from torch import nn as nn | |
from torch.nn import functional as F | |
class ConvLayer(nn.Module): | |
def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False): | |
super(ConvLayer, self).__init__() | |
self.transpose = transpose | |
self.stride = stride | |
self.kernel_size = kernel_size | |
self.conv_type = conv_type | |
# How many channels should be normalised as one group if GroupNorm is activated | |
# WARNING: Number of channels has to be divisible by this number! | |
NORM_CHANNELS = 8 | |
if self.transpose: | |
self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1) | |
else: | |
self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride) | |
if conv_type == "gn": | |
assert(n_outputs % NORM_CHANNELS == 0) | |
self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs) | |
elif conv_type == "bn": | |
self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01) | |
# Add you own types of variations here! | |
def forward(self, x): | |
# Apply the convolution | |
if self.conv_type == "gn" or self.conv_type == "bn": | |
out = F.relu(self.norm((self.filter(x)))) | |
else: # Add your own variations here with elifs conditioned on "conv_type" parameter! | |
assert(self.conv_type == "normal") | |
out = F.leaky_relu(self.filter(x)) | |
return out | |
def get_input_size(self, output_size): | |
# Strided conv/decimation | |
if not self.transpose: | |
curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 | |
else: | |
curr_size = output_size | |
# Conv | |
curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1 | |
# Transposed | |
if self.transpose: | |
assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end | |
curr_size = ((curr_size - 1) // self.stride) + 1 | |
assert(curr_size > 0) | |
return curr_size | |
def get_output_size(self, input_size): | |
# Transposed | |
if self.transpose: | |
assert(input_size > 1) | |
curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 | |
else: | |
curr_size = input_size | |
# Conv | |
curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1 | |
assert (curr_size > 0) | |
# Strided conv/decimation | |
if not self.transpose: | |
assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end | |
curr_size = ((curr_size - 1) // self.stride) + 1 | |
return curr_size |