hieupt commited on
Commit
31f503b
·
verified ·
1 Parent(s): cadff64

Upload conv.py

Browse files
Files changed (1) hide show
  1. model/conv.py +72 -0
model/conv.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class ConvLayer(nn.Module):
6
+ def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False):
7
+ super(ConvLayer, self).__init__()
8
+ self.transpose = transpose
9
+ self.stride = stride
10
+ self.kernel_size = kernel_size
11
+ self.conv_type = conv_type
12
+
13
+ # How many channels should be normalised as one group if GroupNorm is activated
14
+ # WARNING: Number of channels has to be divisible by this number!
15
+ NORM_CHANNELS = 8
16
+
17
+ if self.transpose:
18
+ self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1)
19
+ else:
20
+ self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride)
21
+
22
+ if conv_type == "gn":
23
+ assert(n_outputs % NORM_CHANNELS == 0)
24
+ self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
25
+ elif conv_type == "bn":
26
+ self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01)
27
+ # Add you own types of variations here!
28
+
29
+ def forward(self, x):
30
+ # Apply the convolution
31
+ if self.conv_type == "gn" or self.conv_type == "bn":
32
+ out = F.relu(self.norm((self.filter(x))))
33
+ else: # Add your own variations here with elifs conditioned on "conv_type" parameter!
34
+ assert(self.conv_type == "normal")
35
+ out = F.leaky_relu(self.filter(x))
36
+ return out
37
+
38
+ def get_input_size(self, output_size):
39
+ # Strided conv/decimation
40
+ if not self.transpose:
41
+ curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
42
+ else:
43
+ curr_size = output_size
44
+
45
+ # Conv
46
+ curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1
47
+
48
+ # Transposed
49
+ if self.transpose:
50
+ assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
51
+ curr_size = ((curr_size - 1) // self.stride) + 1
52
+ assert(curr_size > 0)
53
+ return curr_size
54
+
55
+ def get_output_size(self, input_size):
56
+ # Transposed
57
+ if self.transpose:
58
+ assert(input_size > 1)
59
+ curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
60
+ else:
61
+ curr_size = input_size
62
+
63
+ # Conv
64
+ curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1
65
+ assert (curr_size > 0)
66
+
67
+ # Strided conv/decimation
68
+ if not self.transpose:
69
+ assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end
70
+ curr_size = ((curr_size - 1) // self.stride) + 1
71
+
72
+ return curr_size