import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import cv2 import numpy as np import requests import os from typing import Tuple, Dict # CustomViT model definition class PatchEmbedding(nn.Module): def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) return x class Attention(nn.Module): def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.n_heads = n_heads self.scale = (dim // n_heads) ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class TransformerBlock(nn.Module): def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(drop), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop) ) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class CustomViT(nn.Module): def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate) for _ in range(depth) ]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for block in self.blocks: x = block(x) x = self.norm(x) x = x[:, 0] x = self.head(x) return x # Helper functions def load_model(model_path: str) -> CustomViT: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CustomViT(num_classes=2) state_dict = torch.load(model_path, map_location=device) # Remove 'module.' prefix if present if all(k.startswith('module.') for k in state_dict.keys()): state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.to(device) model.eval() return model def preprocess_image(image: np.ndarray) -> torch.Tensor: # Convert numpy array to PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]: device = next(model.parameters()).device # Preprocess the image image_tensor = preprocess_image(image) # Make prediction with torch.no_grad(): outputs = model(image_tensor.to(device)) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] # Create visualization visualization = image.copy() height, width = visualization.shape[:2] # Add prediction overlay result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy" confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1]) # Add text to image color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0) cv2.putText(visualization, f"{result}: {confidence:.2%}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) # Convert BGR to RGB for Gradio visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB) # Prepare labels dictionary labels = { "Leprosy": float(probabilities[0]), "No Leprosy": float(probabilities[1]) } return visualization, labels # Download example images file_urls = [ 'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1', 'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1', 'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1', ] def download_example_images(): examples = [] for i, url in enumerate(file_urls): filename = f"example_{i}.jpg" if not os.path.exists(filename): response = requests.get(url) with open(filename, 'wb') as f: f.write(response.content) examples.append(filename) return examples # Main Gradio interface def create_gradio_interface(): # Load the model model = load_model('best_custom_vit_mo50.pth') # Create inference function def inference(image): return predict_image(image, model) # Download example images examples = download_example_images() # Create Gradio interface interface = gr.Interface( fn=inference, inputs=gr.Image(), outputs=[ gr.Image(label="Prediction Visualization"), gr.Label(label="Classification Probabilities") ], title="Leprosy Detection using Vision Transformer", description="Upload an image to detect signs of leprosy using a Vision Transformer model.", examples=examples, cache_examples=False ) return interface if __name__ == "__main__": interface = create_gradio_interface() interface.launch()