ZDPLI commited on
Commit
0c9da96
·
verified ·
1 Parent(s): 65b9234

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTImageProcessor
3
+
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ import plotly.graph_objects as go
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ model_name = "./best_model"
12
+ processor = ViTImageProcessor.from_pretrained(model_name)
13
+ labels = ['Acne or Rosacea', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis', 'Bullous Disease', 'Cellulitis Impetigo and other Bacterial Infections', 'Contact Dermatitis', 'Eczema', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis', 'Warts Molluscum and other Viral Infections']
14
+
15
+ class ViTForImageClassificationWithAttention(ViTForImageClassification):
16
+ def forward(self, pixel_values, output_attentions=True):
17
+ outputs = super().forward(pixel_values, output_attentions=output_attentions)
18
+ attention = outputs.attentions
19
+ return outputs, attention
20
+
21
+ model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
22
+
23
+ def classify_image(image):
24
+ img = Image.open(image)
25
+ inputs = processor(images=img, return_tensors="pt")
26
+ outputs, attention = model(**inputs, output_attentions=True)
27
+ logits = outputs.logits
28
+ probs = torch.nn.functional.softmax(logits, dim=1)
29
+ top_k_probs, top_k_indices = torch.topk(probs, k=5) # show top 5 predicted labels
30
+ predicted_class_idx = torch.argmax(logits)
31
+ predicted_class_label = labels[predicted_class_idx]
32
+ top_k_labels = [labels[idx] for idx in top_k_indices[0]]
33
+ top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])]
34
+
35
+ # Create a bar chart
36
+ fig_bar = go.Figure(
37
+ data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
38
+ fig_bar.update_layout(title="Top 5 Predicted Labels with Probabilities", xaxis_title="Label",
39
+ yaxis_title="Probability")
40
+
41
+ # Create a heatmap
42
+ if attention is not None:
43
+ fig_heatmap = go.Figure(
44
+ data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
45
+ fig_heatmap.update_layout(title="Attention Heatmap")
46
+ else:
47
+ fig_heatmap = None
48
+
49
+ # Overlay the attention heatmap on the input image
50
+ # Overlay the attention heatmap on the input image
51
+ # Overlay the attention heatmap on the input image
52
+ if attention is not None:
53
+ img_array = np.array(img)
54
+ heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
55
+ heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
56
+ heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
57
+ heatmap = heatmap.astype(np.uint8)
58
+ heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8)
59
+ heatmap_color[:, :, 0] = heatmap # Red channel
60
+ heatmap_color[:, :, 1] = heatmap # Green channel
61
+ heatmap_color[:, :, 2] = 0 # Blue channel
62
+ attention_overlay = (img_array * 0.5 + heatmap_color * 0.5).astype(np.uint8)
63
+ img = Image.fromarray(attention_overlay)
64
+ img.save("attention_overlay.png")
65
+ attention_overlay = gr.Image("attention_overlay.png")
66
+ else:
67
+ attention_overlay = None
68
+
69
+ # Return the predicted label, the bar chart, and the heatmap
70
+ return predicted_class_label, fig_bar, attention_overlay
71
+
72
+ demo = gr.Interface(
73
+ fn=classify_image,
74
+ inputs=gr.Image(type="filepath"),
75
+ outputs=[
76
+ gr.Label(label="Predicted Class Label"),
77
+ gr.Plot(label="Top 5 Predicted Labels with Probabilities"),
78
+ gr.Image(label="Attention Overlay")
79
+ ],
80
+ title="Dermatological Image Classification Demo",
81
+ description="Upload an image to see the predicted class label, top 5 predicted labels with probabilities, and attention overlay",
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()