RedEyeClassifier / source /predict_sample.py
nssharmaofficial's picture
Try with Hugginface model
0fdbe91 verified
raw
history blame
1.34 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.model import CNN
from transformers import AutoModel
def classify_eye(image: torch.Tensor,
model: CNN) -> str:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
# image: (3, 32, 32)
image = image.unsqueeze(0)
# image: (1, 3, 32, 32)
output = model.forward(image)
_, prediction = torch.max(output, dim=1)
if prediction == 0:
output = 'Normal'
elif prediction == 1:
output = 'Red'
return output
def main_classification(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = transform(image)
image = image.to(torch.device("cpu"))
cnn = AutoModel.from_pretrained("nssharmaofficial/RedEyeDetector")
cnn.eval()
prediction_outcome = classify_eye(image, cnn)
return prediction_outcome