hieupt's picture
Upload conv.py
31f503b verified
from torch import nn as nn
from torch.nn import functional as F
class ConvLayer(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False):
super(ConvLayer, self).__init__()
self.transpose = transpose
self.stride = stride
self.kernel_size = kernel_size
self.conv_type = conv_type
# How many channels should be normalised as one group if GroupNorm is activated
# WARNING: Number of channels has to be divisible by this number!
NORM_CHANNELS = 8
if self.transpose:
self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1)
else:
self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride)
if conv_type == "gn":
assert(n_outputs % NORM_CHANNELS == 0)
self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
elif conv_type == "bn":
self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01)
# Add you own types of variations here!
def forward(self, x):
# Apply the convolution
if self.conv_type == "gn" or self.conv_type == "bn":
out = F.relu(self.norm((self.filter(x))))
else: # Add your own variations here with elifs conditioned on "conv_type" parameter!
assert(self.conv_type == "normal")
out = F.leaky_relu(self.filter(x))
return out
def get_input_size(self, output_size):
# Strided conv/decimation
if not self.transpose:
curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
else:
curr_size = output_size
# Conv
curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1
# Transposed
if self.transpose:
assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
curr_size = ((curr_size - 1) // self.stride) + 1
assert(curr_size > 0)
return curr_size
def get_output_size(self, input_size):
# Transposed
if self.transpose:
assert(input_size > 1)
curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
else:
curr_size = input_size
# Conv
curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1
assert (curr_size > 0)
# Strided conv/decimation
if not self.transpose:
assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end
curr_size = ((curr_size - 1) // self.stride) + 1
return curr_size