# app.py import os import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import gradio as gr import timm import matplotlib.pyplot as plt import matplotlib.patches as patches # Optional: If integrating OCR # import pytesseract # Define the Detection Model Architecture class ViTDetectionModel(nn.Module): def __init__(self, num_queries=100, hidden_dim=768): """ Initializes the ViTDetectionModel. Args: num_queries (int, optional): Number of detection queries. Defaults to 100. hidden_dim (int, optional): Hidden dimension size. Defaults to 768. """ super(ViTDetectionModel, self).__init__() # Configure the ViT model to output features only self.vit = timm.create_model( 'vit_base_patch16_224', pretrained=False, # Set to False since we are loading a trained model num_classes=0, # Disable classification head features_only=True, # Return feature maps out_indices=(11,) # Get the last feature map ) self.query_embed = nn.Embedding(num_queries, hidden_dim) self.fc_bbox = nn.Linear(hidden_dim, 8) # 4 points (x, y) for quadrilateral self.fc_class = nn.Linear(hidden_dim, 1) # Binary classification def forward(self, x): """ Forward pass of the detection model. Args: x (Tensor): Input images [batch, 3, H, W]. Returns: Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores. """ # Retrieve the feature map features = self.vit(x)[0] # [batch, hidden_dim, H*W] if features.dim() == 3: batch_size, hidden_dim, num_patches = features.shape grid_size = int(np.sqrt(num_patches)) if grid_size * grid_size != num_patches: raise ValueError(f"Number of patches {num_patches} is not a perfect square.") H, W = grid_size, grid_size features = features.view(batch_size, hidden_dim, H, W) elif features.dim() == 4: batch_size, hidden_dim, H, W = features.shape else: raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.") # Flatten the spatial dimensions features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim] # Prepare query embeddings queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) # [batch, num_queries, hidden_dim] # Compute attention weights attn = torch.matmul(features, queries.transpose(-1, -2)) # [batch, H*W, num_queries] attn = torch.softmax(attn, dim=1) # Softmax over patches # Aggregate features based on attention output = torch.matmul(attn.transpose(-1, -2), features) # [batch, num_queries, hidden_dim] # Predict bounding boxes and classes bboxes = self.fc_bbox(output) # [batch, num_queries, 8] classes = self.fc_class(output) # [batch, num_queries, 1] return bboxes, classes # Function to Load the Trained Model def load_model(model_path, device): """ Loads the trained detection model. Args: model_path (str): Path to the saved model state dictionary. device (torch.device): Device to load the model on. Returns: nn.Module: Loaded detection model. """ model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() return model # Function to Perform Text Detection on an Image def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5): """ Detects text in the input image using the detection model. Args: image (PIL Image): Input image. model (nn.Module): Trained detection model. device (torch.device): Device to run the model on. max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100. confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5. Returns: PIL Image: Image with detected bounding boxes drawn. """ # Define transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Preprocess the image input_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224] # Perform detection with torch.no_grad(): pred_bboxes, pred_classes = model(input_tensor) # [1, num_queries, 8], [1, num_queries, 1] # Process predictions pred_bboxes = pred_bboxes.squeeze(0) # [num_queries, 8] pred_classes = pred_classes.squeeze(0) # [num_queries, 1] pred_classes_sigmoid = torch.sigmoid(pred_classes) high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1) selected_indices = high_conf_indices[:max_boxes] selected_bboxes = pred_bboxes[selected_indices] # [selected, 8] # Denormalize bounding boxes to original image size width, height = image.size scale_x = width / 224 scale_y = height / 224 boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) # [selected, 8] # Draw bounding boxes on the image fig, ax = plt.subplots(1, figsize=(12, 12)) ax.imshow(image) for box in boxes: polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none') ax.add_patch(polygon) plt.axis('off') # Convert Matplotlib figure to PIL Image fig.canvas.draw() img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) plt.close(fig) return img_with_boxes # Optional: If integrating OCR with pytesseract # def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5): # # Similar to detect_text but includes OCR steps # pass # Initialize the model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_path = "detection_model.pth" # Ensure this path matches where the model is stored model = load_model(model_path, device) print("Model loaded successfully.") # Define the Gradio Interface Function def gradio_detect(image): """ Gradio interface function for text detection. Args: image (PIL Image): Uploaded image. Returns: PIL Image: Image with detected bounding boxes. """ result_image = detect_text(image, model, device) return result_image # Create Gradio Interface iface = gr.Interface( fn=gradio_detect, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Text Detection with ViT", description="Upload an image, and the model will detect and highlight text regions.", examples=[ # You can add URLs or paths to example images here # "https://example.com/image1.jpg", # "https://example.com/image2.jpg", ], allow_flagging="never" ) # Launch the Gradio App (Optional for local testing) # if __name__ == "__main__": # iface.launch()