Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
from timm import create_model | |
import matplotlib.pyplot as plt | |
# Class to index and index to class mappings | |
class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4} | |
idx_to_class = {v: k for k, v in class_to_idx.items()} | |
# Data transformation | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
# Function to load and preprocess an image | |
def load_image(image_path): | |
image = Image.open(image_path).convert('RGB') | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
return image | |
# Load the model | |
model_name = 'vit_base_patch16_224' | |
pretrained = True | |
num_classes = len(class_to_idx) | |
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes) | |
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True)) | |
model.eval() | |
# Define the prediction function | |
def predict_image(image_path): | |
# Load and transform the image from the file path | |
img_tensor = load_image(image_path) | |
# Perform the prediction | |
with torch.no_grad(): | |
output = model(img_tensor) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
# Convert probabilities to percentages | |
percentages = probabilities * 100 | |
results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))} | |
# Plotting the results | |
labels = list(results.keys()) | |
values = list(results.values()) | |
fig, ax = plt.subplots() | |
ax.barh(labels, values, color='skyblue') | |
ax.set_xlabel('Percentage') | |
ax.set_title('Prediction Probabilities') | |
plt.tight_layout() | |
plt.savefig('result.png') | |
plt.close(fig) | |
# Return the top prediction as well | |
top_prediction = max(results, key=results.get) | |
return top_prediction, 'result.png' | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="filepath", label="Upload an image"), | |
outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")], | |
title="Skin Lesion Image Classification", | |
description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)", | |
theme="huggingface", | |
live=True | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True) | |