Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from torchvision.models import ResNet | |
class ResNetClassifier(nn.Module): | |
def __init__(self, model: ResNet, num_classes: int = 2): | |
super().__init__() | |
self.num_classes = num_classes | |
self.model = model | |
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) | |