RedEyeClassifier / source /predict_sample.py
nssharmaofficial's picture
Reverse Huggingface model
a6ade1d verified
raw
history blame
1.39 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.model import CNN
def classify_eye(image: torch.Tensor,
model: CNN) -> str:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
# image: (3, 32, 32)
image = image.unsqueeze(0)
# image: (1, 3, 32, 32)
output = model.forward(image)
_, prediction = torch.max(output, dim=1)
if prediction == 0:
output = 'Normal'
elif prediction == 1:
output = 'Red'
return output
def main_classification(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = transform(image)
image = image.to(torch.device("cpu"))
cnn = CNN().to(torch.device("cpu"))
cnn.eval()
cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu")))
prediction_outcome = classify_eye(image, cnn)
return prediction_outcome