import gradio as gr import torch import numpy as np from torchvision import transforms from PIL import Image from timm import create_model import matplotlib.pyplot as plt # Class to index and index to class mappings class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4} idx_to_class = {v: k for k, v in class_to_idx.items()} # Data transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Function to load and preprocess an image def load_image(image_path): image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0) # Add batch dimension return image # Load the model model_name = 'vit_base_patch16_224' pretrained = True num_classes = len(class_to_idx) model = create_model(model_name, pretrained=pretrained, num_classes=num_classes) model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True)) model.eval() # Define the prediction function def predict_image(image_path): # Load and transform the image from the file path img_tensor = load_image(image_path) # Perform the prediction with torch.no_grad(): output = model(img_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Convert probabilities to percentages percentages = probabilities * 100 results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))} # Plotting the results labels = list(results.keys()) values = list(results.values()) fig, ax = plt.subplots() ax.barh(labels, values, color='skyblue') ax.set_xlabel('Percentage') ax.set_title('Prediction Probabilities') plt.tight_layout() plt.savefig('result.png') plt.close(fig) # Return the top prediction as well top_prediction = max(results, key=results.get) return top_prediction, 'result.png' # Create the Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="filepath", label="Upload an image"), outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")], title="Skin Lesion Image Classification", description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)", theme="huggingface", live=True ) # Launch the Gradio interface iface.launch(share=True)