RedEyeClassifier / source /predict_sample.py
nssharmaofficial's picture
Add source code and saved weights
3bdf51a
raw
history blame
1.38 kB
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.model import CNN
def classify_eye(image: torch.Tensor,
model: CNN) -> str:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
# image: (3, 32, 32)
image = image.unsqueeze(0)
# image: (1, 3, 32, 32)
output = model.forward(image)
_, prediction = torch.max(output, dim=1)
if prediction == 0:
output = 'Normal'
elif prediction == 1:
output = 'Red'
return output
def main_classification(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = transform(image)
image = image.to(torch.device("cpu"))
cnn = CNN().to(torch.device("cpu"))
cnn.eval()
cnn.load_state_dict(torch.load(f='weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu")))
prediction_outcome = classify_eye(image, cnn)
return prediction_outcome