Erick Garcia Espinosa commited on
Commit
1853757
·
1 Parent(s): b73d7a5

Add application file and dependencies

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -4,9 +4,10 @@ import numpy as np
4
  from torchvision import transforms
5
  from PIL import Image
6
  from timm import create_model
 
7
 
8
- # Define class to index mapping dictionary
9
- class_to_idx = {'Monkeypox': 0, 'Melanoma': 1, 'Herpes': 2, 'Measles': 3, 'Chickenpox': 4}
10
 
11
  # Data transformation
12
  transform = transforms.Compose([
@@ -36,21 +37,37 @@ def predict_image(image_path):
36
  # Perform the prediction
37
  with torch.no_grad():
38
  output = model(img_tensor)
39
- _, predicted = torch.max(output, 1)
40
-
41
- # Get the predicted label
42
- predicted_label = list(class_to_idx.keys())[predicted.item()]
 
43
 
44
- return predicted_label
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Create the Gradio interface
47
  iface = gr.Interface(
48
  fn=predict_image,
49
  inputs=gr.Image(type="filepath", label="Upload an image"),
50
- outputs=gr.Label(label="Prediction"),
51
  title="Skin Lesion Image Classification",
52
- description="Upload an image of a skin lesion to get a prediction."
 
 
53
  )
54
 
55
  # Launch the Gradio interface
56
- iface.launch()
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  from timm import create_model
7
+ import matplotlib.pyplot as plt
8
 
9
+ class_to_idx = {'Monkeypox': 0, 'Chickenpox': 1, 'Measles': 2, 'Melanoma': 3, 'Herpes': 4}
10
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
11
 
12
  # Data transformation
13
  transform = transforms.Compose([
 
37
  # Perform the prediction
38
  with torch.no_grad():
39
  output = model(img_tensor)
40
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
41
+
42
+ # Convert probabilities to percentages
43
+ percentages = probabilities * 100
44
+ results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
45
 
46
+ # Plotting the results
47
+ labels = list(results.keys())
48
+ values = list(results.values())
49
+
50
+ fig, ax = plt.subplots()
51
+ ax.barh(labels, values, color='skyblue')
52
+ ax.set_xlabel('Percentage')
53
+ ax.set_title('Prediction Probabilities')
54
+
55
+ plt.tight_layout()
56
+ plt.savefig('result.png')
57
+ plt.close(fig)
58
+
59
+ return results, 'result.png'
60
 
61
  # Create the Gradio interface
62
  iface = gr.Interface(
63
  fn=predict_image,
64
  inputs=gr.Image(type="filepath", label="Upload an image"),
65
+ outputs=[gr.Label(label="Prediction"), gr.Image(label="Prediction Probabilities")],
66
  title="Skin Lesion Image Classification",
67
+ description="Upload an image of a skin lesion to get a prediction with confidence percentages. Example images: [Link to Example 1](#), [Link to Example 2](#)",
68
+ theme="huggingface",
69
+ live=True
70
  )
71
 
72
  # Launch the Gradio interface
73
+ iface.launch(share=True)