asdfaman commited on
Commit
00dc169
·
verified ·
1 Parent(s): 884b9f4

Create detector.py

Browse files
Files changed (1) hide show
  1. detector.py +365 -0
detector.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import numpy as np
8
+ import onnxruntime as ort
9
+ import pandas as pd
10
+ from typing import Tuple
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH
14
+ from metrics_storage import MetricsStorage
15
+
16
+
17
+ def download_model():
18
+ """Download the model using Hugging Face Hub"""
19
+ # Ensure model directory exists
20
+ os.makedirs(MODEL_DIR, exist_ok=True)
21
+
22
+ try:
23
+ print(f"Downloading model from {REPO_ID}...")
24
+ # Download the model file from Hugging Face Hub
25
+ model_path = hf_hub_download(
26
+ repo_id=REPO_ID,
27
+ filename=FILENAME,
28
+ local_dir=MODEL_DIR,
29
+ force_download=True,
30
+ cache_dir=None,
31
+ )
32
+
33
+ # Move the file to the correct location if it's not there already
34
+ if os.path.exists(model_path) and model_path != MODEL_PATH:
35
+ os.rename(model_path, MODEL_PATH)
36
+
37
+ # Remove empty directories if they exist
38
+ empty_dir = os.path.join(MODEL_DIR, "tune")
39
+ if os.path.exists(empty_dir):
40
+ import shutil
41
+
42
+ shutil.rmtree(empty_dir)
43
+
44
+ print("Model downloaded successfully!")
45
+ return MODEL_PATH
46
+
47
+ except Exception as e:
48
+ print(f"Error downloading model: {e}")
49
+ raise e
50
+
51
+
52
+ class SignatureDetector:
53
+ def __init__(self, model_path: str = MODEL_PATH):
54
+ self.model_path = model_path
55
+ self.classes = ["signature"]
56
+ self.input_width = 640
57
+ self.input_height = 640
58
+
59
+ # Initialize ONNX Runtime session
60
+ options = ort.SessionOptions()
61
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
62
+ self.session = ort.InferenceSession(self.model_path, options)
63
+ self.session.set_providers(
64
+ ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
65
+ )
66
+
67
+ self.metrics_storage = MetricsStorage()
68
+
69
+ def update_metrics(self, inference_time: float) -> None:
70
+ """
71
+ Updates metrics in persistent storage.
72
+ Args:
73
+ inference_time (float): The time taken for inference in milliseconds.
74
+ """
75
+ self.metrics_storage.add_metric(inference_time)
76
+
77
+ def get_metrics(self) -> dict:
78
+ """
79
+ Retrieves current metrics from storage.
80
+ Returns:
81
+ dict: A dictionary containing times, total inferences, average time, and start index.
82
+ """
83
+ times = self.metrics_storage.get_recent_metrics()
84
+ total = self.metrics_storage.get_total_inferences()
85
+ avg = self.metrics_storage.get_average_time()
86
+
87
+ start_index = max(0, total - len(times))
88
+
89
+ return {
90
+ "times": times,
91
+ "total_inferences": total,
92
+ "avg_time": avg,
93
+ "start_index": start_index,
94
+ }
95
+
96
+ def load_initial_metrics(
97
+ self,
98
+ ) -> Tuple[None, str, plt.Figure, plt.Figure, str, str]:
99
+ """
100
+ Loads initial metrics for display.
101
+ Returns:
102
+ tuple: A tuple containing None, total inferences, histogram figure, line figure, average time, and last time.
103
+ """
104
+ metrics = self.get_metrics()
105
+
106
+ if not metrics["times"]:
107
+ return None, None, None, None, None, None
108
+
109
+ hist_data = pd.DataFrame({"Time (ms)": metrics["times"]})
110
+ indices = range(
111
+ metrics["start_index"], metrics["start_index"] + len(metrics["times"])
112
+ )
113
+
114
+ line_data = pd.DataFrame(
115
+ {
116
+ "Inference": indices,
117
+ "Time (ms)": metrics["times"],
118
+ "Mean": [metrics["avg_time"]] * len(metrics["times"]),
119
+ }
120
+ )
121
+
122
+ hist_fig, line_fig = self.create_plots(hist_data, line_data)
123
+
124
+ return (
125
+ None,
126
+ f"{metrics['total_inferences']}",
127
+ hist_fig,
128
+ line_fig,
129
+ f"{metrics['avg_time']:.2f}",
130
+ f"{metrics['times'][-1]:.2f}",
131
+ )
132
+
133
+ def create_plots(
134
+ self, hist_data: pd.DataFrame, line_data: pd.DataFrame
135
+ ) -> Tuple[plt.Figure, plt.Figure]:
136
+ """
137
+ Helper method to create plots.
138
+ Args:
139
+ hist_data (pd.DataFrame): Data for histogram plot.
140
+ line_data (pd.DataFrame): Data for line plot.
141
+ Returns:
142
+ tuple: A tuple containing histogram figure and line figure.
143
+ """
144
+ plt.style.use("dark_background")
145
+
146
+ # Histogram plot
147
+ hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
148
+ hist_ax.set_facecolor("#f0f0f5")
149
+ hist_data.hist(
150
+ bins=20, ax=hist_ax, color="#4F46E5", alpha=0.7, edgecolor="white"
151
+ )
152
+ hist_ax.set_title(
153
+ "Distribution of Inference Times",
154
+ pad=15,
155
+ fontsize=12,
156
+ color="#1f2937",
157
+ )
158
+ hist_ax.set_xlabel("Time (ms)", color="#374151")
159
+ hist_ax.set_ylabel("Frequency", color="#374151")
160
+ hist_ax.tick_params(colors="#4b5563")
161
+ hist_ax.grid(True, linestyle="--", alpha=0.3)
162
+
163
+ # Line plot
164
+ line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
165
+ line_ax.set_facecolor("#f0f0f5")
166
+ line_data.plot(
167
+ x="Inference",
168
+ y="Time (ms)",
169
+ ax=line_ax,
170
+ color="#4F46E5",
171
+ alpha=0.7,
172
+ label="Time",
173
+ )
174
+ line_data.plot(
175
+ x="Inference",
176
+ y="Mean",
177
+ ax=line_ax,
178
+ color="#DC2626",
179
+ linestyle="--",
180
+ label="Mean",
181
+ )
182
+ line_ax.set_title(
183
+ "Inference Time per Execution", pad=15, fontsize=12, color="#1f2937"
184
+ )
185
+ line_ax.set_xlabel("Inference Number", color="#374151")
186
+ line_ax.set_ylabel("Time (ms)", color="#374151")
187
+ line_ax.tick_params(colors="#4b5563")
188
+ line_ax.grid(True, linestyle="--", alpha=0.3)
189
+ line_ax.legend(
190
+ frameon=True, facecolor="#f0f0f5", edgecolor="white", labelcolor="black"
191
+ )
192
+
193
+ hist_fig.tight_layout()
194
+ line_fig.tight_layout()
195
+
196
+ plt.close(hist_fig)
197
+ plt.close(line_fig)
198
+
199
+ return hist_fig, line_fig
200
+
201
+ def preprocess(self, img: Image.Image) -> Tuple[np.ndarray, np.ndarray]:
202
+ """
203
+ Preprocesses the image for inference.
204
+ Args:
205
+ img: The image to process.
206
+ Returns:
207
+ tuple: A tuple containing the processed image data and the original image.
208
+ """
209
+ # Convert PIL Image to cv2 format
210
+ img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
211
+
212
+ self.img_height, self.img_width = img_cv2.shape[:2]
213
+
214
+ # Convert back to RGB for processing
215
+ img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
216
+
217
+ # Resize
218
+ img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height))
219
+
220
+ # Normalize and transpose
221
+ image_data = np.array(img_resized) / 255.0
222
+ image_data = np.transpose(image_data, (2, 0, 1))
223
+ image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
224
+
225
+ return image_data, img_cv2
226
+
227
+ def draw_detections(
228
+ self, img: np.ndarray, box: list, score: float, class_id: int
229
+ ) -> None:
230
+ """
231
+ Draws the detections on the image.
232
+ Args:
233
+ img: The image to draw on.
234
+ box (list): The bounding box coordinates.
235
+ score (float): The confidence score.
236
+ class_id (int): The class ID.
237
+ """
238
+ x1, y1, w, h = box
239
+ self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
240
+ color = self.color_palette[class_id]
241
+
242
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
243
+
244
+ label = f"{self.classes[class_id]}: {score:.2f}"
245
+ (label_width, label_height), _ = cv2.getTextSize(
246
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
247
+ )
248
+
249
+ label_x = x1
250
+ label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
251
+
252
+ cv2.rectangle(
253
+ img,
254
+ (int(label_x), int(label_y - label_height)),
255
+ (int(label_x + label_width), int(label_y + label_height)),
256
+ color,
257
+ cv2.FILLED,
258
+ )
259
+
260
+ cv2.putText(
261
+ img,
262
+ label,
263
+ (int(label_x), int(label_y)),
264
+ cv2.FONT_HERSHEY_SIMPLEX,
265
+ 0.5,
266
+ (0, 0, 0),
267
+ 1,
268
+ cv2.LINE_AA,
269
+ )
270
+
271
+ def postprocess(
272
+ self,
273
+ input_image: np.ndarray,
274
+ output: np.ndarray,
275
+ conf_thres: float,
276
+ iou_thres: float,
277
+ ) -> np.ndarray:
278
+ """
279
+ Postprocesses the output from inference.
280
+ Args:
281
+ input_image: The input image.
282
+ output: The output from inference.
283
+ conf_thres (float): Confidence threshold for detection.
284
+ iou_thres (float): Intersection over Union threshold for detection.
285
+ Returns:
286
+ np.ndarray: The output image with detections drawn
287
+ """
288
+ outputs = np.transpose(np.squeeze(output[0]))
289
+ rows = outputs.shape[0]
290
+
291
+ boxes = []
292
+ scores = []
293
+ class_ids = []
294
+
295
+ x_factor = self.img_width / self.input_width
296
+ y_factor = self.img_height / self.input_height
297
+
298
+ for i in range(rows):
299
+ classes_scores = outputs[i][4:]
300
+ max_score = np.amax(classes_scores)
301
+
302
+ if max_score >= conf_thres:
303
+ class_id = np.argmax(classes_scores)
304
+ x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
305
+
306
+ left = int((x - w / 2) * x_factor)
307
+ top = int((y - h / 2) * y_factor)
308
+ width = int(w * x_factor)
309
+ height = int(h * y_factor)
310
+
311
+ class_ids.append(class_id)
312
+ scores.append(max_score)
313
+ boxes.append([left, top, width, height])
314
+
315
+ indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
316
+
317
+ for i in indices:
318
+ box = boxes[i]
319
+ score = scores[i]
320
+ class_id = class_ids[i]
321
+ self.draw_detections(input_image, box, score, class_id)
322
+
323
+ return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
324
+
325
+ def detect(
326
+ self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
327
+ ) -> Tuple[Image.Image, dict]:
328
+ """
329
+ Detects signatures in the given image.
330
+ Args:
331
+ image: The image to process.
332
+ conf_thres (float): Confidence threshold for detection.
333
+ iou_thres (float): Intersection over Union threshold for detection.
334
+ Returns:
335
+ tuple: A tuple containing the output image and metrics.
336
+ """
337
+ # Preprocess the image
338
+ img_data, original_image = self.preprocess(image)
339
+
340
+ # Run inference
341
+ start_time = time.time()
342
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data})
343
+ inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds
344
+
345
+ # Postprocess the results
346
+ output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres)
347
+
348
+ self.update_metrics(inference_time)
349
+
350
+ return output_image, self.get_metrics()
351
+
352
+ def detect_example(
353
+ self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
354
+ ) -> Image.Image:
355
+ """
356
+ Wrapper method for examples that returns only the image.
357
+ Args:
358
+ image: The image to process.
359
+ conf_thres (float): Confidence threshold for detection.
360
+ iou_thres (float): Intersection over Union threshold for detection.
361
+ Returns:
362
+ The output image.
363
+ """
364
+ output_image, _ = self.detect(image, conf_thres, iou_thres)
365
+ return output_image