ApexID / models /resnet_model.py
Michael Rey
modified resnet_model.py
8e3a2ae
raw
history blame contribute delete
611 Bytes
import torch.nn as nn
from torchvision import models
# Custom ResNet model for monkey classification
class MonkeyResNet(nn.Module):
def __init__(self, num_classes):
super(MonkeyResNet, self).__init__()
# Load pretrained ResNet18 model from torchvision
self.model = models.resnet18(pretrained=False)
# Replace the final fully connected layer to match number of classes
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.model(x) # run input through the model