|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
import timm |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
|
|
|
|
|
|
|
|
|
|
class ViTDetectionModel(nn.Module): |
|
def __init__(self, num_queries=100, hidden_dim=768): |
|
""" |
|
Initializes the ViTDetectionModel. |
|
|
|
Args: |
|
num_queries (int, optional): Number of detection queries. Defaults to 100. |
|
hidden_dim (int, optional): Hidden dimension size. Defaults to 768. |
|
""" |
|
super(ViTDetectionModel, self).__init__() |
|
|
|
self.vit = timm.create_model( |
|
'vit_base_patch16_224', |
|
pretrained=False, |
|
num_classes=0, |
|
features_only=True, |
|
out_indices=(11,) |
|
) |
|
self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
self.fc_bbox = nn.Linear(hidden_dim, 8) |
|
self.fc_class = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass of the detection model. |
|
|
|
Args: |
|
x (Tensor): Input images [batch, 3, H, W]. |
|
|
|
Returns: |
|
Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores. |
|
""" |
|
|
|
features = self.vit(x)[0] |
|
|
|
if features.dim() == 3: |
|
batch_size, hidden_dim, num_patches = features.shape |
|
grid_size = int(np.sqrt(num_patches)) |
|
if grid_size * grid_size != num_patches: |
|
raise ValueError(f"Number of patches {num_patches} is not a perfect square.") |
|
H, W = grid_size, grid_size |
|
features = features.view(batch_size, hidden_dim, H, W) |
|
elif features.dim() == 4: |
|
batch_size, hidden_dim, H, W = features.shape |
|
else: |
|
raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.") |
|
|
|
|
|
features = features.flatten(2).transpose(1, 2) |
|
|
|
|
|
queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
|
|
attn = torch.matmul(features, queries.transpose(-1, -2)) |
|
attn = torch.softmax(attn, dim=1) |
|
|
|
|
|
output = torch.matmul(attn.transpose(-1, -2), features) |
|
|
|
|
|
bboxes = self.fc_bbox(output) |
|
classes = self.fc_class(output) |
|
|
|
return bboxes, classes |
|
|
|
|
|
def load_model(model_path, device): |
|
""" |
|
Loads the trained detection model. |
|
|
|
Args: |
|
model_path (str): Path to the saved model state dictionary. |
|
device (torch.device): Device to load the model on. |
|
|
|
Returns: |
|
nn.Module: Loaded detection model. |
|
""" |
|
model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.eval() |
|
return model |
|
|
|
|
|
def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5): |
|
""" |
|
Detects text in the input image using the detection model. |
|
|
|
Args: |
|
image (PIL Image): Input image. |
|
model (nn.Module): Trained detection model. |
|
device (torch.device): Device to run the model on. |
|
max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100. |
|
confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5. |
|
|
|
Returns: |
|
PIL Image: Image with detected bounding boxes drawn. |
|
""" |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
pred_bboxes, pred_classes = model(input_tensor) |
|
|
|
|
|
pred_bboxes = pred_bboxes.squeeze(0) |
|
pred_classes = pred_classes.squeeze(0) |
|
pred_classes_sigmoid = torch.sigmoid(pred_classes) |
|
high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1) |
|
selected_indices = high_conf_indices[:max_boxes] |
|
selected_bboxes = pred_bboxes[selected_indices] |
|
|
|
|
|
width, height = image.size |
|
scale_x = width / 224 |
|
scale_y = height / 224 |
|
boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) |
|
|
|
|
|
fig, ax = plt.subplots(1, figsize=(12, 12)) |
|
ax.imshow(image) |
|
|
|
for box in boxes: |
|
polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none') |
|
ax.add_patch(polygon) |
|
|
|
plt.axis('off') |
|
|
|
fig.canvas.draw() |
|
img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
|
plt.close(fig) |
|
|
|
return img_with_boxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model_path = "detection_model.pth" |
|
model = load_model(model_path, device) |
|
print("Model loaded successfully.") |
|
|
|
|
|
def gradio_detect(image): |
|
""" |
|
Gradio interface function for text detection. |
|
|
|
Args: |
|
image (PIL Image): Uploaded image. |
|
|
|
Returns: |
|
PIL Image: Image with detected bounding boxes. |
|
""" |
|
result_image = detect_text(image, model, device) |
|
return result_image |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_detect, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Text Detection with ViT", |
|
description="Upload an image, and the model will detect and highlight text regions.", |
|
examples=[ |
|
|
|
|
|
|
|
], |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
|
|
|
|
|