Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,6 @@ from typing import Union, List
|
|
5 |
from pathlib import Path
|
6 |
from PIL import Image
|
7 |
import torch
|
8 |
-
from ultralytics.utils.plotting import save_one_box
|
9 |
-
from ultralytics.engine.results import Results
|
10 |
-
from ultralytics import YOLOWorld
|
11 |
|
12 |
# Function to resize images
|
13 |
|
@@ -36,90 +33,15 @@ def stitch_images(image_paths, scale_percent=50):
|
|
36 |
print(f"Stitching failed with status code: {status}")
|
37 |
return None
|
38 |
|
39 |
-
# YOLO detection function
|
40 |
-
|
41 |
-
def yolo_detect(
|
42 |
-
image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
|
43 |
-
classes: List[str],
|
44 |
-
) -> np.ndarray:
|
45 |
-
detector = YOLOWorld("yolov8x-worldv2.pt")
|
46 |
-
detector.set_classes(classes)
|
47 |
-
results: List[Results] = detector.predict(image)
|
48 |
-
|
49 |
-
boxes = []
|
50 |
-
for result in results:
|
51 |
-
boxes.append(
|
52 |
-
save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False)
|
53 |
-
)
|
54 |
-
|
55 |
-
del detector
|
56 |
-
return boxes[0]
|
57 |
-
|
58 |
-
# Function to predict and process image
|
59 |
-
|
60 |
-
def predict(image, offset_inches):
|
61 |
-
try:
|
62 |
-
drawer_img = yolo_detect(image, ["box"])
|
63 |
-
shrunked_img = make_square(shrink_bbox(drawer_img, 0.8))
|
64 |
-
return shrunked_img
|
65 |
-
except:
|
66 |
-
raise Exception("Unable to DETECT DRAWER, please take another picture with different magnification level!")
|
67 |
-
|
68 |
-
# Function to shrink bounding box
|
69 |
-
|
70 |
-
def shrink_bbox(image: np.ndarray, shrink_factor: float):
|
71 |
-
height, width = image.shape[:2]
|
72 |
-
center_x, center_y = width // 2, height // 2
|
73 |
-
|
74 |
-
new_width = int(width * shrink_factor)
|
75 |
-
new_height = int(height * shrink_factor)
|
76 |
-
|
77 |
-
x1 = max(center_x - new_width // 2, 0)
|
78 |
-
y1 = max(center_y - new_height // 2, 0)
|
79 |
-
x2 = min(center_x + new_width // 2, width)
|
80 |
-
y2 = min(center_y + new_height // 2, height)
|
81 |
-
|
82 |
-
cropped_image = image[y1:y2, x1:x2]
|
83 |
-
return cropped_image
|
84 |
-
|
85 |
-
# Function to make image square
|
86 |
-
|
87 |
-
def make_square(img: np.ndarray):
|
88 |
-
height, width = img.shape[:2]
|
89 |
-
max_dim = max(height, width)
|
90 |
-
pad_height = (max_dim - height) // 2
|
91 |
-
pad_width = (max_dim - width) // 2
|
92 |
-
pad_height_extra = max_dim - height - 2 * pad_height
|
93 |
-
pad_width_extra = max_dim - width - 2 * pad_width
|
94 |
-
|
95 |
-
if len(img.shape) == 3:
|
96 |
-
padded = np.pad(
|
97 |
-
img,
|
98 |
-
((pad_height, pad_height + pad_height_extra),
|
99 |
-
(pad_width, pad_width + pad_width_extra),
|
100 |
-
(0, 0)),
|
101 |
-
mode="edge"
|
102 |
-
)
|
103 |
-
else:
|
104 |
-
padded = np.pad(
|
105 |
-
img,
|
106 |
-
((pad_height, pad_height + pad_height_extra),
|
107 |
-
(pad_width, pad_width + pad_width_extra)),
|
108 |
-
mode="edge"
|
109 |
-
)
|
110 |
-
|
111 |
-
return padded
|
112 |
-
|
113 |
# Main image processing function
|
114 |
|
115 |
-
def process_image(image_paths, scale_percent=50
|
116 |
stitched_image = stitch_images(image_paths, scale_percent)
|
117 |
|
118 |
if stitched_image is not None:
|
119 |
try:
|
120 |
stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
|
121 |
-
|
122 |
-
return final_image
|
123 |
except Exception as e:
|
124 |
print(str(e))
|
125 |
return stitched_image
|
|
|
5 |
from pathlib import Path
|
6 |
from PIL import Image
|
7 |
import torch
|
|
|
|
|
|
|
8 |
|
9 |
# Function to resize images
|
10 |
|
|
|
33 |
print(f"Stitching failed with status code: {status}")
|
34 |
return None
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# Main image processing function
|
37 |
|
38 |
+
def process_image(image_paths, scale_percent=50):
|
39 |
stitched_image = stitch_images(image_paths, scale_percent)
|
40 |
|
41 |
if stitched_image is not None:
|
42 |
try:
|
43 |
stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
|
44 |
+
return stitched_image_rgb
|
|
|
45 |
except Exception as e:
|
46 |
print(str(e))
|
47 |
return stitched_image
|