oculotest commited on
Commit
d811463
·
verified ·
1 Parent(s): 1144bb3

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -0
model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ class EyeNet(models.ResNet18_Weights, PyTorchModelHubMixin):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.model = models.resnet18()
9
+ self.model.fc = torch.nn.Sequential(
10
+ torch.nn.Linear(512, 256),
11
+ torch.nn.Linear(256, 128),
12
+ torch.nn.Linear(128, 64),
13
+ torch.nn.Linear(64, 5)
14
+ )