|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from torchvision import models |
|
from torchvision import transforms |
|
from transformers import ViTForImageClassification |
|
from torch import nn |
|
from torch.cuda.amp import autocast |
|
import os |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
label_mapping = { |
|
0: "Меланома", |
|
1: "Меланоцитарный невус", |
|
2: "Базальноклеточная карцинома", |
|
3: "Актинический кератоз", |
|
4: "Доброкачественная кератоза", |
|
5: "Дерматофиброма", |
|
6: "Сосудистые поражения" |
|
} |
|
|
|
|
|
CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./") |
|
|
|
|
|
def get_efficientnet(): |
|
model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") |
|
model.classifier[1] = nn.Linear(1280, 7) |
|
return model.to(device) |
|
|
|
def get_deit(): |
|
model = ViTForImageClassification.from_pretrained( |
|
'facebook/deit-base-patch16-224', |
|
num_labels=7, |
|
ignore_mismatched_sizes=True |
|
) |
|
return model.to(device) |
|
|
|
|
|
def transform_image(image): |
|
"""Transform PIL image to model input format""" |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
return transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
class ModelHandler: |
|
def __init__(self): |
|
self.efficientnet = None |
|
self.deit = None |
|
self.models_loaded = False |
|
self.load_models() |
|
|
|
def load_models(self): |
|
try: |
|
|
|
self.efficientnet = get_efficientnet() |
|
efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth") |
|
self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device)) |
|
self.efficientnet.eval() |
|
|
|
|
|
self.deit = get_deit() |
|
deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth") |
|
self.deit.load_state_dict(torch.load(deit_path, map_location=device)) |
|
self.deit.eval() |
|
|
|
self.models_loaded = True |
|
print("✅ Models loaded successfully") |
|
except Exception as e: |
|
print(f"❌ Error loading models: {str(e)}") |
|
self.models_loaded = False |
|
|
|
@torch.no_grad() |
|
def predict_efficientnet(self, image): |
|
if not self.models_loaded: |
|
return {"error": "Модели не загружены"} |
|
|
|
inputs = transform_image(image) |
|
with autocast(): |
|
outputs = self.efficientnet(inputs) |
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
|
return self._format_predictions(probs) |
|
|
|
@torch.no_grad() |
|
def predict_deit(self, image): |
|
if not self.models_loaded: |
|
return {"error": "Модели не загружены"} |
|
|
|
inputs = transform_image(image) |
|
with autocast(): |
|
outputs = self.deit(inputs).logits |
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
|
return self._format_predictions(probs) |
|
|
|
@torch.no_grad() |
|
def predict_ensemble(self, image): |
|
if not self.models_loaded: |
|
return {"error": "Модели не загружены"} |
|
|
|
inputs = transform_image(image) |
|
with autocast(): |
|
eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1) |
|
deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1) |
|
ensemble_probs = (eff_probs + deit_probs) / 2 |
|
|
|
return self._format_predictions(ensemble_probs) |
|
|
|
def _format_predictions(self, probs): |
|
top5_probs, top5_indices = torch.topk(probs, 5) |
|
result = {} |
|
for i in range(5): |
|
idx = top5_indices[0][i].item() |
|
label = label_mapping.get(idx, f"Класс {idx}") |
|
|
|
result[label] = float(top5_probs[0][i].item()) |
|
return result |
|
|
|
|
|
|
|
model_handler = ModelHandler() |
|
|
|
|
|
def predict_efficientnet(image): |
|
if image is None: |
|
return "⚠️ Загрузите изображение" |
|
return model_handler.predict_efficientnet(image) |
|
|
|
def predict_deit(image): |
|
if image is None: |
|
return "⚠️ Загрузите изображение" |
|
return model_handler.predict_deit(image) |
|
|
|
def predict_ensemble(image): |
|
if image is None: |
|
return "⚠️ Загрузите изображение" |
|
return model_handler.predict_ensemble(image) |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Диагностика кожных поражений (HAM10K)") |
|
status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены" |
|
gr.Markdown(f"**Состояние моделей:** {status}") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("EfficientNet"): |
|
img = gr.Image(label="Загрузите изображение", type="pil") |
|
btn = gr.Button("Предсказать", variant="primary") |
|
out = gr.Label(label="Результаты") |
|
btn.click(predict_efficientnet, inputs=img, outputs=out) |
|
gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) |
|
|
|
with gr.TabItem("DeiT"): |
|
img = gr.Image(label="Загрузите изображение", type="pil") |
|
btn = gr.Button("Предсказать", variant="primary") |
|
out = gr.Label(label="Результаты") |
|
btn.click(predict_deit, inputs=img, outputs=out) |
|
gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) |
|
|
|
with gr.TabItem("Ансамблевая модель"): |
|
img = gr.Image(label="Загрузите изображение", type="pil") |
|
btn = gr.Button("Предсказать", variant="primary") |
|
out = gr.Label(label="Результаты") |
|
btn.click(predict_ensemble, inputs=img, outputs=out) |
|
gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
print("🚀 Запуск интерфейса...") |
|
interface.launch(share=True) |
|
|