Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from torchvision import models | |
class CustomResNet18(nn.Module): | |
def __init__(self, num_classes=11): | |
super(CustomResNet18, self).__init__() | |
self.resnet = models.resnet18(pretrained=True) | |
num_features = self.resnet.fc.in_features | |
self.resnet.fc = nn.Linear(num_features, num_classes) | |
def forward(self, x): | |
return self.resnet(x) | |