import torch | |
from torchvision import models | |
from huggingface_hub import PyTorchModelHubMixin | |
class EyeNet(models.ResNet18_Weights, PyTorchModelHubMixin): | |
def __init__(self): | |
super().__init__() | |
self.model = models.resnet18() | |
self.model.fc = torch.nn.Sequential( | |
torch.nn.Linear(512, 256), | |
torch.nn.Linear(256, 128), | |
torch.nn.Linear(128, 64), | |
torch.nn.Linear(64, 5) | |
) | |