File size: 451 Bytes
d811463
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torchvision import models
from huggingface_hub import PyTorchModelHubMixin

class EyeNet(models.ResNet18_Weights, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18()
        self.model.fc = torch.nn.Sequential(
            torch.nn.Linear(512, 256),
            torch.nn.Linear(256, 128),
            torch.nn.Linear(128, 64),
            torch.nn.Linear(64, 5)
        )