oculotest's picture
Create model.py
d811463 verified
raw
history blame
451 Bytes
import torch
from torchvision import models
from huggingface_hub import PyTorchModelHubMixin
class EyeNet(models.ResNet18_Weights, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
self.model = models.resnet18()
self.model.fc = torch.nn.Sequential(
torch.nn.Linear(512, 256),
torch.nn.Linear(256, 128),
torch.nn.Linear(128, 64),
torch.nn.Linear(64, 5)
)