Spaces:
Sleeping
Sleeping
# Load model | |
import torch | |
import torchvision | |
import os | |
import gradio as gr | |
from torchvision import transforms | |
from model import create_effnet | |
from typing import Tuple, Dict | |
from timeit import default_timer as timer | |
# Device agnostic code | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
elif torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
class_name = ["NORMAL", "COVID"] | |
EffNetB0_load_model, EffNetB0_transforms = create_effnet( | |
pretrained_weights=torchvision.models.EfficientNet_B0_Weights.DEFAULT, | |
model=torchvision.models.efficientnet_b0, | |
in_features=1280, | |
dropout=0.2, | |
out_features=len(class_name), | |
device="cpu", | |
) | |
# Write a transform for image | |
data_transform = transforms.Compose( | |
[ | |
# Resize our images to 64x64 | |
transforms.Resize(size=(64, 64)), | |
# Flip the images randomly on the horizontal | |
transforms.RandomHorizontalFlip(p=0.5), | |
# Turns image into grayscale | |
transforms.Grayscale(num_output_channels=3), | |
# Turn the image into a torch.Tensor | |
transforms.ToTensor() | |
# Permute the channel height and width | |
] | |
) | |
EffNetB0_load_model.load_state_dict( | |
torch.load("./EffNetB0_data_auto_10_epochs.pth", map_location=torch.device("cpu")) | |
) | |
### Predict function ---------------------------------------------------- ### | |
def predict(img) -> Tuple[Dict, float]: | |
# Start a timer | |
start_time = timer() | |
class_names = ["normal", "covid"] | |
# Transform the input image for use with ViT Model | |
img = EffNetB0_transforms(img).unsqueeze( | |
0 | |
) # unsqueeze = add batch dimension on 0th index (3, 224, 224) into (1, 3, 224, 224) | |
# Put model into eval mode, make prediction | |
EffNetB0_load_model.eval() | |
with torch.inference_mode(): | |
# Pass transformed image through the model and turn the prediction logits into probabilities | |
pred_logits = EffNetB0_load_model(img) | |
pred_probs = torch.softmax(pred_logits, dim=1) | |
# Create a prediction label and prediction probability dictionary | |
pred_labels_and_probs = { | |
class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names)) | |
} | |
# Calculate pred time | |
end_timer = timer() | |
pred_time = round(end_timer - start_time, 4) | |
# Return pred dict and pred time | |
return pred_labels_and_probs, pred_time | |
# Create title and description | |
title = "Covid Prediction: EfficientNetB0 Model" | |
description = ( | |
"An EfficientNet model trained on Covid-19 Dataset to classify X-RAY images" | |
) | |
# Create example list | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create the Gradio demo | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Label(num_top_classes=2, label="Predictions"), | |
gr.Number(label="Prediction time(s)"), | |
], | |
title=title, | |
description=description, | |
examples=example_list, | |
) | |
demo.launch() | |