Spaces:
Sleeping
Sleeping
Update app.py
Browse filesadded model update capability based on known diagnoses from specilists
app.py
CHANGED
@@ -7,11 +7,18 @@ import plotly.graph_objects as go
|
|
7 |
import torch
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
-
|
11 |
model_name = "./best_model"
|
12 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
13 |
labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
16 |
def forward(self, pixel_values, output_attentions=True):
|
17 |
outputs = super().forward(pixel_values, output_attentions=output_attentions)
|
@@ -20,9 +27,11 @@ class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
|
20 |
|
21 |
model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
|
22 |
|
|
|
23 |
def classify_image(image):
|
24 |
-
|
25 |
-
|
|
|
26 |
outputs, attention = model(**inputs, output_attentions=True)
|
27 |
logits = outputs.logits
|
28 |
probs = torch.nn.functional.softmax(logits, dim=1)
|
@@ -44,13 +53,11 @@ def classify_image(image):
|
|
44 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
45 |
fig_heatmap.update_layout(title="Attention Heatmap")
|
46 |
else:
|
47 |
-
fig_heatmap =
|
48 |
|
49 |
-
# Overlay the attention heatmap on the input image
|
50 |
-
# Overlay the attention heatmap on the input image
|
51 |
# Overlay the attention heatmap on the input image
|
52 |
if attention is not None:
|
53 |
-
img_array = np.array(
|
54 |
heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
|
55 |
heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
|
56 |
heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
|
@@ -60,26 +67,80 @@ def classify_image(image):
|
|
60 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
61 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
62 |
attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
|
63 |
-
|
64 |
-
|
65 |
attention_overlay = gr.Image("attention_overlay.png")
|
66 |
else:
|
67 |
-
attention_overlay =
|
68 |
|
69 |
# Return the predicted label, the bar chart, and the heatmap
|
70 |
-
return predicted_class_label, fig_bar, attention_overlay
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
],
|
80 |
-
title="Dermatological Image Classification
|
81 |
-
description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention overlay",
|
82 |
)
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
-
demo.launch()
|
|
|
7 |
import torch
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
|
|
10 |
model_name = "./best_model"
|
11 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
12 |
labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']
|
13 |
|
14 |
+
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
15 |
+
def forward(self, pixel_values):
|
16 |
+
outputs = super().forward(pixel_values)
|
17 |
+
attention = self.vit.encoder.layers[0].attention.attention_weights
|
18 |
+
return outputs, attention
|
19 |
+
|
20 |
+
model = ViTForImageClassificationWithAttention.from_pretrained(model_name)
|
21 |
+
|
22 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
23 |
def forward(self, pixel_values, output_attentions=True):
|
24 |
outputs = super().forward(pixel_values, output_attentions=output_attentions)
|
|
|
27 |
|
28 |
model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
|
29 |
|
30 |
+
i_count = 0
|
31 |
def classify_image(image):
|
32 |
+
model_name = "best_model.pth"
|
33 |
+
model.load_state_dict(torch.load(model_name))
|
34 |
+
inputs = processor(images=image, return_tensors="pt")
|
35 |
outputs, attention = model(**inputs, output_attentions=True)
|
36 |
logits = outputs.logits
|
37 |
probs = torch.nn.functional.softmax(logits, dim=1)
|
|
|
53 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
54 |
fig_heatmap.update_layout(title="Attention Heatmap")
|
55 |
else:
|
56 |
+
fig_heatmap = go.Figure() # Return an empty plot
|
57 |
|
|
|
|
|
58 |
# Overlay the attention heatmap on the input image
|
59 |
if attention is not None:
|
60 |
+
img_array = np.array(image)
|
61 |
heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
|
62 |
heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
|
63 |
heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
|
|
|
67 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
68 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
69 |
attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
|
70 |
+
attention_overlay = Image.fromarray(attention_overlay)
|
71 |
+
attention_overlay.save("attention_overlay.png")
|
72 |
attention_overlay = gr.Image("attention_overlay.png")
|
73 |
else:
|
74 |
+
attention_overlay = gr.Image() # Return an empty image
|
75 |
|
76 |
# Return the predicted label, the bar chart, and the heatmap
|
77 |
+
return predicted_class_label, fig_bar, fig_heatmap, attention_overlay
|
78 |
+
|
79 |
+
|
80 |
+
def update_model(image, label):
|
81 |
+
# Convert the label to an integer
|
82 |
+
label_idx = labels.index(label)
|
83 |
+
labels_tensor = torch.tensor([label_idx])
|
84 |
+
|
85 |
+
inputs = processor(images=image, return_tensors="pt")
|
86 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
87 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
88 |
+
|
89 |
+
# Zero the gradients
|
90 |
+
optimizer.zero_grad()
|
91 |
+
|
92 |
+
# Forward pass
|
93 |
+
outputs, attention = model(**inputs)
|
94 |
+
loss = loss_fn(outputs.logits, labels_tensor)
|
95 |
+
|
96 |
+
# Backward pass
|
97 |
+
loss.backward()
|
98 |
+
|
99 |
+
# Update the model parameters
|
100 |
+
optimizer.step()
|
101 |
+
|
102 |
+
# Save the updated model
|
103 |
+
torch.save(model.state_dict(), "best_model.pth")
|
104 |
+
|
105 |
+
return "Model updated successfully"
|
106 |
+
|
107 |
+
|
108 |
+
demo = gr.TabbedInterface(
|
109 |
+
[
|
110 |
+
gr.Interface(
|
111 |
+
fn=classify_image,
|
112 |
+
inputs=[
|
113 |
+
gr.Image(type="pil", label="Image")
|
114 |
+
],
|
115 |
+
outputs=[
|
116 |
+
gr.Label(label="Predicted Class Label"),
|
117 |
+
gr.Plot(label="Top 5 Predicted Labels with Probabilities")
|
118 |
+
],
|
119 |
+
title="Dermatological Image Classification Demo",
|
120 |
+
description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention heatmap",
|
121 |
+
allow_flagging=False
|
122 |
+
),
|
123 |
+
gr.Interface(
|
124 |
+
fn=update_model,
|
125 |
+
inputs=[
|
126 |
+
gr.Image(type="pil", label="Image"),
|
127 |
+
gr.Radio(
|
128 |
+
choices=labels,
|
129 |
+
type="value",
|
130 |
+
label="Label",
|
131 |
+
value=labels[0]
|
132 |
+
)
|
133 |
+
],
|
134 |
+
outputs=[
|
135 |
+
gr.Textbox(label="Model Update Status")
|
136 |
+
],
|
137 |
+
title="Train Model",
|
138 |
+
description="Upload an image and label to update the model",
|
139 |
+
allow_flagging=False
|
140 |
+
)
|
141 |
],
|
142 |
+
title="Dermatological Image Classification and Training"
|
|
|
143 |
)
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
+
demo.launch(share=True)
|