Spaces:
Sleeping
Sleeping
Fix image shape issue
Browse files- vlm_tools.py +46 -20
vlm_tools.py
CHANGED
@@ -11,14 +11,14 @@ from PIL import Image
|
|
11 |
from langchain_core.tools import tool as langchain_tool
|
12 |
from smolagents.tools import Tool, tool
|
13 |
|
14 |
-
def pre_processing(image: str, input_size=(416, 416))->
|
15 |
"""
|
16 |
Pre-process an image for YOLO model
|
17 |
Args:
|
18 |
image: The image in base64 format to process
|
19 |
input_size: The size to which the image should be resized
|
20 |
Returns:
|
21 |
-
|
22 |
"""
|
23 |
try:
|
24 |
# Decode base64 image
|
@@ -41,9 +41,16 @@ def pre_processing(image: str, input_size=(416, 416))->np.ndarray:
|
|
41 |
if img is None:
|
42 |
raise ValueError("Failed to resize image")
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Convert BGR to RGB and normalize
|
45 |
-
img =
|
46 |
-
img =
|
|
|
47 |
img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
48 |
|
49 |
return img, original_shape
|
@@ -121,12 +128,12 @@ def get_image_from_file_path(file_path: str)->str:
|
|
121 |
"""
|
122 |
try:
|
123 |
# Debug prints for original path
|
124 |
-
print(f"Original file_path: {file_path}")
|
125 |
-
print(f"Original path exists: {os.path.exists(file_path)}")
|
126 |
-
if os.path.exists(file_path):
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
|
131 |
# Try reading with cv2
|
132 |
img = cv2.imread(file_path)
|
@@ -148,12 +155,12 @@ def get_image_from_file_path(file_path: str)->str:
|
|
148 |
adjusted_path = os.path.join(current_file_dir, file_path)
|
149 |
|
150 |
# Debug prints for adjusted path
|
151 |
-
print(f"Adjusted file_path: {adjusted_path}")
|
152 |
-
print(f"Adjusted path exists: {os.path.exists(adjusted_path)}")
|
153 |
-
if os.path.exists(adjusted_path):
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
# Try reading with cv2
|
159 |
img = cv2.imread(adjusted_path)
|
@@ -305,6 +312,11 @@ class ObjectDetectionTool(Tool):
|
|
305 |
self.onnx_path = onnx_path
|
306 |
self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
|
307 |
|
|
|
|
|
|
|
|
|
|
|
308 |
# Load class labels
|
309 |
self.classes = [
|
310 |
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
@@ -332,14 +344,28 @@ class ObjectDetectionTool(Tool):
|
|
332 |
# Preprocess the image
|
333 |
img, original_shape = pre_processing(image)
|
334 |
|
|
|
|
|
|
|
|
|
335 |
# Create blob and run inference
|
336 |
-
blob = dnn.blobFromImage(
|
337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
onnx_output = self.onnx_model.run(None, onnx_input)
|
339 |
|
340 |
# Handle shape mismatch by transposing if needed
|
341 |
-
if onnx_output[0].shape
|
342 |
-
|
|
|
343 |
|
344 |
# Post-process the output
|
345 |
objects = post_processing(onnx_output, self.classes, original_shape)
|
|
|
11 |
from langchain_core.tools import tool as langchain_tool
|
12 |
from smolagents.tools import Tool, tool
|
13 |
|
14 |
+
def pre_processing(image: str, input_size=(416, 416))->tuple:
|
15 |
"""
|
16 |
Pre-process an image for YOLO model
|
17 |
Args:
|
18 |
image: The image in base64 format to process
|
19 |
input_size: The size to which the image should be resized
|
20 |
Returns:
|
21 |
+
tuple: (processed_image, original_shape)
|
22 |
"""
|
23 |
try:
|
24 |
# Decode base64 image
|
|
|
41 |
if img is None:
|
42 |
raise ValueError("Failed to resize image")
|
43 |
|
44 |
+
# Ensure image is in BGR format (3 channels)
|
45 |
+
if len(img.shape) == 2: # If grayscale
|
46 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
47 |
+
elif img.shape[2] == 4: # If RGBA
|
48 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
|
49 |
+
|
50 |
# Convert BGR to RGB and normalize
|
51 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # More reliable than array slicing
|
52 |
+
img = img.transpose(2, 0, 1) # HWC to CHW
|
53 |
+
img = np.expand_dims(img, axis=0) # Add batch dimension
|
54 |
img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
55 |
|
56 |
return img, original_shape
|
|
|
128 |
"""
|
129 |
try:
|
130 |
# Debug prints for original path
|
131 |
+
# print(f"Original file_path: {file_path}")
|
132 |
+
# print(f"Original path exists: {os.path.exists(file_path)}")
|
133 |
+
# if os.path.exists(file_path):
|
134 |
+
# print(f"Original path is file: {os.path.isfile(file_path)}")
|
135 |
+
# print(f"Original path permissions: {oct(os.stat(file_path).st_mode)[-3:]}")
|
136 |
+
# print(f"Original path absolute: {os.path.abspath(file_path)}")
|
137 |
|
138 |
# Try reading with cv2
|
139 |
img = cv2.imread(file_path)
|
|
|
155 |
adjusted_path = os.path.join(current_file_dir, file_path)
|
156 |
|
157 |
# Debug prints for adjusted path
|
158 |
+
# print(f"Adjusted file_path: {adjusted_path}")
|
159 |
+
# print(f"Adjusted path exists: {os.path.exists(adjusted_path)}")
|
160 |
+
# if os.path.exists(adjusted_path):
|
161 |
+
# print(f"Adjusted path is file: {os.path.isfile(adjusted_path)}")
|
162 |
+
# print(f"Adjusted path permissions: {oct(os.stat(adjusted_path).st_mode)[-3:]}")
|
163 |
+
# print(f"Adjusted path absolute: {os.path.abspath(adjusted_path)}")
|
164 |
|
165 |
# Try reading with cv2
|
166 |
img = cv2.imread(adjusted_path)
|
|
|
312 |
self.onnx_path = onnx_path
|
313 |
self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
|
314 |
|
315 |
+
# Get model input details
|
316 |
+
self.input_name = self.onnx_model.get_inputs()[0].name
|
317 |
+
self.input_shape = self.onnx_model.get_inputs()[0].shape
|
318 |
+
print(f"Model input shape: {self.input_shape}")
|
319 |
+
|
320 |
# Load class labels
|
321 |
self.classes = [
|
322 |
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
|
|
344 |
# Preprocess the image
|
345 |
img, original_shape = pre_processing(image)
|
346 |
|
347 |
+
# Verify input shape
|
348 |
+
if len(img.shape) != 4: # Should be NCHW
|
349 |
+
raise ValueError(f"Invalid input shape: {img.shape}, expected NCHW format")
|
350 |
+
|
351 |
# Create blob and run inference
|
352 |
+
blob = cv2.dnn.blobFromImage(
|
353 |
+
img[0].transpose(1, 2, 0), # Convert back to HWC for blobFromImage
|
354 |
+
1/255.0, # Scale factor
|
355 |
+
(416, 416), # Size
|
356 |
+
(0, 0, 0), # Mean
|
357 |
+
True, # SwapRB
|
358 |
+
crop=False
|
359 |
+
)
|
360 |
+
|
361 |
+
# Run inference
|
362 |
+
onnx_input = {self.input_name: blob}
|
363 |
onnx_output = self.onnx_model.run(None, onnx_input)
|
364 |
|
365 |
# Handle shape mismatch by transposing if needed
|
366 |
+
if len(onnx_output[0].shape) == 4: # If in NCHW format
|
367 |
+
if onnx_output[0].shape[1] == 255: # If channels first
|
368 |
+
onnx_output = [onnx_output[0].transpose(0, 2, 3, 1)] # Convert to NHWC
|
369 |
|
370 |
# Post-process the output
|
371 |
objects = post_processing(onnx_output, self.classes, original_shape)
|