Spaces:
Running
Running
File size: 4,753 Bytes
a0a2a37 5490f7a a0a2a37 dbfeb08 a0a2a37 1569e58 a0a2a37 1f48633 a0a2a37 16c3a47 a9f2449 a0a2a37 0983653 a0a2a37 16c3a47 a0a2a37 6584f0f a0a2a37 dbfeb08 9fe98d4 dbfeb08 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a 9fe98d4 c1df69a a0a2a37 db44c43 a0a2a37 5beb3d6 a0a2a37 16c3a47 a0a2a37 6584f0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from torch import nn
from model import create_resnet50_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import torch.nn.functional as F
# Setup class names
class_names = ['CRVO',
'Choroidal Nevus',
'Diabetic Retinopathy',
'Laser Spots',
'Macular Degeneration',
'Macular Hole',
'Myelinated Nerve Fiber',
'Normal',
'Pathological Mypoia',
'Retinitis Pigmentosa']
### 2. Model and transforms preparation ###
# Create ResNet50 model
resnet50, resnet50_transforms = create_resnet50_model(
num_classes=len(class_names), # actual value would also work
)
resnet50.fc = nn.Linear(2048, 10)
# Load saved weights
resnet50.load_state_dict(
torch.load(
f="pretrained_resnet50_feature_extractor_drappcompressed.pth",
map_location=torch.device("cpu"), # load to CPU
)
)
### 3. Predict function ###
# Create predict function
# def predict(img) -> Tuple[Dict, float]:
# """Transforms and performs a prediction on img and returns prediction and time taken.
# """
# # Start the timer
# start_time = timer()
# # Transform the target image and add a batch dimension
# img = resnet50_transforms(img).unsqueeze(0)
# # Put model into evaluation mode and turn on inference mode
# resnet50.eval()
# with torch.inference_mode():
# # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
# pred_probs = torch.softmax(resnet50(img), dim=1)
# # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
# pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# # Calculate the prediction time
# pred_time = round(timer() - start_time, 5)
# # Return the prediction dictionary and prediction time
# return pred_labels_and_probs, pred_time
def predict(img):
"""Transforms and performs a prediction on img and returns prediction and time taken."""
start_time = timer()
try:
img = resnet50_transforms(img).unsqueeze(0)
resnet50.eval()
with torch.inference_mode():
pred_probs = torch.softmax(resnet50(img), dim=1)
# Calculate entropy for OOD detection
entropy = -torch.sum(pred_probs * torch.log(pred_probs + 1e-8)).item()
max_prob = torch.max(pred_probs).item()
# Create base prediction dictionary
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# OOD Detection - modify existing probabilities instead of adding new keys
if (max_prob > 0.95 and entropy < 0.2) or entropy > 2.0:
# Boost the probability of the first class and add a marker
pred_labels_and_probs[class_names[0]] = 0.99 # Use existing class
# You could also just print a warning or log it
print("May not be retina scan")
pred_time = round(timer() - start_time, 5)
return pred_labels_and_probs, pred_time
except Exception as e:
# Return dictionary with same structure as normal case
pred_labels_and_probs = {class_names[i]: 0.0 for i in range(len(class_names))}
pred_labels_and_probs[class_names[0]] = 1.0 # Show error in first class
return pred_labels_and_probs, 0.0
### 4. Gradio app ###
# Create title, description and article strings
#title = "DeepFundus 👀"
#description = "A ResNet50 feature extractor computer vision model to classify funduscopic images."
#article = "Created with the help from [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil"), # what are the inputs?
outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list)
#title=title,
#description=description,
#article=article)
# Launch the demo!
demo.launch()
|