import gradio as gr import torch from PIL import Image import numpy as np from torchvision import models from torchvision import transforms from transformers import ViTForImageClassification from torch import nn from torch.cuda.amp import autocast import os # Global configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Label mapping (HAM10K) label_mapping = { 0: "Меланома", 1: "Меланоцитарный невус", 2: "Базальноклеточная карцинома", 3: "Актинический кератоз", 4: "Доброкачественная кератоза", 5: "Дерматофиброма", 6: "Сосудистые поражения" } # Model paths CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./") # Model definitions def get_efficientnet(): model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") model.classifier[1] = nn.Linear(1280, 7) return model.to(device) def get_deit(): model = ViTForImageClassification.from_pretrained( 'facebook/deit-base-patch16-224', num_labels=7, ignore_mismatched_sizes=True ) return model.to(device) # Transforms def transform_image(image): """Transform PIL image to model input format""" transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) return transform(image).unsqueeze(0).to(device) # Model Handler class ModelHandler: def __init__(self): self.efficientnet = None self.deit = None self.models_loaded = False self.load_models() def load_models(self): try: # Load EfficientNet self.efficientnet = get_efficientnet() efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth") self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device)) self.efficientnet.eval() # Load DeiT self.deit = get_deit() deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth") self.deit.load_state_dict(torch.load(deit_path, map_location=device)) self.deit.eval() self.models_loaded = True print("✅ Models loaded successfully") except Exception as e: print(f"❌ Error loading models: {str(e)}") self.models_loaded = False @torch.no_grad() def predict_efficientnet(self, image): if not self.models_loaded: return {"error": "Модели не загружены"} inputs = transform_image(image) with autocast(): outputs = self.efficientnet(inputs) probs = torch.nn.functional.softmax(outputs, dim=1) return self._format_predictions(probs) @torch.no_grad() def predict_deit(self, image): if not self.models_loaded: return {"error": "Модели не загружены"} inputs = transform_image(image) with autocast(): outputs = self.deit(inputs).logits probs = torch.nn.functional.softmax(outputs, dim=1) return self._format_predictions(probs) @torch.no_grad() def predict_ensemble(self, image): if not self.models_loaded: return {"error": "Модели не загружены"} inputs = transform_image(image) with autocast(): eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1) deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1) ensemble_probs = (eff_probs + deit_probs) / 2 return self._format_predictions(ensemble_probs) def _format_predictions(self, probs): top5_probs, top5_indices = torch.topk(probs, 5) result = {} for i in range(5): idx = top5_indices[0][i].item() label = label_mapping.get(idx, f"Класс {idx}") # return raw prob, not percent: result[label] = float(top5_probs[0][i].item()) return result # Initialize model handler model_handler = ModelHandler() # Prediction wrappers def predict_efficientnet(image): if image is None: return "⚠️ Загрузите изображение" return model_handler.predict_efficientnet(image) def predict_deit(image): if image is None: return "⚠️ Загрузите изображение" return model_handler.predict_deit(image) def predict_ensemble(image): if image is None: return "⚠️ Загрузите изображение" return model_handler.predict_ensemble(image) # Create Gradio Blocks with Tabs def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Диагностика кожных поражений (HAM10K)") status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены" gr.Markdown(f"**Состояние моделей:** {status}") with gr.Tabs(): with gr.TabItem("EfficientNet"): img = gr.Image(label="Загрузите изображение", type="pil") btn = gr.Button("Предсказать", variant="primary") out = gr.Label(label="Результаты") btn.click(predict_efficientnet, inputs=img, outputs=out) gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) with gr.TabItem("DeiT"): img = gr.Image(label="Загрузите изображение", type="pil") btn = gr.Button("Предсказать", variant="primary") out = gr.Label(label="Результаты") btn.click(predict_deit, inputs=img, outputs=out) gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) with gr.TabItem("Ансамблевая модель"): img = gr.Image(label="Загрузите изображение", type="pil") btn = gr.Button("Предсказать", variant="primary") out = gr.Label(label="Результаты") btn.click(predict_ensemble, inputs=img, outputs=out) gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img) return demo # Launch interface if __name__ == "__main__": interface = create_interface() print("🚀 Запуск интерфейса...") interface.launch(share=True)