nssharmaofficial commited on
Commit
a6ade1d
·
verified ·
1 Parent(s): 0fdbe91

Reverse Huggingface model

Browse files
Files changed (1) hide show
  1. source/predict_sample.py +3 -2
source/predict_sample.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import torch.utils.data
4
  from PIL import Image
5
  from source.model import CNN
6
- from transformers import AutoModel
7
 
8
 
9
  def classify_eye(image: torch.Tensor,
@@ -43,9 +42,11 @@ def main_classification(image):
43
  image = transform(image)
44
  image = image.to(torch.device("cpu"))
45
 
46
- cnn = AutoModel.from_pretrained("nssharmaofficial/RedEyeDetector")
47
  cnn.eval()
48
 
 
 
49
  prediction_outcome = classify_eye(image, cnn)
50
 
51
  return prediction_outcome
 
3
  import torch.utils.data
4
  from PIL import Image
5
  from source.model import CNN
 
6
 
7
 
8
  def classify_eye(image: torch.Tensor,
 
42
  image = transform(image)
43
  image = image.to(torch.device("cpu"))
44
 
45
+ cnn = CNN().to(torch.device("cpu"))
46
  cnn.eval()
47
 
48
+ cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu")))
49
+
50
  prediction_outcome = classify_eye(image, cnn)
51
 
52
  return prediction_outcome