import torch import torch.nn as nn import torch.nn.functional as F # https://github.com/samcw/ResNet18-Pytorch class ResBlock(nn.Module): def __init__(self, inchannel, outchannel, stride=1): super(ResBlock, self).__init__() self.left = nn.Sequential( nn.Conv1d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm1d(outchannel), nn.ReLU(inplace=True), nn.Conv1d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm1d(outchannel) ) self.shortcut = nn.Sequential() if stride != 1 or inchannel != outchannel: self.shortcut = nn.Sequential( nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), nn.BatchNorm1d(outchannel) ) def forward(self, x): out = self.left(x) out = out + self.shortcut(x) out = F.relu(out) return out class ResNet18(nn.Module): def __init__(self, args): super(ResNet18, self).__init__() self.inchannel = 64 self.conv1 = nn.Sequential( nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm1d(64), nn.ReLU() ) self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1) self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2) self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2) self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2) self.pred_layer = nn.Sequential( nn.Linear(512, 512), nn.SiLU(), nn.Dropout(p=0.3), nn.Linear(512, 1), ) if getattr(args, 'mean_label', False): self.pred_layer[3].bias.data.fill_(args.mean_label) def make_layer(self, block, channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.inchannel, channels, stride)) self.inchannel = channels return nn.Sequential(*layers) def forward(self, x): x = x.unsqueeze(1) out = self.conv1(x) out = F.max_pool1d(out, 3, 2, 1) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = out.mean(-1) out = self.pred_layer(out) return out