Spaces:
Sleeping
Sleeping
File size: 2,625 Bytes
685396e 1853757 685396e 179adf5 010feb3 1853757 685396e b73d7a5 685396e b73d7a5 179adf5 b73d7a5 685396e b73d7a5 685396e 6b394c3 685396e b73d7a5 179adf5 685396e b73d7a5 685396e 1853757 685396e 1853757 179adf5 685396e b73d7a5 685396e dcda142 179adf5 b73d7a5 179adf5 1853757 685396e b73d7a5 1853757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from timm import create_model
import matplotlib.pyplot as plt
# Class to index and index to class mappings
class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
idx_to_class = {v: k for k, v in class_to_idx.items()}
# Data transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Function to load and preprocess an image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
return image
# Load the model
model_name = 'vit_base_patch16_224'
pretrained = True
num_classes = len(class_to_idx)
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
model.eval()
# Define the prediction function
def predict_image(image_path):
# Load and transform the image from the file path
img_tensor = load_image(image_path)
# Perform the prediction
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Convert probabilities to percentages
percentages = probabilities * 100
results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
# Plotting the results
labels = list(results.keys())
values = list(results.values())
fig, ax = plt.subplots()
ax.barh(labels, values, color='skyblue')
ax.set_xlabel('Percentage')
ax.set_title('Prediction Probabilities')
plt.tight_layout()
plt.savefig('result.png')
plt.close(fig)
# Return the top prediction as well
top_prediction = max(results, key=results.get)
return top_prediction, 'result.png'
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath", label="Upload an image"),
outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")],
title="Skin Lesion Image Classification",
description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)",
theme="huggingface",
live=True
)
# Launch the Gradio interface
iface.launch(share=True)
|