File size: 2,625 Bytes
685396e
 
 
 
 
 
1853757
685396e
179adf5
010feb3
1853757
685396e
b73d7a5
685396e
 
 
 
 
b73d7a5
179adf5
 
b73d7a5
685396e
 
b73d7a5
685396e
 
 
 
6b394c3
685396e
 
b73d7a5
179adf5
 
 
685396e
b73d7a5
685396e
 
1853757
 
 
 
 
685396e
1853757
 
 
 
 
 
 
 
 
 
 
 
 
179adf5
 
 
 
685396e
b73d7a5
685396e
 
dcda142
179adf5
b73d7a5
179adf5
1853757
 
685396e
 
b73d7a5
1853757
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from timm import create_model
import matplotlib.pyplot as plt

# Class to index and index to class mappings
class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
idx_to_class = {v: k for k, v in class_to_idx.items()}

# Data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Function to load and preprocess an image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Load the model
model_name = 'vit_base_patch16_224'
pretrained = True
num_classes = len(class_to_idx)
model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
model.eval()

# Define the prediction function
def predict_image(image_path):
    # Load and transform the image from the file path
    img_tensor = load_image(image_path)
    
    # Perform the prediction
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    # Convert probabilities to percentages
    percentages = probabilities * 100
    results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
    
    # Plotting the results
    labels = list(results.keys())
    values = list(results.values())

    fig, ax = plt.subplots()
    ax.barh(labels, values, color='skyblue')
    ax.set_xlabel('Percentage')
    ax.set_title('Prediction Probabilities')

    plt.tight_layout()
    plt.savefig('result.png')
    plt.close(fig)

    # Return the top prediction as well
    top_prediction = max(results, key=results.get)
    
    return top_prediction, 'result.png'

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="filepath", label="Upload an image"),
    outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")],
    title="Skin Lesion Image Classification",
    description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)",
    theme="huggingface",
    live=True
)

# Launch the Gradio interface
iface.launch(share=True)