huytofu92 commited on
Commit
3098bb4
·
1 Parent(s): 5f9faf7

Fix image shape issue

Browse files
Files changed (1) hide show
  1. vlm_tools.py +46 -20
vlm_tools.py CHANGED
@@ -11,14 +11,14 @@ from PIL import Image
11
  from langchain_core.tools import tool as langchain_tool
12
  from smolagents.tools import Tool, tool
13
 
14
- def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
15
  """
16
  Pre-process an image for YOLO model
17
  Args:
18
  image: The image in base64 format to process
19
  input_size: The size to which the image should be resized
20
  Returns:
21
- The pre-processed image as a numpy array
22
  """
23
  try:
24
  # Decode base64 image
@@ -41,9 +41,16 @@ def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
41
  if img is None:
42
  raise ValueError("Failed to resize image")
43
 
 
 
 
 
 
 
44
  # Convert BGR to RGB and normalize
45
- img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to CHW
46
- img = np.expand_dims(img, axis=0)
 
47
  img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
48
 
49
  return img, original_shape
@@ -121,12 +128,12 @@ def get_image_from_file_path(file_path: str)->str:
121
  """
122
  try:
123
  # Debug prints for original path
124
- print(f"Original file_path: {file_path}")
125
- print(f"Original path exists: {os.path.exists(file_path)}")
126
- if os.path.exists(file_path):
127
- print(f"Original path is file: {os.path.isfile(file_path)}")
128
- print(f"Original path permissions: {oct(os.stat(file_path).st_mode)[-3:]}")
129
- print(f"Original path absolute: {os.path.abspath(file_path)}")
130
 
131
  # Try reading with cv2
132
  img = cv2.imread(file_path)
@@ -148,12 +155,12 @@ def get_image_from_file_path(file_path: str)->str:
148
  adjusted_path = os.path.join(current_file_dir, file_path)
149
 
150
  # Debug prints for adjusted path
151
- print(f"Adjusted file_path: {adjusted_path}")
152
- print(f"Adjusted path exists: {os.path.exists(adjusted_path)}")
153
- if os.path.exists(adjusted_path):
154
- print(f"Adjusted path is file: {os.path.isfile(adjusted_path)}")
155
- print(f"Adjusted path permissions: {oct(os.stat(adjusted_path).st_mode)[-3:]}")
156
- print(f"Adjusted path absolute: {os.path.abspath(adjusted_path)}")
157
 
158
  # Try reading with cv2
159
  img = cv2.imread(adjusted_path)
@@ -305,6 +312,11 @@ class ObjectDetectionTool(Tool):
305
  self.onnx_path = onnx_path
306
  self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
307
 
 
 
 
 
 
308
  # Load class labels
309
  self.classes = [
310
  'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
@@ -332,14 +344,28 @@ class ObjectDetectionTool(Tool):
332
  # Preprocess the image
333
  img, original_shape = pre_processing(image)
334
 
 
 
 
 
335
  # Create blob and run inference
336
- blob = dnn.blobFromImage(img[0], 0.00392, (416, 416), (0, 0, 0), True, crop=False)
337
- onnx_input = {self.onnx_model.get_inputs()[0].name: blob}
 
 
 
 
 
 
 
 
 
338
  onnx_output = self.onnx_model.run(None, onnx_input)
339
 
340
  # Handle shape mismatch by transposing if needed
341
- if onnx_output[0].shape[1] == 255: # If in NCHW format
342
- onnx_output = [onnx_output[0].transpose(0, 2, 3, 1)] # Convert to NHWC
 
343
 
344
  # Post-process the output
345
  objects = post_processing(onnx_output, self.classes, original_shape)
 
11
  from langchain_core.tools import tool as langchain_tool
12
  from smolagents.tools import Tool, tool
13
 
14
+ def pre_processing(image: str, input_size=(416, 416))->tuple:
15
  """
16
  Pre-process an image for YOLO model
17
  Args:
18
  image: The image in base64 format to process
19
  input_size: The size to which the image should be resized
20
  Returns:
21
+ tuple: (processed_image, original_shape)
22
  """
23
  try:
24
  # Decode base64 image
 
41
  if img is None:
42
  raise ValueError("Failed to resize image")
43
 
44
+ # Ensure image is in BGR format (3 channels)
45
+ if len(img.shape) == 2: # If grayscale
46
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
+ elif img.shape[2] == 4: # If RGBA
48
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
49
+
50
  # Convert BGR to RGB and normalize
51
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # More reliable than array slicing
52
+ img = img.transpose(2, 0, 1) # HWC to CHW
53
+ img = np.expand_dims(img, axis=0) # Add batch dimension
54
  img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
55
 
56
  return img, original_shape
 
128
  """
129
  try:
130
  # Debug prints for original path
131
+ # print(f"Original file_path: {file_path}")
132
+ # print(f"Original path exists: {os.path.exists(file_path)}")
133
+ # if os.path.exists(file_path):
134
+ # print(f"Original path is file: {os.path.isfile(file_path)}")
135
+ # print(f"Original path permissions: {oct(os.stat(file_path).st_mode)[-3:]}")
136
+ # print(f"Original path absolute: {os.path.abspath(file_path)}")
137
 
138
  # Try reading with cv2
139
  img = cv2.imread(file_path)
 
155
  adjusted_path = os.path.join(current_file_dir, file_path)
156
 
157
  # Debug prints for adjusted path
158
+ # print(f"Adjusted file_path: {adjusted_path}")
159
+ # print(f"Adjusted path exists: {os.path.exists(adjusted_path)}")
160
+ # if os.path.exists(adjusted_path):
161
+ # print(f"Adjusted path is file: {os.path.isfile(adjusted_path)}")
162
+ # print(f"Adjusted path permissions: {oct(os.stat(adjusted_path).st_mode)[-3:]}")
163
+ # print(f"Adjusted path absolute: {os.path.abspath(adjusted_path)}")
164
 
165
  # Try reading with cv2
166
  img = cv2.imread(adjusted_path)
 
312
  self.onnx_path = onnx_path
313
  self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
314
 
315
+ # Get model input details
316
+ self.input_name = self.onnx_model.get_inputs()[0].name
317
+ self.input_shape = self.onnx_model.get_inputs()[0].shape
318
+ print(f"Model input shape: {self.input_shape}")
319
+
320
  # Load class labels
321
  self.classes = [
322
  'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
 
344
  # Preprocess the image
345
  img, original_shape = pre_processing(image)
346
 
347
+ # Verify input shape
348
+ if len(img.shape) != 4: # Should be NCHW
349
+ raise ValueError(f"Invalid input shape: {img.shape}, expected NCHW format")
350
+
351
  # Create blob and run inference
352
+ blob = cv2.dnn.blobFromImage(
353
+ img[0].transpose(1, 2, 0), # Convert back to HWC for blobFromImage
354
+ 1/255.0, # Scale factor
355
+ (416, 416), # Size
356
+ (0, 0, 0), # Mean
357
+ True, # SwapRB
358
+ crop=False
359
+ )
360
+
361
+ # Run inference
362
+ onnx_input = {self.input_name: blob}
363
  onnx_output = self.onnx_model.run(None, onnx_input)
364
 
365
  # Handle shape mismatch by transposing if needed
366
+ if len(onnx_output[0].shape) == 4: # If in NCHW format
367
+ if onnx_output[0].shape[1] == 255: # If channels first
368
+ onnx_output = [onnx_output[0].transpose(0, 2, 3, 1)] # Convert to NHWC
369
 
370
  # Post-process the output
371
  objects = post_processing(onnx_output, self.classes, original_shape)