Spaces:
Sleeping
Sleeping
File size: 1,387 Bytes
3bdf51a a6ade1d 3bdf51a a6ade1d 3bdf51a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.model import CNN
def classify_eye(image: torch.Tensor,
model: CNN) -> str:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
# image: (3, 32, 32)
image = image.unsqueeze(0)
# image: (1, 3, 32, 32)
output = model.forward(image)
_, prediction = torch.max(output, dim=1)
if prediction == 0:
output = 'Normal'
elif prediction == 1:
output = 'Red'
return output
def main_classification(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = transform(image)
image = image.to(torch.device("cpu"))
cnn = CNN().to(torch.device("cpu"))
cnn.eval()
cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu")))
prediction_outcome = classify_eye(image, cnn)
return prediction_outcome
|