Gizachew commited on
Commit
739fe18
·
verified ·
1 Parent(s): aeea157

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import numpy as np
9
+ import gradio as gr
10
+ import timm
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.patches as patches
13
+
14
+ # Optional: If integrating OCR
15
+ # import pytesseract
16
+
17
+ # Define the Detection Model Architecture
18
+ class ViTDetectionModel(nn.Module):
19
+ def __init__(self, num_queries=100, hidden_dim=768):
20
+ """
21
+ Initializes the ViTDetectionModel.
22
+
23
+ Args:
24
+ num_queries (int, optional): Number of detection queries. Defaults to 100.
25
+ hidden_dim (int, optional): Hidden dimension size. Defaults to 768.
26
+ """
27
+ super(ViTDetectionModel, self).__init__()
28
+ # Configure the ViT model to output features only
29
+ self.vit = timm.create_model(
30
+ 'vit_base_patch16_224',
31
+ pretrained=False, # Set to False since we are loading a trained model
32
+ num_classes=0, # Disable classification head
33
+ features_only=True, # Return feature maps
34
+ out_indices=(11,) # Get the last feature map
35
+ )
36
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
37
+ self.fc_bbox = nn.Linear(hidden_dim, 8) # 4 points (x, y) for quadrilateral
38
+ self.fc_class = nn.Linear(hidden_dim, 1) # Binary classification
39
+
40
+ def forward(self, x):
41
+ """
42
+ Forward pass of the detection model.
43
+
44
+ Args:
45
+ x (Tensor): Input images [batch, 3, H, W].
46
+
47
+ Returns:
48
+ Tuple[Tensor, Tensor]: Predicted bounding boxes and class scores.
49
+ """
50
+ # Retrieve the feature map
51
+ features = self.vit(x)[0] # [batch, hidden_dim, H*W]
52
+
53
+ if features.dim() == 3:
54
+ batch_size, hidden_dim, num_patches = features.shape
55
+ grid_size = int(np.sqrt(num_patches))
56
+ if grid_size * grid_size != num_patches:
57
+ raise ValueError(f"Number of patches {num_patches} is not a perfect square.")
58
+ H, W = grid_size, grid_size
59
+ features = features.view(batch_size, hidden_dim, H, W)
60
+ elif features.dim() == 4:
61
+ batch_size, hidden_dim, H, W = features.shape
62
+ else:
63
+ raise ValueError(f"Unexpected feature dimensions: {features.dim()}, expected 3 or 4.")
64
+
65
+ # Flatten the spatial dimensions
66
+ features = features.flatten(2).transpose(1, 2) # [batch, H*W, hidden_dim]
67
+
68
+ # Prepare query embeddings
69
+ queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1) # [batch, num_queries, hidden_dim]
70
+
71
+ # Compute attention weights
72
+ attn = torch.matmul(features, queries.transpose(-1, -2)) # [batch, H*W, num_queries]
73
+ attn = torch.softmax(attn, dim=1) # Softmax over patches
74
+
75
+ # Aggregate features based on attention
76
+ output = torch.matmul(attn.transpose(-1, -2), features) # [batch, num_queries, hidden_dim]
77
+
78
+ # Predict bounding boxes and classes
79
+ bboxes = self.fc_bbox(output) # [batch, num_queries, 8]
80
+ classes = self.fc_class(output) # [batch, num_queries, 1]
81
+
82
+ return bboxes, classes
83
+
84
+ # Function to Load the Trained Model
85
+ def load_model(model_path, device):
86
+ """
87
+ Loads the trained detection model.
88
+
89
+ Args:
90
+ model_path (str): Path to the saved model state dictionary.
91
+ device (torch.device): Device to load the model on.
92
+
93
+ Returns:
94
+ nn.Module: Loaded detection model.
95
+ """
96
+ model = ViTDetectionModel(num_queries=100, hidden_dim=768).to(device)
97
+ model.load_state_dict(torch.load(model_path, map_location=device))
98
+ model.eval()
99
+ return model
100
+
101
+ # Function to Perform Text Detection on an Image
102
+ def detect_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
103
+ """
104
+ Detects text in the input image using the detection model.
105
+
106
+ Args:
107
+ image (PIL Image): Input image.
108
+ model (nn.Module): Trained detection model.
109
+ device (torch.device): Device to run the model on.
110
+ max_boxes (int, optional): Maximum number of bounding boxes to return. Defaults to 100.
111
+ confidence_threshold (float, optional): Threshold to filter detections. Defaults to 0.5.
112
+
113
+ Returns:
114
+ PIL Image: Image with detected bounding boxes drawn.
115
+ """
116
+ # Define transformation
117
+ transform = transforms.Compose([
118
+ transforms.Resize((224, 224)),
119
+ transforms.ToTensor(),
120
+ ])
121
+
122
+ # Preprocess the image
123
+ input_tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
124
+
125
+ # Perform detection
126
+ with torch.no_grad():
127
+ pred_bboxes, pred_classes = model(input_tensor) # [1, num_queries, 8], [1, num_queries, 1]
128
+
129
+ # Process predictions
130
+ pred_bboxes = pred_bboxes.squeeze(0) # [num_queries, 8]
131
+ pred_classes = pred_classes.squeeze(0) # [num_queries, 1]
132
+ pred_classes_sigmoid = torch.sigmoid(pred_classes)
133
+ high_conf_indices = (pred_classes_sigmoid > confidence_threshold).squeeze(1).nonzero(as_tuple=False).squeeze(1)
134
+ selected_indices = high_conf_indices[:max_boxes]
135
+ selected_bboxes = pred_bboxes[selected_indices] # [selected, 8]
136
+
137
+ # Denormalize bounding boxes to original image size
138
+ width, height = image.size
139
+ scale_x = width / 224
140
+ scale_y = height / 224
141
+ boxes = selected_bboxes.cpu().numpy() * np.array([scale_x, scale_y] * 4) # [selected, 8]
142
+
143
+ # Draw bounding boxes on the image
144
+ fig, ax = plt.subplots(1, figsize=(12, 12))
145
+ ax.imshow(image)
146
+
147
+ for box in boxes:
148
+ polygon = patches.Polygon(box.reshape(-1, 2), linewidth=2, edgecolor='r', facecolor='none')
149
+ ax.add_patch(polygon)
150
+
151
+ plt.axis('off')
152
+ # Convert Matplotlib figure to PIL Image
153
+ fig.canvas.draw()
154
+ img_with_boxes = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
155
+ plt.close(fig)
156
+
157
+ return img_with_boxes
158
+
159
+ # Optional: If integrating OCR with pytesseract
160
+ # def detect_and_recognize_text(image, model, device, max_boxes=100, confidence_threshold=0.5):
161
+ # # Similar to detect_text but includes OCR steps
162
+ # pass
163
+
164
+ # Initialize the model
165
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
166
+ model_path = "detection_model.pth" # Ensure this path matches where the model is stored
167
+ model = load_model(model_path, device)
168
+ print("Model loaded successfully.")
169
+
170
+ # Define the Gradio Interface Function
171
+ def gradio_detect(image):
172
+ """
173
+ Gradio interface function for text detection.
174
+
175
+ Args:
176
+ image (PIL Image): Uploaded image.
177
+
178
+ Returns:
179
+ PIL Image: Image with detected bounding boxes.
180
+ """
181
+ result_image = detect_text(image, model, device)
182
+ return result_image
183
+
184
+ # Create Gradio Interface
185
+ iface = gr.Interface(
186
+ fn=gradio_detect,
187
+ inputs=gr.Image(type="pil"),
188
+ outputs=gr.Image(type="pil"),
189
+ title="Text Detection with ViT",
190
+ description="Upload an image, and the model will detect and highlight text regions.",
191
+ examples=[
192
+ # You can add URLs or paths to example images here
193
+ # "https://example.com/image1.jpg",
194
+ # "https://example.com/image2.jpg",
195
+ ],
196
+ allow_flagging="never"
197
+ )
198
+
199
+ # Launch the Gradio App (Optional for local testing)
200
+ # if __name__ == "__main__":
201
+ # iface.launch()