Spaces:
Sleeping
Sleeping
File size: 618 Bytes
eb6a8a0 9a18613 b0ae40a eb6a8a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch.nn as nn
from torchvision import models
# Custom ResNet model for monkey classification
class MonkeyResNet(nn.Module):
def __init__(self, num_classes):
super(MonkeyResNet, self).__init__()
# Load pretrained ResNet18 model from torchvision
self.model = models.resnet18(pretrained=False)
# Replace the final fully connected layer to match number of classes
in_features = self.model.fc.in_features
self.model.fc = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.model(x) # run input through the model
|