Spaces:
Sleeping
Sleeping
from torchvision import transforms | |
import torch | |
import torch.utils.data | |
from PIL import Image | |
from source.model import CNN | |
def classify_eye(image: torch.Tensor, | |
model: CNN) -> str: | |
""" | |
Generate caption of a single image of size (3, 224, 224). | |
Generating of caption starts with <sos>, and each next predicted word ID | |
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>. | |
Returns: | |
list[str]: caption for given image | |
""" | |
# image: (3, 32, 32) | |
image = image.unsqueeze(0) | |
# image: (1, 3, 32, 32) | |
output = model.forward(image) | |
_, prediction = torch.max(output, dim=1) | |
if prediction == 0: | |
output = 'Normal' | |
elif prediction == 1: | |
output = 'Red' | |
return output | |
def main_classification(image): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
transform = transforms.Compose([ | |
transforms.Resize((32, 32)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
image = transform(image) | |
image = image.to(torch.device("cpu")) | |
cnn = CNN().to(torch.device("cpu")) | |
cnn.eval() | |
cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu"))) | |
prediction_outcome = classify_eye(image, cnn) | |
return prediction_outcome | |