ZDPLI commited on
Commit
a05049f
·
verified ·
1 Parent(s): 0c9da96

Update app.py

Browse files

added model update capability based on known diagnoses from specilists

Files changed (1) hide show
  1. app.py +83 -22
app.py CHANGED
@@ -7,11 +7,18 @@ import plotly.graph_objects as go
7
  import torch
8
  import numpy as np
9
  from PIL import Image
10
-
11
  model_name = "./best_model"
12
  processor = ViTImageProcessor.from_pretrained(model_name)
13
  labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']
14
 
 
 
 
 
 
 
 
 
15
  class ViTForImageClassificationWithAttention(ViTForImageClassification):
16
  def forward(self, pixel_values, output_attentions=True):
17
  outputs = super().forward(pixel_values, output_attentions=output_attentions)
@@ -20,9 +27,11 @@ class ViTForImageClassificationWithAttention(ViTForImageClassification):
20
 
21
  model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
22
 
 
23
  def classify_image(image):
24
- img = Image.open(image)
25
- inputs = processor(images=img, return_tensors="pt")
 
26
  outputs, attention = model(**inputs, output_attentions=True)
27
  logits = outputs.logits
28
  probs = torch.nn.functional.softmax(logits, dim=1)
@@ -44,13 +53,11 @@ def classify_image(image):
44
  data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
45
  fig_heatmap.update_layout(title="Attention Heatmap")
46
  else:
47
- fig_heatmap = None
48
 
49
- # Overlay the attention heatmap on the input image
50
- # Overlay the attention heatmap on the input image
51
  # Overlay the attention heatmap on the input image
52
  if attention is not None:
53
- img_array = np.array(img)
54
  heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
55
  heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
56
  heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
@@ -60,26 +67,80 @@ def classify_image(image):
60
  heatmap_color[:, :, 1] = heatmap # Green channel
61
  heatmap_color[:, :, 2] = 0 # Blue channel
62
  attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
63
- img = Image.fromarray(attention_overlay)
64
- img.save("attention_overlay.png")
65
  attention_overlay = gr.Image("attention_overlay.png")
66
  else:
67
- attention_overlay = None
68
 
69
  # Return the predicted label, the bar chart, and the heatmap
70
- return predicted_class_label, fig_bar, attention_overlay
71
-
72
- demo = gr.Interface(
73
- fn=classify_image,
74
- inputs=gr.Image(type="filepath"),
75
- outputs=[
76
- gr.Label(label="Predicted Class Label"),
77
- gr.Plot(label="Top 5 Predicted Labels with Probabilities"),
78
- gr.Image(label="Attention Overlay")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ],
80
- title="Dermatological Image Classification Demo",
81
- description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention overlay",
82
  )
83
 
84
  if __name__ == "__main__":
85
- demo.launch()
 
7
  import torch
8
  import numpy as np
9
  from PIL import Image
 
10
  model_name = "./best_model"
11
  processor = ViTImageProcessor.from_pretrained(model_name)
12
  labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']
13
 
14
+ class ViTForImageClassificationWithAttention(ViTForImageClassification):
15
+ def forward(self, pixel_values):
16
+ outputs = super().forward(pixel_values)
17
+ attention = self.vit.encoder.layers[0].attention.attention_weights
18
+ return outputs, attention
19
+
20
+ model = ViTForImageClassificationWithAttention.from_pretrained(model_name)
21
+
22
  class ViTForImageClassificationWithAttention(ViTForImageClassification):
23
  def forward(self, pixel_values, output_attentions=True):
24
  outputs = super().forward(pixel_values, output_attentions=output_attentions)
 
27
 
28
  model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
29
 
30
+ i_count = 0
31
  def classify_image(image):
32
+ model_name = "best_model.pth"
33
+ model.load_state_dict(torch.load(model_name))
34
+ inputs = processor(images=image, return_tensors="pt")
35
  outputs, attention = model(**inputs, output_attentions=True)
36
  logits = outputs.logits
37
  probs = torch.nn.functional.softmax(logits, dim=1)
 
53
  data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
54
  fig_heatmap.update_layout(title="Attention Heatmap")
55
  else:
56
+ fig_heatmap = go.Figure() # Return an empty plot
57
 
 
 
58
  # Overlay the attention heatmap on the input image
59
  if attention is not None:
60
+ img_array = np.array(image)
61
  heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
62
  heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
63
  heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
 
67
  heatmap_color[:, :, 1] = heatmap # Green channel
68
  heatmap_color[:, :, 2] = 0 # Blue channel
69
  attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
70
+ attention_overlay = Image.fromarray(attention_overlay)
71
+ attention_overlay.save("attention_overlay.png")
72
  attention_overlay = gr.Image("attention_overlay.png")
73
  else:
74
+ attention_overlay = gr.Image() # Return an empty image
75
 
76
  # Return the predicted label, the bar chart, and the heatmap
77
+ return predicted_class_label, fig_bar, fig_heatmap, attention_overlay
78
+
79
+
80
+ def update_model(image, label):
81
+ # Convert the label to an integer
82
+ label_idx = labels.index(label)
83
+ labels_tensor = torch.tensor([label_idx])
84
+
85
+ inputs = processor(images=image, return_tensors="pt")
86
+ loss_fn = torch.nn.CrossEntropyLoss()
87
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
88
+
89
+ # Zero the gradients
90
+ optimizer.zero_grad()
91
+
92
+ # Forward pass
93
+ outputs, attention = model(**inputs)
94
+ loss = loss_fn(outputs.logits, labels_tensor)
95
+
96
+ # Backward pass
97
+ loss.backward()
98
+
99
+ # Update the model parameters
100
+ optimizer.step()
101
+
102
+ # Save the updated model
103
+ torch.save(model.state_dict(), "best_model.pth")
104
+
105
+ return "Model updated successfully"
106
+
107
+
108
+ demo = gr.TabbedInterface(
109
+ [
110
+ gr.Interface(
111
+ fn=classify_image,
112
+ inputs=[
113
+ gr.Image(type="pil", label="Image")
114
+ ],
115
+ outputs=[
116
+ gr.Label(label="Predicted Class Label"),
117
+ gr.Plot(label="Top 5 Predicted Labels with Probabilities")
118
+ ],
119
+ title="Dermatological Image Classification Demo",
120
+ description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention heatmap",
121
+ allow_flagging=False
122
+ ),
123
+ gr.Interface(
124
+ fn=update_model,
125
+ inputs=[
126
+ gr.Image(type="pil", label="Image"),
127
+ gr.Radio(
128
+ choices=labels,
129
+ type="value",
130
+ label="Label",
131
+ value=labels[0]
132
+ )
133
+ ],
134
+ outputs=[
135
+ gr.Textbox(label="Model Update Status")
136
+ ],
137
+ title="Train Model",
138
+ description="Upload an image and label to update the model",
139
+ allow_flagging=False
140
+ )
141
  ],
142
+ title="Dermatological Image Classification and Training"
 
143
  )
144
 
145
  if __name__ == "__main__":
146
+ demo.launch(share=True)