Erick Garcia Espinosa
Add application file and dependencies
1853757
raw
history blame
2.35 kB
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from timm import create_model
import matplotlib.pyplot as plt
class_to_idx = {'Monkeypox': 0, 'Chickenpox': 1, 'Measles': 2, 'Melanoma': 3, 'Herpes': 4}
idx_to_class = {v: k for k, v in class_to_idx.items()}
# Data transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Function to load and preprocess an image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
return image
# Load the model
model_name = 'vit_base_patch16_224'
pretrained = True
num_classes = len(class_to_idx)
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
model.eval()
# Define the prediction function
def predict_image(image_path):
# Load and transform the image from the file path
img_tensor = load_image(image_path)
# Perform the prediction
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Convert probabilities to percentages
percentages = probabilities * 100
results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
# Plotting the results
labels = list(results.keys())
values = list(results.values())
fig, ax = plt.subplots()
ax.barh(labels, values, color='skyblue')
ax.set_xlabel('Percentage')
ax.set_title('Prediction Probabilities')
plt.tight_layout()
plt.savefig('result.png')
plt.close(fig)
return results, 'result.png'
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath", label="Upload an image"),
outputs=[gr.Label(label="Prediction"), gr.Image(label="Prediction Probabilities")],
title="Skin Lesion Image Classification",
description="Upload an image of a skin lesion to get a prediction with confidence percentages. Example images: [Link to Example 1](#), [Link to Example 2](#)",
theme="huggingface",
live=True
)
# Launch the Gradio interface
iface.launch(share=True)