huytofu92 commited on
Commit
e9a4266
·
1 Parent(s): 8615d8c

Fix object detection tool

Browse files
Files changed (1) hide show
  1. vlm_tools.py +82 -42
vlm_tools.py CHANGED
@@ -1,4 +1,5 @@
1
  import cv2
 
2
  import numpy as np
3
  import pytesseract
4
  import requests
@@ -11,7 +12,6 @@ from langchain_core.tools import tool as langchain_tool
11
  from smolagents.tools import Tool, tool
12
 
13
  def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
14
-
15
  """
16
  Pre-process an image for YOLO model
17
  Args:
@@ -20,16 +20,35 @@ def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
20
  Returns:
21
  The pre-processed image as a numpy array
22
  """
23
- image_data = base64.b64decode(image)
24
- np_image = np.frombuffer(image_data, np.uint8)
25
- img = cv2.imdecode(np_image, cv2.IMREAD_COLOR)
26
-
27
- # Resize and normalize the image
28
- img = cv2.resize(img, input_size)
29
- img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to CHW
30
- img = np.expand_dims(img, axis=0)
31
- img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
32
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def post_processing(onnx_output, classes, original_shape, conf_threshold=0.5, nms_threshold=0.4)->list:
35
  """
@@ -62,7 +81,7 @@ def post_processing(onnx_output, classes, original_shape, conf_threshold=0.5, nm
62
  class_ids.append(class_id)
63
 
64
  # Apply non-max suppression
65
- indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold)
66
  detected_objects = []
67
  for i in indices:
68
  i = i[0]
@@ -246,38 +265,59 @@ class ObjectDetectionTool(Tool):
246
  output_type = "any"
247
 
248
  def setup(self):
249
- # Load ONNX model
250
- self.onnx_path = onnx_path
251
- self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
252
-
253
- # Load class labels - using a predefined list since we can't use open()
254
- # These are the standard COCO dataset classes that YOLOv3 uses
255
- self.classes = [
256
- 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
257
- 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
258
- 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
259
- 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
260
- 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
261
- 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
262
- 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
263
- 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
264
- 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
265
- 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
266
- ]
 
 
267
 
268
  def forward(self, images: any)->any:
269
- detected_objects = []
270
- for image in images:
271
- img = pre_processing(image)
272
-
273
- # Preprocess the image
274
- blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
275
- onnx_input = {self.onnx_model.get_inputs()[0].name: blob}
276
- onnx_output = self.onnx_model.run(None, onnx_input)
277
-
278
- detected_objects.append(post_processing(onnx_output, self.classes, img.shape))
279
-
280
- return detected_objects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  class OCRTool(Tool):
283
  description = """
 
1
  import cv2
2
+ from cv2 import dnn
3
  import numpy as np
4
  import pytesseract
5
  import requests
 
12
  from smolagents.tools import Tool, tool
13
 
14
  def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
 
15
  """
16
  Pre-process an image for YOLO model
17
  Args:
 
20
  Returns:
21
  The pre-processed image as a numpy array
22
  """
23
+ try:
24
+ # Decode base64 image
25
+ image_data = base64.b64decode(image)
26
+ np_image = np.frombuffer(image_data, np.uint8)
27
+ img = cv2.imdecode(np_image, cv2.IMREAD_COLOR)
28
+
29
+ if img is None:
30
+ raise ValueError("Failed to decode image")
31
+
32
+ # Store original shape for post-processing
33
+ original_shape = img.shape[:2] # (height, width)
34
+
35
+ # Ensure input_size is valid
36
+ if not isinstance(input_size, tuple) or len(input_size) != 2:
37
+ input_size = (416, 416)
38
+
39
+ # Resize and normalize the image
40
+ img = cv2.resize(img, input_size, interpolation=cv2.INTER_LINEAR)
41
+ if img is None:
42
+ raise ValueError("Failed to resize image")
43
+
44
+ # Convert BGR to RGB and normalize
45
+ img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to CHW
46
+ img = np.expand_dims(img, axis=0)
47
+ img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
48
+
49
+ return img, original_shape
50
+ except Exception as e:
51
+ raise ValueError(f"Error in pre_processing: {str(e)}")
52
 
53
  def post_processing(onnx_output, classes, original_shape, conf_threshold=0.5, nms_threshold=0.4)->list:
54
  """
 
81
  class_ids.append(class_id)
82
 
83
  # Apply non-max suppression
84
+ indices = dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold)
85
  detected_objects = []
86
  for i in indices:
87
  i = i[0]
 
265
  output_type = "any"
266
 
267
  def setup(self):
268
+ try:
269
+ # Load ONNX model
270
+ self.onnx_path = onnx_path
271
+ self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
272
+
273
+ # Load class labels
274
+ self.classes = [
275
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
276
+ 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
277
+ 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
278
+ 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
279
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
280
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
281
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
282
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
283
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
284
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
285
+ ]
286
+ except Exception as e:
287
+ raise RuntimeError(f"Error in setup: {str(e)}")
288
 
289
  def forward(self, images: any)->any:
290
+ try:
291
+ if not isinstance(images, list):
292
+ images = [images] # Convert single image to list
293
+
294
+ detected_objects = []
295
+ for image in images:
296
+ try:
297
+ # Preprocess the image
298
+ img, original_shape = pre_processing(image)
299
+
300
+ # Create blob and run inference
301
+ blob = dnn.blobFromImage(img[0], 0.00392, (416, 416), (0, 0, 0), True, crop=False)
302
+ onnx_input = {self.onnx_model.get_inputs()[0].name: blob}
303
+ onnx_output = self.onnx_model.run(None, onnx_input)
304
+
305
+ # Handle shape mismatch by transposing if needed
306
+ if onnx_output[0].shape[1] == 255: # If in NCHW format
307
+ onnx_output = [onnx_output[0].transpose(0, 2, 3, 1)] # Convert to NHWC
308
+
309
+ # Post-process the output
310
+ objects = post_processing(onnx_output, self.classes, original_shape)
311
+ detected_objects.append(objects)
312
+
313
+ except Exception as e:
314
+ print(f"Error processing image: {str(e)}")
315
+ detected_objects.append([]) # Add empty list for failed image
316
+
317
+ return detected_objects
318
+
319
+ except Exception as e:
320
+ raise RuntimeError(f"Error in forward pass: {str(e)}")
321
 
322
  class OCRTool(Tool):
323
  description = """