File size: 7,354 Bytes
739fe18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# app.py
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
import timm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Optional: If integrating OCR
# import pytesseract
# Define the Detection Model Architecture
class ViTDetectionModel(nn.Module):
def __init__(self, num_queries=100, hidden_dim=768):
"""
Initializes the ViTDetectionModel.
Args:
num_queries (int, optional): Number of detection queries. Defaults to 100.
hidden_dim (int, optional): Hidden dimension size. Defaults to 768.
"""
super(ViTDetectionModel, self).__init__()
# Configure the ViT model to output features only
self.vit = timm.create_model(
'vit_base_patch16_224',
pretrained=False, # Set to False since we are loading a trained model
num_classes=0, # Disable classification head
features_only=True, # Return feature maps
out_indices=(11,) # Get the last feature map
)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.fc_bbox = nn.Linear(hidden_dim, 8) # 4 points (x, y) for quadrilateral
self.fc_class = nn.Linear(hidden_dim, 1) # Binary classification
def forward(self, x):
"""
Forward pass of the detection model.
Args:
x (Tensor): Input images [batch, 3, H, W].
Returns:
Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores.
"""
# Retrieve the feature map
features = self.vit(x)[0] # [batch, hidden_dim, H*W]
if features.dim() == 3:
batch_size, hidden_dim, num_patches = features.shape
grid_size = int(np.sqrt(num_patches))
if grid_size * grid_size != num_patches:
raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
H, W = grid_size, grid_size
features = features.view(batch_size, hidden_dim, H, W)
elif features.dim() == 4:
batch_size, hidden_dim, H, W = features.shape
else:
raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
# Flatten the spatial dimensions
features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
# Prepare query embeddings
queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) # [batch, num_queries, hidden_dim]
# Compute attention weights
attn = torch.matmul(features, queries.transpose(-1, -2)) # [batch, H*W, num_queries]
attn = torch.softmax(attn, dim=1) # Softmax over patches
# Aggregate features based on attention
output = torch.matmul(attn.transpose(-1, -2), features) # [batch, num_queries, hidden_dim]
# Predict bounding boxes and classes
bboxes = self.fc_bbox(output) # [batch, num_queries, 8]
classes = self.fc_class(output) # [batch, num_queries, 1]
return bboxes, classes
# Function to Load the Trained Model
def load_model(model_path, device):
"""
Loads the trained detection model.
Args:
model_path (str): Path to the saved model state dictionary.
device (torch.device): Device to load the model on.
Returns:
nn.Module: Loaded detection model.
"""
model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
# Function to Perform Text Detection on an Image
def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
"""
Detects text in the input image using the detection model.
Args:
image (PIL Image): Input image.
model (nn.Module): Trained detection model.
device (torch.device): Device to run the model on.
max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100.
confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5.
Returns:
PIL Image: Image with detected bounding boxes drawn.
"""
# Define transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Preprocess the image
input_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform detection
with torch.no_grad():
pred_bboxes, pred_classes = model(input_tensor) # [1, num_queries, 8], [1, num_queries, 1]
# Process predictions
pred_bboxes = pred_bboxes.squeeze(0) # [num_queries, 8]
pred_classes = pred_classes.squeeze(0) # [num_queries, 1]
pred_classes_sigmoid = torch.sigmoid(pred_classes)
high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1)
selected_indices = high_conf_indices[:max_boxes]
selected_bboxes = pred_bboxes[selected_indices] # [selected, 8]
# Denormalize bounding boxes to original image size
width, height = image.size
scale_x = width / 224
scale_y = height / 224
boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) # [selected, 8]
# Draw bounding boxes on the image
fig, ax = plt.subplots(1, figsize=(12, 12))
ax.imshow(image)
for box in boxes:
polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(polygon)
plt.axis('off')
# Convert Matplotlib figure to PIL Image
fig.canvas.draw()
img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
plt.close(fig)
return img_with_boxes
# Optional: If integrating OCR with pytesseract
# def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
# # Similar to detect_text but includes OCR steps
# pass
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "detection_model.pth" # Ensure this path matches where the model is stored
model = load_model(model_path, device)
print("Model loaded successfully.")
# Define the Gradio Interface Function
def gradio_detect(image):
"""
Gradio interface function for text detection.
Args:
image (PIL Image): Uploaded image.
Returns:
PIL Image: Image with detected bounding boxes.
"""
result_image = detect_text(image, model, device)
return result_image
# Create Gradio Interface
iface = gr.Interface(
fn=gradio_detect,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Text Detection with ViT",
description="Upload an image, and the model will detect and highlight text regions.",
examples=[
# You can add URLs or paths to example images here
# "https://example.com/image1.jpg",
# "https://example.com/image2.jpg",
],
allow_flagging="never"
)
# Launch the Gradio App (Optional for local testing)
# if __name__ == "__main__":
# iface.launch()
|