Erick Garcia Espinosa commited on
Commit
179adf5
·
1 Parent(s): b009213

improvements

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from timm import create_model
7
  import matplotlib.pyplot as plt
8
 
 
9
  class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
10
  idx_to_class = {v: k for k, v in class_to_idx.items()}
11
 
@@ -16,11 +17,8 @@ transform = transforms.Compose([
16
  ])
17
 
18
  # Function to load and preprocess an image
19
- def load_image(image):
20
- if isinstance(image, str): # If image is a file path
21
- image = Image.open(image).convert('RGB')
22
- elif isinstance(image, np.ndarray): # If image is a numpy array
23
- image = Image.fromarray(image.astype('uint8'), 'RGB')
24
  image = transform(image).unsqueeze(0) # Add batch dimension
25
  return image
26
 
@@ -33,12 +31,9 @@ model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_loca
33
  model.eval()
34
 
35
  # Define the prediction function
36
- def predict_image(image):
37
- if image is None:
38
- return "No image provided", None
39
-
40
- # Load and transform the image
41
- img_tensor = load_image(image)
42
 
43
  # Perform the prediction
44
  with torch.no_grad():
@@ -49,9 +44,6 @@ def predict_image(image):
49
  percentages = probabilities * 100
50
  results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
51
 
52
- # Get the highest prediction
53
- predicted_label = max(results, key=results.get)
54
-
55
  # Plotting the results
56
  labels = list(results.keys())
57
  values = list(results.values())
@@ -65,21 +57,18 @@ def predict_image(image):
65
  plt.savefig('result.png')
66
  plt.close(fig)
67
 
68
- return predicted_label, 'result.png'
 
 
 
69
 
70
  # Create the Gradio interface
71
  iface = gr.Interface(
72
  fn=predict_image,
73
- inputs=[
74
- gr.Image(type="pil", tool="editor", label="Upload an image or take a photo"),
75
- gr.Image(source="webcam", type="pil", tool="editor", label="Take a photo")
76
- ],
77
- outputs=[
78
- gr.Textbox(label="Prediction"),
79
- gr.Image(label="Prediction Probabilities")
80
- ],
81
  title="Skin Lesion Image Classification",
82
- description="Upload an image of a skin lesion to get a prediction with confidence percentages. This model can classify images of skin lesions into one of the following categories: Measles, Chickenpox, Herpes, Melanoma, and Monkeypox. Check out the dataset and paper at: [Link to Dataset](#), [Link to Paper](#)",
83
  theme="huggingface",
84
  live=True
85
  )
 
6
  from timm import create_model
7
  import matplotlib.pyplot as plt
8
 
9
+ # Class to index and index to class mappings
10
  class_to_idx = {'Monkeypox': 0, 'Measles': 1, 'Chickenpox': 2, 'Herpes': 3, 'Melanoma': 4}
11
  idx_to_class = {v: k for k, v in class_to_idx.items()}
12
 
 
17
  ])
18
 
19
  # Function to load and preprocess an image
20
+ def load_image(image_path):
21
+ image = Image.open(image_path).convert('RGB')
 
 
 
22
  image = transform(image).unsqueeze(0) # Add batch dimension
23
  return image
24
 
 
31
  model.eval()
32
 
33
  # Define the prediction function
34
+ def predict_image(image_path):
35
+ # Load and transform the image from the file path
36
+ img_tensor = load_image(image_path)
 
 
 
37
 
38
  # Perform the prediction
39
  with torch.no_grad():
 
44
  percentages = probabilities * 100
45
  results = {idx_to_class[i]: percentages[i].item() for i in range(len(idx_to_class))}
46
 
 
 
 
47
  # Plotting the results
48
  labels = list(results.keys())
49
  values = list(results.values())
 
57
  plt.savefig('result.png')
58
  plt.close(fig)
59
 
60
+ # Return the top prediction as well
61
+ top_prediction = max(results, key=results.get)
62
+
63
+ return top_prediction, 'result.png'
64
 
65
  # Create the Gradio interface
66
  iface = gr.Interface(
67
  fn=predict_image,
68
+ inputs=gr.Image(type="filepath", label="Upload an image"),
69
+ outputs=[gr.Textbox(label="Prediction"), gr.Image(label="Prediction Probabilities")],
 
 
 
 
 
 
70
  title="Skin Lesion Image Classification",
71
+ description="Upload an image of a skin lesion to get a prediction. This tool helps to classify images of skin lesions into the following categories: Measles, Chickenpox, Herpes, Melanomas, and Monkeypox. Check out the dataset and paper at: [Link to Example 1](#), [Link to Example 2](#)",
72
  theme="huggingface",
73
  live=True
74
  )