amst / app.py
Gizachew's picture
Create app.py
739fe18 verified
raw
history blame
7.35 kB
# app.py
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
import timm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Optional: If integrating OCR
# import pytesseract
# Define the Detection Model Architecture
class ViTDetectionModel(nn.Module):
def __init__(self, num_queries=100, hidden_dim=768):
"""
Initializes the ViTDetectionModel.
Args:
num_queries (int, optional): Number of detection queries. Defaults to 100.
hidden_dim (int, optional): Hidden dimension size. Defaults to 768.
"""
super(ViTDetectionModel, self).__init__()
# Configure the ViT model to output features only
self.vit = timm.create_model(
'vit_base_patch16_224',
pretrained=False, # Set to False since we are loading a trained model
num_classes=0, # Disable classification head
features_only=True, # Return feature maps
out_indices=(11,) # Get the last feature map
)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.fc_bbox = nn.Linear(hidden_dim, 8) # 4 points (x, y) for quadrilateral
self.fc_class = nn.Linear(hidden_dim, 1) # Binary classification
def forward(self, x):
"""
Forward pass of the detection model.
Args:
x (Tensor): Input images [batch, 3, H, W].
Returns:
Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores.
"""
# Retrieve the feature map
features = self.vit(x)[0] # [batch, hidden_dim, H*W]
if features.dim() == 3:
batch_size, hidden_dim, num_patches = features.shape
grid_size = int(np.sqrt(num_patches))
if grid_size * grid_size != num_patches:
raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
H, W = grid_size, grid_size
features = features.view(batch_size, hidden_dim, H, W)
elif features.dim() == 4:
batch_size, hidden_dim, H, W = features.shape
else:
raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
# Flatten the spatial dimensions
features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
# Prepare query embeddings
queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) # [batch, num_queries, hidden_dim]
# Compute attention weights
attn = torch.matmul(features, queries.transpose(-1, -2)) # [batch, H*W, num_queries]
attn = torch.softmax(attn, dim=1) # Softmax over patches
# Aggregate features based on attention
output = torch.matmul(attn.transpose(-1, -2), features) # [batch, num_queries, hidden_dim]
# Predict bounding boxes and classes
bboxes = self.fc_bbox(output) # [batch, num_queries, 8]
classes = self.fc_class(output) # [batch, num_queries, 1]
return bboxes, classes
# Function to Load the Trained Model
def load_model(model_path, device):
"""
Loads the trained detection model.
Args:
model_path (str): Path to the saved model state dictionary.
device (torch.device): Device to load the model on.
Returns:
nn.Module: Loaded detection model.
"""
model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
# Function to Perform Text Detection on an Image
def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
"""
Detects text in the input image using the detection model.
Args:
image (PIL Image): Input image.
model (nn.Module): Trained detection model.
device (torch.device): Device to run the model on.
max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100.
confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5.
Returns:
PIL Image: Image with detected bounding boxes drawn.
"""
# Define transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Preprocess the image
input_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform detection
with torch.no_grad():
pred_bboxes, pred_classes = model(input_tensor) # [1, num_queries, 8], [1, num_queries, 1]
# Process predictions
pred_bboxes = pred_bboxes.squeeze(0) # [num_queries, 8]
pred_classes = pred_classes.squeeze(0) # [num_queries, 1]
pred_classes_sigmoid = torch.sigmoid(pred_classes)
high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1)
selected_indices = high_conf_indices[:max_boxes]
selected_bboxes = pred_bboxes[selected_indices] # [selected, 8]
# Denormalize bounding boxes to original image size
width, height = image.size
scale_x = width / 224
scale_y = height / 224
boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) # [selected, 8]
# Draw bounding boxes on the image
fig, ax = plt.subplots(1, figsize=(12, 12))
ax.imshow(image)
for box in boxes:
polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(polygon)
plt.axis('off')
# Convert Matplotlib figure to PIL Image
fig.canvas.draw()
img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
plt.close(fig)
return img_with_boxes
# Optional: If integrating OCR with pytesseract
# def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
# # Similar to detect_text but includes OCR steps
# pass
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "detection_model.pth" # Ensure this path matches where the model is stored
model = load_model(model_path, device)
print("Model loaded successfully.")
# Define the Gradio Interface Function
def gradio_detect(image):
"""
Gradio interface function for text detection.
Args:
image (PIL Image): Uploaded image.
Returns:
PIL Image: Image with detected bounding boxes.
"""
result_image = detect_text(image, model, device)
return result_image
# Create Gradio Interface
iface = gr.Interface(
fn=gradio_detect,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Text Detection with ViT",
description="Upload an image, and the model will detect and highlight text regions.",
examples=[
# You can add URLs or paths to example images here
# "https://example.com/image1.jpg",
# "https://example.com/image2.jpg",
],
allow_flagging="never"
)
# Launch the Gradio App (Optional for local testing)
# if __name__ == "__main__":
# iface.launch()