Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch import nn as nn | |
from torch.nn import functional as F | |
class Resample1d(nn.Module): | |
def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False): | |
''' | |
Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format | |
:param channels: Number of features C at each time-step | |
:param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance) | |
:param stride: Resampling factor (integer) | |
:param transpose: False for down-, true for upsampling | |
:param padding: Either "reflect" to pad or "valid" to not pad | |
:param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation | |
''' | |
super(Resample1d, self).__init__() | |
self.padding = padding | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.transpose = transpose | |
self.channels = channels | |
cutoff = 0.5 / stride | |
assert(kernel_size > 2) | |
assert ((kernel_size - 1) % 2 == 0) | |
assert(padding == "reflect" or padding == "valid") | |
filter = build_sinc_filter(kernel_size, cutoff) | |
self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable) | |
def forward(self, x): | |
# Pad here if not using transposed conv | |
input_size = x.shape[2] | |
if self.padding != "valid": | |
num_pad = (self.kernel_size-1)//2 | |
out = F.pad(x, (num_pad, num_pad), mode=self.padding) | |
else: | |
out = x | |
# Lowpass filter (+ 0 insertion if transposed) | |
if self.transpose: | |
expected_steps = ((input_size - 1) * self.stride + 1) | |
if self.padding == "valid": | |
expected_steps = expected_steps - self.kernel_size + 1 | |
out = F.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels) | |
diff_steps = out.shape[2] - expected_steps | |
if diff_steps > 0: | |
assert(diff_steps % 2 == 0) | |
out = out[:,:,diff_steps//2:-diff_steps//2] | |
else: | |
assert(input_size % self.stride == 1) | |
out = F.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels) | |
return out | |
def get_output_size(self, input_size): | |
''' | |
Returns the output dimensionality (number of timesteps) for a given input size | |
:param input_size: Number of input time steps (Scalar, each feature is one-dimensional) | |
:return: Output size (scalar) | |
''' | |
assert(input_size > 1) | |
if self.transpose: | |
if self.padding == "valid": | |
return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1 | |
else: | |
return ((input_size - 1) * self.stride + 1) | |
else: | |
assert(input_size % self.stride == 1) # Want to take first and last sample | |
if self.padding == "valid": | |
return input_size - self.kernel_size + 1 | |
else: | |
return input_size | |
def get_input_size(self, output_size): | |
''' | |
Returns the input dimensionality (number of timesteps) for a given output size | |
:param input_size: Number of input time steps (Scalar, each feature is one-dimensional) | |
:return: Output size (scalar) | |
''' | |
# Strided conv/decimation | |
if not self.transpose: | |
curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 | |
else: | |
curr_size = output_size | |
# Conv | |
if self.padding == "valid": | |
curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1 | |
# Transposed | |
if self.transpose: | |
assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end | |
curr_size = ((curr_size - 1) // self.stride) + 1 | |
assert(curr_size > 0) | |
return curr_size | |
def build_sinc_filter(kernel_size, cutoff): | |
# FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf | |
# Sinc lowpass filter | |
# Build sinc kernel | |
assert(kernel_size % 2 == 1) | |
M = kernel_size - 1 | |
filter = np.zeros(kernel_size, dtype=np.float32) | |
for i in range(kernel_size): | |
if i == M//2: | |
filter[i] = 2 * np.pi * cutoff | |
else: | |
filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \ | |
(0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M)) | |
filter = filter / np.sum(filter) | |
return filter |