lwm / utils /res1dcnn.py
wi-lab's picture
upload side scripts
713dc9d verified
raw
history blame
3.05 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(out_channels)
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(out_channels)
# Shortcut connection to match dimensions when needed
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm1d(out_channels)
)
def forward(self, x):
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x += self.shortcut(residual)
x = F.relu(x)
return x
class ResNet1DCNN(nn.Module):
def __init__(self, input_channels, sequence_length, num_classes):
super(ResNet1DCNN, self).__init__()
# Initial convolution layer
self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm1d(32)
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
# Residual layers
self.layer1 = self._make_layer(32, 32, 2)
self.layer2 = self._make_layer(32, 64, 3)
self.layer3 = self._make_layer(64, 128, 4)
# Calculate the size of the flattened features
with torch.no_grad():
dummy_input = torch.zeros(1, input_channels, sequence_length)
dummy_output = self.compute_conv_output(dummy_input)
self.flatten_size = dummy_output.numel()
# Fully connected layers
self.fc1 = nn.Linear(self.flatten_size, 128)
self.bn_fc1 = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, num_classes)
self.dropout = nn.Dropout(0.5)
def _make_layer(self, in_channels, out_channels, num_blocks):
layers = [ResidualBlock(in_channels, out_channels)]
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels))
return nn.Sequential(*layers)
def compute_conv_output(self, x):
x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = F.adaptive_avg_pool1d(x, 1)
return x
def forward(self, x):
x = x.transpose(1, 2)
x = self.compute_conv_output(x)
x = x.view(x.size(0), -1)
x = F.relu(self.bn_fc1(self.fc1(x)))
x = self.dropout(x)
x = self.fc2(x)
return x