File size: 7,512 Bytes
0c9da96
 
 
 
 
 
 
 
 
 
 
f349e34
0c9da96
a05049f
 
 
 
 
 
 
 
0c9da96
 
 
 
 
 
 
 
a05049f
0c9da96
a05049f
 
 
0c9da96
 
 
 
 
 
 
 
 
 
 
 
f349e34
 
0c9da96
 
 
 
 
f349e34
0c9da96
a05049f
0c9da96
 
 
a05049f
0c9da96
 
 
 
 
 
 
 
f349e34
a05049f
 
0c9da96
 
a05049f
0c9da96
 
a05049f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f349e34
a05049f
 
 
 
 
 
 
 
 
 
f349e34
 
a05049f
f349e34
 
a05049f
 
 
 
 
 
 
 
 
 
 
 
 
 
f349e34
a05049f
f349e34
 
a05049f
 
0c9da96
f349e34
0c9da96
 
 
12d56d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from transformers import ViTForImageClassification, ViTImageProcessor

import matplotlib.pyplot as plt
import gradio as gr
import plotly.graph_objects as go
import torch
import numpy as np
from PIL import Image
model_name = "./best_model"
processor = ViTImageProcessor.from_pretrained(model_name)
labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции']

class ViTForImageClassificationWithAttention(ViTForImageClassification):
    def forward(self, pixel_values):
        outputs = super().forward(pixel_values)
        attention = self.vit.encoder.layers[0].attention.attention_weights
        return outputs, attention

model = ViTForImageClassificationWithAttention.from_pretrained(model_name)

class ViTForImageClassificationWithAttention(ViTForImageClassification):
    def forward(self, pixel_values, output_attentions=True):
        outputs = super().forward(pixel_values, output_attentions=output_attentions)
        attention = outputs.attentions
        return outputs, attention

model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")

i_count = 0
def classify_image(image):
    model_name = "best_model.pth"
    model.load_state_dict(torch.load(model_name))
    inputs = processor(images=image, return_tensors="pt")
    outputs, attention = model(**inputs, output_attentions=True)
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=1)
    top_k_probs, top_k_indices = torch.topk(probs, k=5)  # show top 5 predicted labels
    predicted_class_idx = torch.argmax(logits)
    predicted_class_label = labels[predicted_class_idx]
    top_k_labels = [labels[idx] for idx in top_k_indices[0]]
    top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])]

    # Create a bar chart
    fig_bar = go.Figure(
        data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
    fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз",
                          yaxis_title="Вероятность")

    # Create a heatmap
    if attention is not None:
        fig_heatmap = go.Figure(
            data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
        fig_heatmap.update_layout(title="Карта внимания системы")
    else:
        fig_heatmap = go.Figure()  # Return an empty plot

    # Overlay the attention heatmap on the input image
    if attention is not None:
        img_array = np.array(image)
        heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
        heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
        heatmap = heatmap / heatmap.max() * 255  # Normalize heatmap to [0, 255]
        heatmap = heatmap.astype(np.uint8)
        heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8)
        heatmap_color[:, :, 0] = heatmap  # Red channel
        heatmap_color[:, :, 1] = heatmap  # Green channel
        heatmap_color[:, :, 2] = 0  # Blue channel
        attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8)
        attention_overlay = Image.fromarray(attention_overlay)
        attention_overlay.save("attention_overlay.png")
        attention_overlay = gr.Image("attention_overlay.png")
    else:
        attention_overlay = gr.Image()  # Return an empty image

    # Return the predicted label, the bar chart, and the heatmap
    return predicted_class_label, fig_bar, fig_heatmap, attention_overlay


def update_model(image, label):
    # Convert the label to an integer
    label_idx = labels.index(label)
    labels_tensor = torch.tensor([label_idx])

    inputs = processor(images=image, return_tensors="pt")
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs, attention = model(**inputs)
    loss = loss_fn(outputs.logits, labels_tensor)

    # Backward pass
    loss.backward()

    # Update the model parameters
    optimizer.step()

    # Save the updated model
    torch.save(model.state_dict(), "best_model.pth")

    return "Модель успешно обновлена"


demo = gr.TabbedInterface(
    [
        gr.Interface(
            fn=classify_image,
            inputs=[
                gr.Image(type="pil", label="Image")
            ],
            outputs=[
                gr.Label(label="Предсказанный диагноз"),
                gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности")
            ],
            title="DermaScan Demo",
            description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.",
            allow_flagging=False
        ),
        gr.Interface(
            fn=update_model,
            inputs=[
                gr.Image(type="pil", label="Image"),
                gr.Radio(
                    choices=labels,
                    type="value",
                    label="Label",
                    value=labels[0]
                )
            ],
            outputs=[
                gr.Textbox(label="Обновление модели")
            ],
            title="Обучить модель",
            description="Загрузите изображение и метку для обновления модели.",
            allow_flagging=False
        )
    ],
    title="DermaScan Demo"
)

if __name__ == "__main__":
    demo.launch()