Erick Garcia Espinosa commited on
Commit
4d70acc
·
1 Parent(s): 7b4760b

improvements

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -6,7 +6,6 @@ from PIL import Image
6
  from timm import create_model
7
  import matplotlib.pyplot as plt
8
 
9
- # Class to index and index to class mappings
10
  class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
11
  idx_to_class = {v: k for k, v in class_to_idx.items()}
12
 
@@ -31,9 +30,16 @@ model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_loca
31
  model.eval()
32
 
33
  # Define the prediction function
34
- def predict_image(image_path):
 
 
 
 
 
 
 
35
  # Load and transform the image from the file path
36
- img_tensor = load_image(image_path)
37
 
38
  # Perform the prediction
39
  with torch.no_grad():
@@ -44,6 +50,9 @@ def predict_image(image_path):
44
  percentages = probabilities * 100
45
  results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
46
 
 
 
 
47
  # Plotting the results
48
  labels = list(results.keys())
49
  values = list(results.values())
@@ -57,18 +66,21 @@ def predict_image(image_path):
57
  plt.savefig('result.png')
58
  plt.close(fig)
59
 
60
- # Return the top prediction as well
61
- top_prediction = max(results, key=results.get)
62
-
63
- return top_prediction, 'result.png'
64
 
65
  # Create the Gradio interface
66
  iface = gr.Interface(
67
  fn=predict_image,
68
- inputs=gr.Image(type="filepath", label="Upload an image"),
69
- outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")],
 
 
 
 
 
 
70
  title="Skin Lesion Image Classification",
71
- description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)",
72
  theme="huggingface",
73
  live=True
74
  )
 
6
  from timm import create_model
7
  import matplotlib.pyplot as plt
8
 
 
9
  class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
10
  idx_to_class = {v: k for k, v in class_to_idx.items()}
11
 
 
30
  model.eval()
31
 
32
  # Define the prediction function
33
+ def predict_image(image):
34
+ if image is None:
35
+ return "No image provided", None
36
+
37
+ # Convert the image to PIL if it's not a filepath
38
+ if isinstance(image, np.ndarray):
39
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
40
+
41
  # Load and transform the image from the file path
42
+ img_tensor = transform(image).unsqueeze(0)
43
 
44
  # Perform the prediction
45
  with torch.no_grad():
 
50
  percentages = probabilities * 100
51
  results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
52
 
53
+ # Get the highest prediction
54
+ predicted_label = max(results, key=results.get)
55
+
56
  # Plotting the results
57
  labels = list(results.keys())
58
  values = list(results.values())
 
66
  plt.savefig('result.png')
67
  plt.close(fig)
68
 
69
+ return predicted_label, 'result.png'
 
 
 
70
 
71
  # Create the Gradio interface
72
  iface = gr.Interface(
73
  fn=predict_image,
74
+ inputs=[
75
+ gr.Image(source="upload", type="pil", tool="editor", label="Upload an image or take a photo"),
76
+ gr.Image(source="webcam", type="pil", tool="editor", label="Take a photo")
77
+ ],
78
+ outputs=[
79
+ gr.Textbox(label="Prediction"),
80
+ gr.Image(label="Prediction Probabilities")
81
+ ],
82
  title="Skin Lesion Image Classification",
83
+ description="Upload an image of a skin lesion to get a prediction with confidence percentages. This model can classify images of skin lesions into one of the following categories: Measles, Chickenpox, Herpes, Melanoma, and Monkeypox. Check out the dataset and paper at: [Link to Dataset](#), [Link to Paper](#)",
84
  theme="huggingface",
85
  live=True
86
  )