Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import requests | |
import os | |
from typing import Tuple, Dict | |
# CustomViT model definition | |
class PatchEmbedding(nn.Module): | |
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): | |
super().__init__() | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.n_patches = (img_size // patch_size) ** 2 | |
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
def forward(self, x): | |
x = self.proj(x) | |
x = x.flatten(2) | |
x = x.transpose(1, 2) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.n_heads = n_heads | |
self.scale = (dim // n_heads) ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class TransformerBlock(nn.Module): | |
def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(dim) | |
self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.norm2 = nn.LayerNorm(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = nn.Sequential( | |
nn.Linear(dim, mlp_hidden_dim), | |
nn.GELU(), | |
nn.Dropout(drop), | |
nn.Linear(mlp_hidden_dim, dim), | |
nn.Dropout(drop) | |
) | |
def forward(self, x): | |
x = x + self.attn(self.norm1(x)) | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class CustomViT(nn.Module): | |
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.): | |
super().__init__() | |
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate) | |
for _ in range(depth) | |
]) | |
self.norm = nn.LayerNorm(embed_dim) | |
self.head = nn.Linear(embed_dim, num_classes) | |
def forward(self, x): | |
B = x.shape[0] | |
x = self.patch_embed(x) | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x = x + self.pos_embed | |
x = self.pos_drop(x) | |
for block in self.blocks: | |
x = block(x) | |
x = self.norm(x) | |
x = x[:, 0] | |
x = self.head(x) | |
return x | |
# Helper functions | |
def load_model(model_path: str) -> CustomViT: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = CustomViT(num_classes=2) | |
state_dict = torch.load(model_path, map_location=device) | |
# Remove 'module.' prefix if present | |
if all(k.startswith('module.') for k in state_dict.keys()): | |
state_dict = {k[7:]: v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def preprocess_image(image: np.ndarray) -> torch.Tensor: | |
# Convert numpy array to PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(image).unsqueeze(0) | |
def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]: | |
device = next(model.parameters()).device | |
# Preprocess the image | |
image_tensor = preprocess_image(image) | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(image_tensor.to(device)) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
# Create visualization | |
visualization = image.copy() | |
height, width = visualization.shape[:2] | |
# Add prediction overlay | |
result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy" | |
confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1]) | |
# Add text to image | |
color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0) | |
cv2.putText(visualization, f"{result}: {confidence:.2%}", | |
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) | |
# Convert BGR to RGB for Gradio | |
visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB) | |
# Prepare labels dictionary | |
labels = { | |
"Leprosy": float(probabilities[0]), | |
"No Leprosy": float(probabilities[1]) | |
} | |
return visualization, labels | |
# Download example images | |
file_urls = [ | |
'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1', | |
'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1', | |
'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1', | |
] | |
def download_example_images(): | |
examples = [] | |
for i, url in enumerate(file_urls): | |
filename = f"example_{i}.jpg" | |
if not os.path.exists(filename): | |
response = requests.get(url) | |
with open(filename, 'wb') as f: | |
f.write(response.content) | |
examples.append(filename) | |
return examples | |
# Main Gradio interface | |
def create_gradio_interface(): | |
# Load the model | |
model = load_model('best_custom_vit_mo50.pth') | |
# Create inference function | |
def inference(image): | |
return predict_image(image, model) | |
# Download example images | |
examples = download_example_images() | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=inference, | |
inputs=gr.Image(), | |
outputs=[ | |
gr.Image(label="Prediction Visualization"), | |
gr.Label(label="Classification Probabilities") | |
], | |
title="Leprosy Detection using Vision Transformer", | |
description="Upload an image to detect signs of leprosy using a Vision Transformer model.", | |
examples=examples, | |
cache_examples=False | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch() | |