""" a modularized deep neural network for 1-d signal data, pytorch version Shenda Hong, Mar 2020 """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel class MyConv1dPadSame(nn.Module): """ extend nn.Conv1d to support SAME padding input: (n_sample, in_channels, n_length) output: (n_sample, out_channels, (n_length+stride-1)//stride) """ def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1): super(MyConv1dPadSame, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.groups = groups self.conv = torch.nn.Conv1d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, groups=self.groups) def forward(self, x): net = x # compute pad shape in_dim = net.shape[-1] out_dim = (in_dim + self.stride - 1) // self.stride p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim) pad_left = p // 2 pad_right = p - pad_left net = F.pad(net, (pad_left, pad_right), "constant", 0) net = self.conv(net) return net class MyMaxPool1dPadSame(nn.Module): """ extend nn.MaxPool1d to support SAME padding params: kernel_size: kernel size stride: the stride of the window. Default value is kernel_size input: (n_sample, n_channel, n_length) """ def __init__(self, kernel_size): super(MyMaxPool1dPadSame, self).__init__() self.kernel_size = kernel_size self.max_pool = torch.nn.MaxPool1d(kernel_size=self.kernel_size) def forward(self, x): net = x # compute pad shape p = max(0, self.kernel_size - 1) pad_left = p // 2 pad_right = p - pad_left net = F.pad(net, (pad_left, pad_right), "constant", 0) net = self.max_pool(net) return net class Swish(nn.Module): def forward(self, x): return x * F.sigmoid(x) class BasicBlock(nn.Module): """ Basic Block: conv1 -> convk -> conv1 params: in_channels: number of input channels out_channels: number of output channels ratio: ratio of channels to out_channels kernel_size: kernel window length stride: kernel step size groups: number of groups in convk downsample: whether downsample length use_bn: whether use batch_norm use_do: whether use dropout input: (n_sample, in_channels, n_length) output: (n_sample, out_channels, (n_length+stride-1)//stride) """ def __init__(self, in_channels, out_channels, ratio, kernel_size, stride, groups, downsample, is_first_block=False, use_bn=True, use_do=True): super(BasicBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.ratio = ratio self.kernel_size = kernel_size self.groups = groups self.downsample = downsample self.stride = stride if self.downsample else 1 self.is_first_block = is_first_block self.use_bn = use_bn self.use_do = use_do self.middle_channels = int(self.out_channels * self.ratio) # the first conv, conv1 self.bn1 = nn.BatchNorm1d(in_channels) self.activation1 = Swish() self.do1 = nn.Dropout(p=0.5) self.conv1 = MyConv1dPadSame( in_channels=self.in_channels, out_channels=self.middle_channels, kernel_size=1, stride=1, groups=1) # the second conv, convk self.bn2 = nn.BatchNorm1d(self.middle_channels) self.activation2 = Swish() self.do2 = nn.Dropout(p=0.5) self.conv2 = MyConv1dPadSame( in_channels=self.middle_channels, out_channels=self.middle_channels, kernel_size=self.kernel_size, stride=self.stride, groups=self.groups) # the third conv, conv1 self.bn3 = nn.BatchNorm1d(self.middle_channels) self.activation3 = Swish() self.do3 = nn.Dropout(p=0.5) self.conv3 = MyConv1dPadSame( in_channels=self.middle_channels, out_channels=self.out_channels, kernel_size=1, stride=1, groups=1) # Squeeze-and-Excitation r = 2 self.se_fc1 = nn.Linear(self.out_channels, self.out_channels//r) self.se_fc2 = nn.Linear(self.out_channels//r, self.out_channels) self.se_activation = Swish() if self.downsample: self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride) def forward(self, x): identity = x out = x # the first conv, conv1 if not self.is_first_block: if self.use_bn: out = self.bn1(out) out = self.activation1(out) if self.use_do: out = self.do1(out) out = self.conv1(out) # the second conv, convk if self.use_bn: out = self.bn2(out) out = self.activation2(out) if self.use_do: out = self.do2(out) out = self.conv2(out) # the third conv, conv1 if self.use_bn: out = self.bn3(out) out = self.activation3(out) if self.use_do: out = self.do3(out) out = self.conv3(out) # (n_sample, n_channel, n_length) # Squeeze-and-Excitation se = out.mean(-1) # (n_sample, n_channel) se = self.se_fc1(se) se = self.se_activation(se) se = self.se_fc2(se) se = F.sigmoid(se) # (n_sample, n_channel) out = torch.einsum('abc,ab->abc', out, se) # if downsample, also downsample identity if self.downsample: identity = self.max_pool(identity) # if expand channel, also pad zeros to identity if self.out_channels != self.in_channels: identity = identity.transpose(-1,-2) ch1 = (self.out_channels-self.in_channels)//2 ch2 = self.out_channels-self.in_channels-ch1 identity = F.pad(identity, (ch1, ch2), "constant", 0) identity = identity.transpose(-1,-2) # shortcut out += identity return out class BasicStage(nn.Module): """ Basic Stage: block_1 -> block_2 -> ... -> block_M """ def __init__(self, in_channels, out_channels, ratio, kernel_size, stride, groups, i_stage, m_blocks, use_bn=True, use_do=True, verbose=False): super(BasicStage, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.ratio = ratio self.kernel_size = kernel_size self.groups = groups self.i_stage = i_stage self.m_blocks = m_blocks self.use_bn = use_bn self.use_do = use_do self.verbose = verbose self.block_list = nn.ModuleList() for i_block in range(self.m_blocks): # first block if self.i_stage == 0 and i_block == 0: self.is_first_block = True else: self.is_first_block = False # downsample, stride, input if i_block == 0: self.downsample = True self.stride = stride self.tmp_in_channels = self.in_channels else: self.downsample = False self.stride = 1 self.tmp_in_channels = self.out_channels # build block tmp_block = BasicBlock( in_channels=self.tmp_in_channels, out_channels=self.out_channels, ratio=self.ratio, kernel_size=self.kernel_size, stride=self.stride, groups=self.groups, downsample=self.downsample, is_first_block=self.is_first_block, use_bn=self.use_bn, use_do=self.use_do) self.block_list.append(tmp_block) def forward(self, x): out = x for i_block in range(self.m_blocks): net = self.block_list[i_block] out = net(out) if self.verbose: print('stage: {}, block: {}, in_channels: {}, out_channels: {}, outshape: {}'.format(self.i_stage, i_block, net.in_channels, net.out_channels, list(out.shape))) print('stage: {}, block: {}, conv1: {}->{} k={} s={} C={}'.format(self.i_stage, i_block, net.conv1.in_channels, net.conv1.out_channels, net.conv1.kernel_size, net.conv1.stride, net.conv1.groups)) print('stage: {}, block: {}, convk: {}->{} k={} s={} C={}'.format(self.i_stage, i_block, net.conv2.in_channels, net.conv2.out_channels, net.conv2.kernel_size, net.conv2.stride, net.conv2.groups)) print('stage: {}, block: {}, conv1: {}->{} k={} s={} C={}'.format(self.i_stage, i_block, net.conv3.in_channels, net.conv3.out_channels, net.conv3.kernel_size, net.conv3.stride, net.conv3.groups)) return out class Net1D(nn.Module): """ Input: X: (n_samples, n_channel, n_length) Y: (n_samples) Output: out: (n_samples) params: in_channels base_filters filter_list: list, filters for each stage m_blocks_list: list, number of blocks of each stage kernel_size stride groups_width n_stages n_classes use_bn use_do """ def __init__(self, in_channels, base_filters, ratio, filter_list, m_blocks_list, kernel_size, stride, groups_width, n_classes=1, use_bn=True, use_do=True, verbose=False): super(Net1D, self).__init__() self.in_channels = in_channels self.base_filters = base_filters self.ratio = ratio self.filter_list = filter_list self.m_blocks_list = m_blocks_list self.kernel_size = kernel_size self.stride = stride self.groups_width = groups_width self.n_stages = len(filter_list) self.n_classes = n_classes self.use_bn = use_bn self.use_do = use_do self.verbose = verbose # first conv self.first_conv = MyConv1dPadSame( in_channels=in_channels, out_channels=self.base_filters, kernel_size=self.kernel_size, stride=2) self.first_bn = nn.BatchNorm1d(base_filters) self.first_activation = Swish() # stages self.stage_list = nn.ModuleList() in_channels = self.base_filters for i_stage in range(self.n_stages): out_channels = self.filter_list[i_stage] m_blocks = self.m_blocks_list[i_stage] tmp_stage = BasicStage( in_channels=in_channels, out_channels=out_channels, ratio=self.ratio, kernel_size=self.kernel_size, stride=self.stride, groups=out_channels//self.groups_width, i_stage=i_stage, m_blocks=m_blocks, use_bn=self.use_bn, use_do=self.use_do, verbose=self.verbose) self.stage_list.append(tmp_stage) in_channels = out_channels # final prediction self.dense = nn.Linear(in_channels, self.n_classes) def forward(self, x): out = x # first conv out = self.first_conv(out) if self.use_bn: out = self.first_bn(out) out = self.first_activation(out) # stages for i_stage in range(self.n_stages): net = self.stage_list[i_stage] out = net(out) # final prediction out = out.mean(-1) out = self.dense(out) return out if __name__ == '__main__': import os from torchinfo import summary os.environ['CUDA_VISIBLE_DEVICES'] = '7' inp_data = torch.randn((2, 1, 100)).cuda() pulse_rate = torch.randn((2, 1)).cuda() model = Net1D( in_channels=1, base_filters=24, ratio=1.0, filter_list = [24, 48, 96, 192], m_blocks_list = [2, 2, 2, 2], kernel_size=13, stride=1, groups_width=12, verbose=False, n_classes=1, ).cuda() out = model(inp_data, pulse_rate) print(out.shape)