|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import requests |
|
import os |
|
from typing import Tuple, Dict |
|
|
|
|
|
class PatchEmbedding(nn.Module): |
|
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.n_patches = (img_size // patch_size) ** 2 |
|
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
def forward(self, x): |
|
x = self.proj(x) |
|
x = x.flatten(2) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.): |
|
super().__init__() |
|
self.n_heads = n_heads |
|
self.scale = (dim // n_heads) ** -0.5 |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv.unbind(0) |
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
|
self.norm2 = nn.LayerNorm(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(dim, mlp_hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(drop), |
|
nn.Linear(mlp_hidden_dim, dim), |
|
nn.Dropout(drop) |
|
) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.norm1(x)) |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|
|
class CustomViT(nn.Module): |
|
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.): |
|
super().__init__() |
|
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)) |
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
self.blocks = nn.ModuleList([ |
|
TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate) |
|
for _ in range(depth) |
|
]) |
|
self.norm = nn.LayerNorm(embed_dim) |
|
self.head = nn.Linear(embed_dim, num_classes) |
|
|
|
def forward(self, x): |
|
B = x.shape[0] |
|
x = self.patch_embed(x) |
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
for block in self.blocks: |
|
x = block(x) |
|
x = self.norm(x) |
|
x = x[:, 0] |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
def load_model(model_path: str) -> CustomViT: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = CustomViT(num_classes=2) |
|
state_dict = torch.load(model_path, map_location=device) |
|
|
|
|
|
if all(k.startswith('module.') for k in state_dict.keys()): |
|
state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
def preprocess_image(image: np.ndarray) -> torch.Tensor: |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
return transform(image).unsqueeze(0) |
|
|
|
def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]: |
|
device = next(model.parameters()).device |
|
|
|
|
|
image_tensor = preprocess_image(image) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image_tensor.to(device)) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
|
|
|
|
visualization = image.copy() |
|
height, width = visualization.shape[:2] |
|
|
|
|
|
result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy" |
|
confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1]) |
|
|
|
|
|
color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0) |
|
cv2.putText(visualization, f"{result}: {confidence:.2%}", |
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) |
|
|
|
|
|
visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
labels = { |
|
"Leprosy": float(probabilities[0]), |
|
"No Leprosy": float(probabilities[1]) |
|
} |
|
|
|
return visualization, labels |
|
|
|
|
|
file_urls = [ |
|
'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1', |
|
'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1', |
|
'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1', |
|
] |
|
|
|
def download_example_images(): |
|
examples = [] |
|
for i, url in enumerate(file_urls): |
|
filename = f"example_{i}.jpg" |
|
if not os.path.exists(filename): |
|
response = requests.get(url) |
|
with open(filename, 'wb') as f: |
|
f.write(response.content) |
|
examples.append(filename) |
|
return examples |
|
|
|
|
|
def create_gradio_interface(): |
|
|
|
model = load_model('best_custom_vit_mo50.pth') |
|
|
|
|
|
def inference(image): |
|
return predict_image(image, model) |
|
|
|
|
|
examples = download_example_images() |
|
|
|
|
|
interface = gr.Interface( |
|
fn=inference, |
|
inputs=gr.Image(), |
|
outputs=[ |
|
gr.Image(label="Prediction Visualization"), |
|
gr.Label(label="Classification Probabilities") |
|
], |
|
title="Leprosy Detection using Vision Transformer", |
|
description="Upload an image to detect signs of leprosy using a Vision Transformer model.", |
|
examples=examples, |
|
cache_examples=False |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_gradio_interface() |
|
interface.launch() |
|
|