|
from torchvision import transforms |
|
import torch |
|
import torch.utils.data |
|
from PIL import Image |
|
from source.model import CNN |
|
from transformers import AutoModel |
|
|
|
|
|
def classify_eye(image: torch.Tensor, |
|
model: CNN) -> str: |
|
""" |
|
Generate caption of a single image of size (3, 224, 224). |
|
Generating of caption starts with <sos>, and each next predicted word ID |
|
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>. |
|
|
|
Returns: |
|
list[str]: caption for given image |
|
""" |
|
|
|
|
|
image = image.unsqueeze(0) |
|
|
|
|
|
output = model.forward(image) |
|
_, prediction = torch.max(output, dim=1) |
|
|
|
if prediction == 0: |
|
output = 'Normal' |
|
elif prediction == 1: |
|
output = 'Red' |
|
|
|
return output |
|
|
|
def main_classification(image): |
|
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
image = transform(image) |
|
image = image.to(torch.device("cpu")) |
|
|
|
cnn = AutoModel.from_pretrained("nssharmaofficial/RedEyeDetector") |
|
cnn.eval() |
|
|
|
prediction_outcome = classify_eye(image, cnn) |
|
|
|
return prediction_outcome |
|
|