import torch import torch.nn as nn from torchvision import models class Stem(nn.Module): def __init__(self): super(Stem, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2), nn.MaxPool2d(kernel_size=3, stride=2), ) def forward(self, x): x = self.conv(x) return x class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channels, out_channels // 4, stride=1, kernel_size=1), nn.BatchNorm2d(out_channels // 4), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d( out_channels // 4, out_channels // 4, stride=stride, kernel_size=3, padding=1, ), nn.BatchNorm2d(out_channels // 4), nn.ReLU(inplace=True), ) self.conv3 = nn.Sequential( nn.Conv2d(out_channels // 4, out_channels, kernel_size=1, stride=1), nn.BatchNorm2d(out_channels), ) self.shortcut = ( nn.Identity() if in_channels == out_channels else nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm2d(out_channels), ) ) self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = self.shortcut(x) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x += identity x = self.relu(x) return x def make_layer(in_channels, out_channels, block, num_blocks): layers = [] for i in range(num_blocks): layers.append(block(in_channels, out_channels)) in_channels = out_channels return layers class FromZero(nn.Module): def __init__(self, num_classes=10): super(FromZero, self).__init__() self.stem = Stem() self.layer1 = nn.Sequential(*make_layer(64, 64, ResidualBlock, 2)) self.layer2 = nn.Sequential( ResidualBlock(64, 128, stride=2), ResidualBlock(128, 128) ) self.layer3 = nn.Sequential( ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256) ) self.layer4 = nn.Sequential( ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512) ) self.flatten = nn.Flatten() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512, num_classes) def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return x class PreTrained(nn.Module): def __init__(self, num_classes): super().__init__() self.model = models.resnet18( weights=models.ResNet18_Weights.IMAGENET1K_V1, progress=True ) for param in self.model.parameters(): param.requires_grad = False self.model.fc = nn.Sequential( nn.Linear(self.model.fc.in_features, 512), nn.ReLU(inplace=True), nn.Linear(512, num_classes), ) def forward(self, x): return self.model(x)