Spaces:
Sleeping
Sleeping
File size: 7,512 Bytes
0c9da96 f349e34 0c9da96 a05049f 0c9da96 a05049f 0c9da96 a05049f 0c9da96 f349e34 0c9da96 f349e34 0c9da96 a05049f 0c9da96 a05049f 0c9da96 f349e34 a05049f 0c9da96 a05049f 0c9da96 a05049f f349e34 a05049f f349e34 a05049f f349e34 a05049f f349e34 a05049f f349e34 a05049f 0c9da96 f349e34 0c9da96 12d56d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
import matplotlib.pyplot as plt
import gradio as gr
import plotly.graph_objects as go
import torch
import numpy as np
from PIL import Image
model_name = "./best_model"
processor = ViTImageProcessor.from_pretrained(model_name)
labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции']
class ViTForImageClassificationWithAttention(ViTForImageClassification):
def forward(self, pixel_values):
outputs = super().forward(pixel_values)
attention = self.vit.encoder.layers[0].attention.attention_weights
return outputs, attention
model = ViTForImageClassificationWithAttention.from_pretrained(model_name)
class ViTForImageClassificationWithAttention(ViTForImageClassification):
def forward(self, pixel_values, output_attentions=True):
outputs = super().forward(pixel_values, output_attentions=output_attentions)
attention = outputs.attentions
return outputs, attention
model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager")
i_count = 0
def classify_image(image):
model_name = "best_model.pth"
model.load_state_dict(torch.load(model_name))
inputs = processor(images=image, return_tensors="pt")
outputs, attention = model(**inputs, output_attentions=True)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1)
top_k_probs, top_k_indices = torch.topk(probs, k=5) # show top 5 predicted labels
predicted_class_idx = torch.argmax(logits)
predicted_class_label = labels[predicted_class_idx]
top_k_labels = [labels[idx] for idx in top_k_indices[0]]
top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])]
# Create a bar chart
fig_bar = go.Figure(
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз",
yaxis_title="Вероятность")
# Create a heatmap
if attention is not None:
fig_heatmap = go.Figure(
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
fig_heatmap.update_layout(title="Карта внимания системы")
else:
fig_heatmap = go.Figure() # Return an empty plot
# Overlay the attention heatmap on the input image
if attention is not None:
img_array = np.array(image)
heatmap = np.array(attention[0][0, 0, :, :].detach().numpy())
heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1]))
heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255]
heatmap = heatmap.astype(np.uint8)
heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8)
heatmap_color[:, :, 0] = heatmap # Red channel
heatmap_color[:, :, 1] = heatmap # Green channel
heatmap_color[:, :, 2] = 0 # Blue channel
attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8)
attention_overlay = Image.fromarray(attention_overlay)
attention_overlay.save("attention_overlay.png")
attention_overlay = gr.Image("attention_overlay.png")
else:
attention_overlay = gr.Image() # Return an empty image
# Return the predicted label, the bar chart, and the heatmap
return predicted_class_label, fig_bar, fig_heatmap, attention_overlay
def update_model(image, label):
# Convert the label to an integer
label_idx = labels.index(label)
labels_tensor = torch.tensor([label_idx])
inputs = processor(images=image, return_tensors="pt")
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
outputs, attention = model(**inputs)
loss = loss_fn(outputs.logits, labels_tensor)
# Backward pass
loss.backward()
# Update the model parameters
optimizer.step()
# Save the updated model
torch.save(model.state_dict(), "best_model.pth")
return "Модель успешно обновлена"
demo = gr.TabbedInterface(
[
gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil", label="Image")
],
outputs=[
gr.Label(label="Предсказанный диагноз"),
gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности")
],
title="DermaScan Demo",
description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.",
allow_flagging=False
),
gr.Interface(
fn=update_model,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Radio(
choices=labels,
type="value",
label="Label",
value=labels[0]
)
],
outputs=[
gr.Textbox(label="Обновление модели")
],
title="Обучить модель",
description="Загрузите изображение и метку для обновления модели.",
allow_flagging=False
)
],
title="DermaScan Demo"
)
if __name__ == "__main__":
demo.launch()
|