Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from torchvision import models | |
# Custom ResNet model for monkey classification | |
class MonkeyResNet(nn.Module): | |
def __init__(self, num_classes): | |
super(MonkeyResNet, self).__init__() | |
# Load pretrained ResNet18 model from torchvision | |
self.model = models.resnet18(pretrained=False) | |
# Replace the final fully connected layer to match number of classes | |
in_features = self.model.fc.in_features | |
self.model.fc = nn.Linear(in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) # run input through the model | |