hieupt commited on
Commit
57599d7
·
verified ·
1 Parent(s): 5378323

Upload resample.py

Browse files
Files changed (1) hide show
  1. model/resample.py +119 -0
model/resample.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ class Resample1d(nn.Module):
7
+ def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False):
8
+ '''
9
+ Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format
10
+ :param channels: Number of features C at each time-step
11
+ :param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance)
12
+ :param stride: Resampling factor (integer)
13
+ :param transpose: False for down-, true for upsampling
14
+ :param padding: Either "reflect" to pad or "valid" to not pad
15
+ :param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation
16
+ '''
17
+ super(Resample1d, self).__init__()
18
+
19
+ self.padding = padding
20
+ self.kernel_size = kernel_size
21
+ self.stride = stride
22
+ self.transpose = transpose
23
+ self.channels = channels
24
+
25
+ cutoff = 0.5 / stride
26
+
27
+ assert(kernel_size > 2)
28
+ assert ((kernel_size - 1) % 2 == 0)
29
+ assert(padding == "reflect" or padding == "valid")
30
+
31
+ filter = build_sinc_filter(kernel_size, cutoff)
32
+
33
+ self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable)
34
+
35
+ def forward(self, x):
36
+ # Pad here if not using transposed conv
37
+ input_size = x.shape[2]
38
+ if self.padding != "valid":
39
+ num_pad = (self.kernel_size-1)//2
40
+ out = F.pad(x, (num_pad, num_pad), mode=self.padding)
41
+ else:
42
+ out = x
43
+
44
+ # Lowpass filter (+ 0 insertion if transposed)
45
+ if self.transpose:
46
+ expected_steps = ((input_size - 1) * self.stride + 1)
47
+ if self.padding == "valid":
48
+ expected_steps = expected_steps - self.kernel_size + 1
49
+
50
+ out = F.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
51
+ diff_steps = out.shape[2] - expected_steps
52
+ if diff_steps > 0:
53
+ assert(diff_steps % 2 == 0)
54
+ out = out[:,:,diff_steps//2:-diff_steps//2]
55
+ else:
56
+ assert(input_size % self.stride == 1)
57
+ out = F.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
58
+
59
+ return out
60
+
61
+ def get_output_size(self, input_size):
62
+ '''
63
+ Returns the output dimensionality (number of timesteps) for a given input size
64
+ :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
65
+ :return: Output size (scalar)
66
+ '''
67
+ assert(input_size > 1)
68
+ if self.transpose:
69
+ if self.padding == "valid":
70
+ return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1
71
+ else:
72
+ return ((input_size - 1) * self.stride + 1)
73
+ else:
74
+ assert(input_size % self.stride == 1) # Want to take first and last sample
75
+ if self.padding == "valid":
76
+ return input_size - self.kernel_size + 1
77
+ else:
78
+ return input_size
79
+
80
+ def get_input_size(self, output_size):
81
+ '''
82
+ Returns the input dimensionality (number of timesteps) for a given output size
83
+ :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
84
+ :return: Output size (scalar)
85
+ '''
86
+
87
+ # Strided conv/decimation
88
+ if not self.transpose:
89
+ curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
90
+ else:
91
+ curr_size = output_size
92
+
93
+ # Conv
94
+ if self.padding == "valid":
95
+ curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1
96
+
97
+ # Transposed
98
+ if self.transpose:
99
+ assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
100
+ curr_size = ((curr_size - 1) // self.stride) + 1
101
+ assert(curr_size > 0)
102
+ return curr_size
103
+
104
+ def build_sinc_filter(kernel_size, cutoff):
105
+ # FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf
106
+ # Sinc lowpass filter
107
+ # Build sinc kernel
108
+ assert(kernel_size % 2 == 1)
109
+ M = kernel_size - 1
110
+ filter = np.zeros(kernel_size, dtype=np.float32)
111
+ for i in range(kernel_size):
112
+ if i == M//2:
113
+ filter[i] = 2 * np.pi * cutoff
114
+ else:
115
+ filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \
116
+ (0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M))
117
+
118
+ filter = filter / np.sum(filter)
119
+ return filter