Erick Garcia Espinosa
improvements
dcda142
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from timm import create_model
import matplotlib.pyplot as plt
# Class to index and index to class mappings
class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
idx_to_class = {v: k for k, v in class_to_idx.items()}
# Data transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Function to load and preprocess an image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
return image
# Load the model
model_name = 'vit_base_patch16_224'
pretrained = True
num_classes = len(class_to_idx)
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
model.eval()
# Define the prediction function
def predict_image(image_path):
# Load and transform the image from the file path
img_tensor = load_image(image_path)
# Perform the prediction
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Convert probabilities to percentages
percentages = probabilities * 100
results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
# Plotting the results
labels = list(results.keys())
values = list(results.values())
fig, ax = plt.subplots()
ax.barh(labels, values, color='skyblue')
ax.set_xlabel('Percentage')
ax.set_title('Prediction Probabilities')
plt.tight_layout()
plt.savefig('result.png')
plt.close(fig)
# Return the top prediction as well
top_prediction = max(results, key=results.get)
return top_prediction, 'result.png'
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath", label="Upload an image"),
outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")],
title="Skin Lesion Image Classification",
description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)",
theme="huggingface",
live=True
)
# Launch the Gradio interface
iface.launch(share=True)