import numpy as np import torch.nn as nn class ConvBlock(nn.Module): def __init__(self, f, w, s, d, in_channels): super().__init__() p1 = d*(w - 1) // 2 p2 = d*(w - 1) - p1 self.pad = nn.ZeroPad2d((0, 0, p1, p2)) self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1), dilation=(d, 1)) self.relu = nn.ReLU() self.bn = nn.BatchNorm2d(f) self.pool = nn.MaxPool2d(kernel_size=(2, 1)) self.dropout = nn.Dropout(0.25) def forward(self, x): x = self.pad(x) x = self.conv2d(x) x = self.relu(x) x = self.bn(x) x = self.pool(x) x = self.dropout(x) return x class NoPadConvBlock(nn.Module): def __init__(self, f, w, s, d, in_channels): super().__init__() self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1), dilation=(d, 1)) self.relu = nn.ReLU() self.bn = nn.BatchNorm2d(f) self.pool = nn.MaxPool2d(kernel_size=(2, 1)) self.dropout = nn.Dropout(0.25) def forward(self, x): x = self.conv2d(x) x = self.relu(x) x = self.bn(x) x = self.pool(x) x = self.dropout(x) return x class TinyPathway(nn.Module): def __init__(self, dilation=1, hop=256, localize=False, model_capacity="full", n_layers=6, chunk_size=256): super().__init__() capacity_multiplier = { 'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32 }[model_capacity] self.layers = [1, 2, 3, 4, 5, 6] self.layers = self.layers[:n_layers] filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]] filters = [1] + filters widths = [512, 64, 64, 64, 32, 32] strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize) strides[0] = strides[0]*4 # apply 4 times more stride at the first layer dilations = self.deter_dilations(dilation) for i in range(len(self.layers)): f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i] self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel)) self.chunk_size = chunk_size self.input_window, self.hop = self.find_input_size_for_pathway() self.out_dim = filters[n_layers] def find_input_size_for_pathway(self): def find_input_size(output_size, kernel_size, stride, dilation, padding): num = (stride*(output_size-1)) + 1 input_size = num - 2*padding + dilation*(kernel_size-1) return input_size conv_calc, n = {}, 0 for i in self.layers: layer = self.__getattr__("conv%d" % (i-1)) for mm in layer.modules(): if hasattr(mm, 'kernel_size'): try: d = mm.dilation[0] except TypeError: d = mm.dilation conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d] n += 1 out = self.chunk_size hop = 1 for n in sorted(conv_calc.keys())[::-1]: kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n] out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n) hop = hop*stride_n return out, hop def deter_dilations(self, total_dilation, localize=False): n_layers = len(self.layers) if localize: # e.g., 32*1023 window and 3 layers -> [1, 1, 32] a = [total_dilation] + [1 for _ in range(n_layers-1)] else: # e.g., 32*1023 window and 3 layers -> [4, 4, 2] total_dilation = int(np.log2(total_dilation)) a = [] for layer in range(n_layers): this_dilation = int(np.ceil(total_dilation/(n_layers-layer))) a.append(2**this_dilation) total_dilation = total_dilation - this_dilation return a[::-1] def forward(self, x): x = x.view(x.shape[0], 1, -1, 1) for i in range(len(self.layers)): x = self.__getattr__("conv%d" % i)(x) x = x.permute(0, 3, 2, 1) return x