File size: 7,354 Bytes
739fe18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# app.py

import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
import timm
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Optional: If integrating OCR
# import pytesseract

# Define the Detection Model Architecture
class ViTDetectionModel(nn.Module):
    def __init__(self, num_queries=100, hidden_dim=768):
        """
        Initializes the ViTDetectionModel.
        
        Args:
            num_queries (int, optional): Number of detection queries. Defaults to 100.
            hidden_dim (int, optional): Hidden dimension size. Defaults to 768.
        """
        super(ViTDetectionModel, self).__init__()
        # Configure the ViT model to output features only
        self.vit = timm.create_model(
            'vit_base_patch16_224',
            pretrained=False,  # Set to False since we are loading a trained model
            num_classes=0,               # Disable classification head
            features_only=True,          # Return feature maps
            out_indices=(11,)            # Get the last feature map
        )
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.fc_bbox = nn.Linear(hidden_dim, 8)  # 4 points (x, y) for quadrilateral
        self.fc_class = nn.Linear(hidden_dim, 1)  # Binary classification

    def forward(self, x):
        """
        Forward pass of the detection model.
        
        Args:
            x (Tensor): Input images [batch, 3, H, W].
        
        Returns:
            Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores.
        """
        # Retrieve the feature map
        features = self.vit(x)[0]  # [batch, hidden_dim, H*W]
        
        if features.dim() == 3:
            batch_size, hidden_dim, num_patches = features.shape
            grid_size = int(np.sqrt(num_patches))
            if grid_size * grid_size != num_patches:
                raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
            H, W = grid_size, grid_size
            features = features.view(batch_size, hidden_dim, H, W)
        elif features.dim() == 4:
            batch_size, hidden_dim, H, W = features.shape
        else:
            raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
        
        # Flatten the spatial dimensions
        features = features.flatten(2).transpose(1, 2)  # [batch, H*W, hidden_dim]

        # Prepare query embeddings
        queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)  # [batch, num_queries, hidden_dim]

        # Compute attention weights
        attn = torch.matmul(features, queries.transpose(-1, -2))  # [batch, H*W, num_queries]
        attn = torch.softmax(attn, dim=1)  # Softmax over patches

        # Aggregate features based on attention
        output = torch.matmul(attn.transpose(-1, -2), features)  # [batch, num_queries, hidden_dim]

        # Predict bounding boxes and classes
        bboxes = self.fc_bbox(output)  # [batch, num_queries, 8]
        classes = self.fc_class(output)  # [batch, num_queries, 1]

        return bboxes, classes

# Function to Load the Trained Model
def load_model(model_path, device):
    """
    Loads the trained detection model.
    
    Args:
        model_path (str): Path to the saved model state dictionary.
        device (torch.device): Device to load the model on.
    
    Returns:
        nn.Module: Loaded detection model.
    """
    model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

# Function to Perform Text Detection on an Image
def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
    """
    Detects text in the input image using the detection model.
    
    Args:
        image (PIL Image): Input image.
        model (nn.Module): Trained detection model.
        device (torch.device): Device to run the model on.
        max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100.
        confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5.
    
    Returns:
        PIL Image: Image with detected bounding boxes drawn.
    """
    # Define transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # Preprocess the image
    input_tensor = transform(image).unsqueeze(0).to(device)  # [1, 3, 224, 224]
    
    # Perform detection
    with torch.no_grad():
        pred_bboxes, pred_classes = model(input_tensor)  # [1, num_queries, 8], [1, num_queries, 1]
    
    # Process predictions
    pred_bboxes = pred_bboxes.squeeze(0)  # [num_queries, 8]
    pred_classes = pred_classes.squeeze(0)  # [num_queries, 1]
    pred_classes_sigmoid = torch.sigmoid(pred_classes)
    high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1)
    selected_indices = high_conf_indices[:max_boxes]
    selected_bboxes = pred_bboxes[selected_indices]  # [selected, 8]
    
    # Denormalize bounding boxes to original image size
    width, height = image.size
    scale_x = width / 224
    scale_y = height / 224
    boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4)  # [selected, 8]
    
    # Draw bounding boxes on the image
    fig, ax = plt.subplots(1, figsize=(12, 12))
    ax.imshow(image)
    
    for box in boxes:
        polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(polygon)
    
    plt.axis('off')
    # Convert Matplotlib figure to PIL Image
    fig.canvas.draw()
    img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
    plt.close(fig)
    
    return img_with_boxes

# Optional: If integrating OCR with pytesseract
# def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
#     # Similar to detect_text but includes OCR steps
#     pass

# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "detection_model.pth"  # Ensure this path matches where the model is stored
model = load_model(model_path, device)
print("Model loaded successfully.")

# Define the Gradio Interface Function
def gradio_detect(image):
    """
    Gradio interface function for text detection.
    
    Args:
        image (PIL Image): Uploaded image.
    
    Returns:
        PIL Image: Image with detected bounding boxes.
    """
    result_image = detect_text(image, model, device)
    return result_image

# Create Gradio Interface
iface = gr.Interface(
    fn=gradio_detect,
    inputs=gr.Image(type="pil"),
    outputs=gr.Image(type="pil"),
    title="Text Detection with ViT",
    description="Upload an image, and the model will detect and highlight text regions.",
    examples=[
        # You can add URLs or paths to example images here
        # "https://example.com/image1.jpg",
        # "https://example.com/image2.jpg",
    ],
    allow_flagging="never"
)

# Launch the Gradio App (Optional for local testing)
# if __name__ == "__main__":
#     iface.launch()