from torchvision import transforms import torch import torch.utils.data from PIL import Image from source.model import CNN def classify_eye(image: torch.Tensor, model: CNN) -> str: """ Generate caption of a single image of size (3, 224, 224). Generating of caption starts with , and each next predicted word ID is appended for the next LSTM input until the sentence reaches MAX_LENGTH or . Returns: list[str]: caption for given image """ # image: (3, 32, 32) image = image.unsqueeze(0) # image: (1, 3, 32, 32) output = model.forward(image) _, prediction = torch.max(output, dim=1) if prediction == 0: output = 'Normal' elif prediction == 1: output = 'Red' return output def main_classification(image): image = Image.fromarray(image.astype('uint8'), 'RGB') transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) image = transform(image) image = image.to(torch.device("cpu")) cnn = CNN().to(torch.device("cpu")) cnn.eval() cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu"))) prediction_outcome = classify_eye(image, cnn) return prediction_outcome