ammariii08 commited on
Commit
f32744f
·
verified ·
1 Parent(s): 6109c93

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +578 -0
app.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import gc
4
+ import base64
5
+ import io
6
+ import time
7
+ import shutil
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import ezdxf
12
+ import gradio as gr
13
+ from PIL import Image, ImageEnhance
14
+ from pathlib import Path
15
+ from typing import List, Union
16
+ from ultralytics import YOLOWorld, YOLO
17
+ from ultralytics.engine.results import Results
18
+ from ultralytics.utils.plotting import save_one_box
19
+ from transformers import AutoModelForImageSegmentation
20
+ from torchvision import transforms
21
+ from scalingtestupdated import calculate_scaling_factor
22
+ from shapely.geometry import Polygon, Point, MultiPolygon
23
+ from scipy.interpolate import splprep, splev
24
+ from scipy.ndimage import gaussian_filter1d
25
+ from u2net import U2NETP
26
+
27
+ # ---------------------
28
+ # Create a cache folder for models
29
+ # ---------------------
30
+ CACHE_DIR = os.path.join(os.path.dirname(__file__), ".cache")
31
+ os.makedirs(CACHE_DIR, exist_ok=True)
32
+
33
+ # ---------------------
34
+ # Custom Exceptions
35
+ # ---------------------
36
+ class DrawerNotDetectedError(Exception):
37
+ """Raised when the drawer cannot be detected in the image"""
38
+ pass
39
+
40
+ class ReferenceBoxNotDetectedError(Exception):
41
+ """Raised when the reference box cannot be detected in the image"""
42
+ pass
43
+
44
+ # ---------------------
45
+ # Global Model Initialization with caching and print statements
46
+ # ---------------------
47
+ print("Loading YOLOWorld model...")
48
+ start_time = time.time()
49
+ yolo_model_path = os.path.join(CACHE_DIR, "yolov8x-worldv2.pt")
50
+ if not os.path.exists(yolo_model_path):
51
+ print("Caching YOLOWorld model to", yolo_model_path)
52
+ shutil.copy("yolov8x-worldv2.pt", yolo_model_path)
53
+ drawer_detector_global = YOLOWorld(yolo_model_path)
54
+ drawer_detector_global.set_classes(["box"])
55
+ print("YOLOWorld model loaded in {:.2f} seconds".format(time.time() - start_time))
56
+
57
+ print("Loading YOLO reference model...")
58
+ start_time = time.time()
59
+ reference_model_path = os.path.join(CACHE_DIR, "last.pt")
60
+ if not os.path.exists(reference_model_path):
61
+ print("Caching YOLO reference model to", reference_model_path)
62
+ shutil.copy("last.pt", reference_model_path)
63
+ reference_detector_global = YOLO(reference_model_path)
64
+ print("YOLO reference model loaded in {:.2f} seconds".format(time.time() - start_time))
65
+
66
+ print("Loading U²-Net model for reference background removal (U2NETP)...")
67
+ start_time = time.time()
68
+ u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
69
+ if not os.path.exists(u2net_model_path):
70
+ print("Caching U²-Net model to", u2net_model_path)
71
+ shutil.copy("u2netp.pth", u2net_model_path)
72
+ u2net_global = U2NETP(3, 1)
73
+ u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
74
+ device = "cpu"
75
+ u2net_global.to(device)
76
+ u2net_global.eval()
77
+ print("U²-Net model loaded in {:.2f} seconds".format(time.time() - start_time))
78
+
79
+ print("Loading BiRefNet model...")
80
+ start_time = time.time()
81
+ birefnet_global = AutoModelForImageSegmentation.from_pretrained(
82
+ "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
83
+ )
84
+ torch.set_float32_matmul_precision("high")
85
+ birefnet_global.to(device)
86
+ birefnet_global.eval()
87
+ print("BiRefNet model loaded in {:.2f} seconds".format(time.time() - start_time))
88
+
89
+ # Define transform for BiRefNet
90
+ transform_image_global = transforms.Compose([
91
+ transforms.Resize((1024, 1024)),
92
+ transforms.ToTensor(),
93
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
94
+ ])
95
+
96
+ # ---------------------
97
+ # Model Reload Function (if needed)
98
+ # ---------------------
99
+ def unload_and_reload_models():
100
+ global drawer_detector_global, reference_detector_global, birefnet_global, u2net_global
101
+ print("Reloading models...")
102
+ start_time = time.time()
103
+ del drawer_detector_global, reference_detector_global, birefnet_global, u2net_global
104
+ gc.collect()
105
+ if torch.cuda.is_available():
106
+ torch.cuda.empty_cache()
107
+ gc.collect()
108
+ new_drawer_detector = YOLOWorld(os.path.join(CACHE_DIR, "yolov8x-worldv2.pt"))
109
+ new_drawer_detector.set_classes(["box"])
110
+ new_reference_detector = YOLO(os.path.join(CACHE_DIR, "last.pt"))
111
+ new_birefnet = AutoModelForImageSegmentation.from_pretrained(
112
+ "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
113
+ )
114
+ new_birefnet.to(device)
115
+ new_birefnet.eval()
116
+ new_u2net = U2NETP(3, 1)
117
+ new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
118
+ new_u2net.to(device)
119
+ new_u2net.eval()
120
+ drawer_detector_global = new_drawer_detector
121
+ reference_detector_global = new_reference_detector
122
+ birefnet_global = new_birefnet
123
+ u2net_global = new_u2net
124
+ print("Models reloaded in {:.2f} seconds".format(time.time() - start_time))
125
+
126
+ # ---------------------
127
+ # Helper Function: resize_img (defined once)
128
+ # ---------------------
129
+ def resize_img(img: np.ndarray, resize_dim):
130
+ return np.array(Image.fromarray(img).resize(resize_dim))
131
+
132
+ # ---------------------
133
+ # Other Helper Functions for Detection & Processing
134
+ # ---------------------
135
+ def yolo_detect(image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor]) -> np.ndarray:
136
+ t = time.time()
137
+ results: List[Results] = drawer_detector_global.predict(image)
138
+ if not results or len(results) == 0 or len(results[0].boxes) == 0:
139
+ raise DrawerNotDetectedError("Drawer not detected in the image.")
140
+ print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
141
+ return save_one_box(results[0].cpu().boxes.xyxy, im=results[0].orig_img, save=False)
142
+
143
+ def detect_reference_square(img: np.ndarray):
144
+ t = time.time()
145
+ res = reference_detector_global.predict(img, conf=0.45)
146
+ if not res or len(res) == 0 or len(res[0].boxes) == 0:
147
+ raise ReferenceBoxNotDetectedError("Reference box not detected in the image.")
148
+ print("Reference detection completed in {:.2f} seconds".format(time.time() - t))
149
+ return (
150
+ save_one_box(res[0].cpu().boxes.xyxy, res[0].orig_img, save=False),
151
+ res[0].cpu().boxes.xyxy[0]
152
+ )
153
+
154
+ # Use U2NETP for reference background removal.
155
+ def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
156
+ t = time.time()
157
+ image_pil = Image.fromarray(image)
158
+ transform_u2netp = transforms.Compose([
159
+ transforms.Resize((320, 320)),
160
+ transforms.ToTensor(),
161
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
162
+ ])
163
+ input_tensor = transform_u2netp(image_pil).unsqueeze(0).to("cpu")
164
+ with torch.no_grad():
165
+ outputs = u2net_global(input_tensor)
166
+ pred = outputs[0]
167
+ pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
168
+ pred_np = pred.squeeze().cpu().numpy()
169
+ pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
170
+ pred_np = (pred_np * 255).astype(np.uint8)
171
+ print("U2NETP background removal completed in {:.2f} seconds".format(time.time() - t))
172
+ return pred_np
173
+
174
+ # Use BiRefNet for main object background removal.
175
+ def remove_bg(image: np.ndarray) -> np.ndarray:
176
+ t = time.time()
177
+ image_pil = Image.fromarray(image)
178
+ input_images = transform_image_global(image_pil).unsqueeze(0).to("cpu")
179
+ with torch.no_grad():
180
+ preds = birefnet_global(input_images)[-1].sigmoid().cpu()
181
+ pred = preds[0].squeeze()
182
+ pred_pil = transforms.ToPILImage()(pred)
183
+ scale_ratio = 1024 / max(image_pil.size)
184
+ scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
185
+ result = np.array(pred_pil.resize(scaled_size))
186
+ print("BiRefNet background removal completed in {:.2f} seconds".format(time.time() - t))
187
+ return result
188
+
189
+ def make_square(img: np.ndarray):
190
+ height, width = img.shape[:2]
191
+ max_dim = max(height, width)
192
+ pad_height = (max_dim - height) // 2
193
+ pad_width = (max_dim - width) // 2
194
+ pad_height_extra = max_dim - height - 2 * pad_height
195
+ pad_width_extra = max_dim - width - 2 * pad_width
196
+ if len(img.shape) == 3:
197
+ padded = np.pad(img, ((pad_height, pad_height + pad_height_extra),
198
+ (pad_width, pad_width + pad_width_extra),
199
+ (0, 0)), mode="edge")
200
+ else:
201
+ padded = np.pad(img, ((pad_height, pad_height + pad_height_extra),
202
+ (pad_width, pad_width + pad_width_extra)), mode="edge")
203
+ return padded
204
+
205
+ def shrink_bbox(image: np.ndarray, shrink_factor: float):
206
+ height, width = image.shape[:2]
207
+ center_x, center_y = width // 2, height // 2
208
+ new_width = int(width * shrink_factor)
209
+ new_height = int(height * shrink_factor)
210
+ x1 = max(center_x - new_width // 2, 0)
211
+ y1 = max(center_y - new_height // 2, 0)
212
+ x2 = min(center_x + new_width // 2, width)
213
+ y2 = min(center_y + new_height // 2, height)
214
+ return image[y1:y2, x1:x2]
215
+
216
+ def exclude_scaling_box(image: np.ndarray, bbox: np.ndarray, orig_size: tuple, processed_size: tuple, expansion_factor: float = 1.2) -> np.ndarray:
217
+ x_min, y_min, x_max, y_max = map(int, bbox)
218
+ scale_x = processed_size[1] / orig_size[1]
219
+ scale_y = processed_size[0] / orig_size[0]
220
+ x_min = int(x_min * scale_x)
221
+ x_max = int(x_max * scale_x)
222
+ y_min = int(y_min * scale_y)
223
+ y_max = int(y_max * scale_y)
224
+ box_width = x_max - x_min
225
+ box_height = y_max - y_min
226
+ expanded_x_min = max(0, int(x_min - (expansion_factor - 1) * box_width / 2))
227
+ expanded_x_max = min(image.shape[1], int(x_max + (expansion_factor - 1) * box_width / 2))
228
+ expanded_y_min = max(0, int(y_min - (expansion_factor - 1) * box_height / 2))
229
+ expanded_y_max = min(image.shape[0], int(y_max + (expansion_factor - 1) * box_height / 2))
230
+ image[expanded_y_min:expanded_y_max, expanded_x_min:expanded_x_max] = 0
231
+ return image
232
+
233
+ def resample_contour(contour):
234
+ num_points = 1000
235
+ smoothing_factor = 5
236
+ spline_degree = 3
237
+ if len(contour) < spline_degree + 1:
238
+ raise ValueError(f"Contour must have at least {spline_degree + 1} points, but has {len(contour)} points.")
239
+ contour = contour[:, 0, :]
240
+ tck, _ = splprep([contour[:, 0], contour[:, 1]], s=smoothing_factor)
241
+ u = np.linspace(0, 1, num_points)
242
+ resampled_points = splev(u, tck)
243
+ smoothed_x = gaussian_filter1d(resampled_points[0], sigma=1)
244
+ smoothed_y = gaussian_filter1d(resampled_points[1], sigma=1)
245
+ return np.array([smoothed_x, smoothed_y]).T
246
+
247
+ # ---------------------
248
+ # Add the missing extract_outlines function
249
+ # ---------------------
250
+ def extract_outlines(binary_image: np.ndarray) -> (np.ndarray, list):
251
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
252
+ outline_image = np.zeros_like(binary_image)
253
+ cv2.drawContours(outline_image, contours, -1, (255), thickness=2)
254
+ return cv2.bitwise_not(outline_image), contours
255
+
256
+ # ---------------------
257
+ # Functions for Finger Cut Clearance
258
+ # ---------------------
259
+ def union_tool_and_circle(tool_polygon: Polygon, center_inch, circle_diameter=1.0):
260
+ radius = circle_diameter / 2.0
261
+ circle_poly = Point(center_inch).buffer(radius, resolution=64)
262
+ union_poly = tool_polygon.union(circle_poly)
263
+ return union_poly
264
+
265
+ def build_tool_polygon(points_inch):
266
+ return Polygon(points_inch)
267
+
268
+ def polygon_to_exterior_coords(poly: Polygon):
269
+ if poly.geom_type == "MultiPolygon":
270
+ biggest = max(poly.geoms, key=lambda g: g.area)
271
+ poly = biggest
272
+ if not poly.exterior:
273
+ return []
274
+ return list(poly.exterior.coords)
275
+
276
+ def place_finger_cut_randomly(tool_polygon, points_inch, existing_centers, all_polygons, circle_diameter=1.0, min_gap=0.25, max_attempts=30):
277
+ import random
278
+ needed_center_distance = circle_diameter + min_gap
279
+ radius = circle_diameter / 2.0
280
+ for _ in range(max_attempts):
281
+ idx = random.randint(0, len(points_inch) - 1)
282
+ cx, cy = points_inch[idx]
283
+ too_close = False
284
+ for (ex_x, ex_y) in existing_centers:
285
+ if np.hypot(cx - ex_x, cy - ex_y) < needed_center_distance:
286
+ too_close = True
287
+ break
288
+ if too_close:
289
+ continue
290
+ circle_poly = Point((cx, cy)).buffer(radius, resolution=64)
291
+ union_poly = tool_polygon.union(circle_poly)
292
+ overlap_with_others = False
293
+ too_close_to_others = False
294
+ for poly in all_polygons:
295
+ if union_poly.intersects(poly):
296
+ overlap_with_others = True
297
+ break
298
+ if circle_poly.buffer(min_gap).intersects(poly):
299
+ too_close_to_others = True
300
+ break
301
+ if overlap_with_others or too_close_to_others:
302
+ continue
303
+ existing_centers.append((cx, cy))
304
+ return union_poly, (cx, cy)
305
+ print("Warning: Could not place a finger cut circle meeting all spacing requirements.")
306
+ return None, None
307
+
308
+ # ---------------------
309
+ # DXF Spline and Boundary Functions
310
+ # ---------------------
311
+ def save_dxf_spline(inflated_contours, scaling_factor, height, finger_clearance=False):
312
+ degree = 3
313
+ closed = True
314
+ doc = ezdxf.new(units=0)
315
+ doc.units = ezdxf.units.IN
316
+ doc.header["$INSUNITS"] = ezdxf.units.IN
317
+ msp = doc.modelspace()
318
+ finger_cut_centers = []
319
+ final_polygons_inch = []
320
+ for contour in inflated_contours:
321
+ try:
322
+ resampled_contour = resample_contour(contour)
323
+ points_inch = [(x * scaling_factor, (height - y) * scaling_factor) for x, y in resampled_contour]
324
+ if len(points_inch) < 3:
325
+ continue
326
+ if np.linalg.norm(np.array(points_inch[0]) - np.array(points_inch[-1])) > 1e-2:
327
+ points_inch.append(points_inch[0])
328
+ tool_polygon = build_tool_polygon(points_inch)
329
+ if finger_clearance:
330
+ union_poly, center = place_finger_cut_randomly(tool_polygon, points_inch, finger_cut_centers, final_polygons_inch, circle_diameter=1.0, min_gap=0.25, max_attempts=30)
331
+ if union_poly is not None:
332
+ tool_polygon = union_poly
333
+ exterior_coords = polygon_to_exterior_coords(tool_polygon)
334
+ if len(exterior_coords) < 3:
335
+ continue
336
+ msp.add_spline(exterior_coords, degree=degree, dxfattribs={"layer": "TOOLS"})
337
+ final_polygons_inch.append(tool_polygon)
338
+ except ValueError as e:
339
+ print(f"Skipping contour: {e}")
340
+ return doc, final_polygons_inch
341
+
342
+ def add_rectangular_boundary(doc, polygons_inch, boundary_length, boundary_width, boundary_unit):
343
+ msp = doc.modelspace()
344
+ if boundary_unit == "mm":
345
+ boundary_length_in = boundary_length / 25.4
346
+ boundary_width_in = boundary_width / 25.4
347
+ else:
348
+ boundary_length_in = boundary_length
349
+ boundary_width_in = boundary_width
350
+ min_x = float("inf")
351
+ min_y = float("inf")
352
+ max_x = -float("inf")
353
+ max_y = -float("inf")
354
+ for poly in polygons_inch:
355
+ b = poly.bounds
356
+ min_x = min(min_x, b[0])
357
+ min_y = min(min_y, b[1])
358
+ max_x = max(max_x, b[2])
359
+ max_y = max(max_y, b[3])
360
+ if min_x == float("inf"):
361
+ print("No tool polygons found, skipping boundary.")
362
+ return None
363
+ shape_cx = (min_x + max_x) / 2
364
+ shape_cy = (min_y + max_y) / 2
365
+ half_w = boundary_width_in / 2.0
366
+ half_l = boundary_length_in / 2.0
367
+ left = shape_cx - half_w
368
+ right = shape_cx + half_w
369
+ bottom = shape_cy - half_l
370
+ top = shape_cy + half_l
371
+ rect_coords = [(left, bottom), (right, bottom), (right, top), (left, top), (left, bottom)]
372
+ from shapely.geometry import Polygon as ShapelyPolygon
373
+ boundary_polygon = ShapelyPolygon(rect_coords)
374
+ msp.add_lwpolyline(rect_coords, close=True, dxfattribs={"layer": "BOUNDARY"})
375
+ return boundary_polygon
376
+
377
+ def draw_polygons_inch(polygons_inch, image_rgb, scaling_factor, image_height, color=(0,0,255), thickness=2):
378
+ for poly in polygons_inch:
379
+ if poly.geom_type == "MultiPolygon":
380
+ for subpoly in poly.geoms:
381
+ draw_single_polygon(subpoly, image_rgb, scaling_factor, image_height, color, thickness)
382
+ else:
383
+ draw_single_polygon(poly, image_rgb, scaling_factor, image_height, color, thickness)
384
+
385
+ def draw_single_polygon(poly, image_rgb, scaling_factor, image_height, color=(0,0,255), thickness=2):
386
+ ext = list(poly.exterior.coords)
387
+ if len(ext) < 3:
388
+ return
389
+ pts_px = []
390
+ for (x_in, y_in) in ext:
391
+ px = int(x_in / scaling_factor)
392
+ py = int(image_height - (y_in / scaling_factor))
393
+ pts_px.append([px, py])
394
+ pts_px = np.array(pts_px, dtype=np.int32)
395
+ cv2.polylines(image_rgb, [pts_px], isClosed=True, color=color, thickness=thickness, lineType=cv2.LINE_AA)
396
+
397
+ # ---------------------
398
+ # Main Predict Function with Finger Cut Clearance, Boundary Box, Annotation and Sharpness Enhancement
399
+ # ---------------------
400
+ def predict(
401
+ image: Union[str, bytes, np.ndarray],
402
+ offset_inches: float,
403
+ finger_clearance: str, # "Yes" or "No"
404
+ add_boundary: str, # "Yes" or "No"
405
+ boundary_length: float,
406
+ boundary_width: float,
407
+ boundary_unit: str,
408
+ annotation_text: str
409
+ ):
410
+ overall_start = time.time()
411
+ # Convert image to NumPy array if needed.
412
+ if isinstance(image, str):
413
+ if os.path.exists(image):
414
+ image = np.array(Image.open(image).convert("RGB"))
415
+ else:
416
+ try:
417
+ image = np.array(Image.open(io.BytesIO(base64.b64decode(image))).convert("RGB"))
418
+ except Exception:
419
+ raise ValueError("Invalid base64 image data")
420
+ # Apply sharpness enhancement if image is a NumPy array.
421
+ if isinstance(image, np.ndarray):
422
+ pil_image = Image.fromarray(image)
423
+ enhanced_image = ImageEnhance.Sharpness(pil_image).enhance(1.5)
424
+ image = np.array(enhanced_image)
425
+ try:
426
+ t = time.time()
427
+ drawer_img = yolo_detect(image)
428
+ print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
429
+ t = time.time()
430
+ shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
431
+ del drawer_img
432
+ gc.collect()
433
+ print("Image shrinking completed in {:.2f} seconds".format(time.time() - t))
434
+ except DrawerNotDetectedError:
435
+ raise DrawerNotDetectedError("Drawer not detected! Please take another picture with a drawer.")
436
+ try:
437
+ t = time.time()
438
+ reference_obj_img, scaling_box_coords = detect_reference_square(shrunked_img)
439
+ print("Reference square detection completed in {:.2f} seconds".format(time.time() - t))
440
+ except ReferenceBoxNotDetectedError:
441
+ raise ReferenceBoxNotDetectedError("Reference box not detected! Please take another picture with a reference box.")
442
+ t = time.time()
443
+ reference_obj_img = make_square(reference_obj_img)
444
+ reference_square_mask = remove_bg_u2netp(reference_obj_img)
445
+ print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
446
+ t = time.time()
447
+ try:
448
+ cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
449
+ scaling_factor = calculate_scaling_factor(
450
+ reference_image_path="./Reference_ScalingBox.jpg",
451
+ target_image=reference_square_mask,
452
+ feature_detector="ORB",
453
+ )
454
+ except ZeroDivisionError:
455
+ scaling_factor = None
456
+ print("Error calculating scaling factor: Division by zero")
457
+ except Exception as e:
458
+ scaling_factor = None
459
+ print(f"Error calculating scaling factor: {e}")
460
+ if scaling_factor is None or scaling_factor == 0:
461
+ scaling_factor = 1.0
462
+ print("Using default scaling factor of 1.0 due to calculation error")
463
+ gc.collect()
464
+ print("Scaling factor determined: {}".format(scaling_factor))
465
+ t = time.time()
466
+ orig_size = shrunked_img.shape[:2]
467
+ objects_mask = remove_bg(shrunked_img)
468
+ processed_size = objects_mask.shape[:2]
469
+ objects_mask = exclude_scaling_box(objects_mask, scaling_box_coords, orig_size, processed_size, expansion_factor=1.2)
470
+ objects_mask = resize_img(objects_mask, (shrunked_img.shape[1], shrunked_img.shape[0]))
471
+ del scaling_box_coords
472
+ gc.collect()
473
+ print("Object masking completed in {:.2f} seconds".format(time.time() - t))
474
+ t = time.time()
475
+ offset_pixels = (offset_inches / scaling_factor) * 2 + 1 if scaling_factor != 0 else 1
476
+ dilated_mask = cv2.dilate(objects_mask, np.ones((int(offset_pixels), int(offset_pixels)), np.uint8))
477
+ del objects_mask
478
+ gc.collect()
479
+ print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
480
+ Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
481
+ t = time.time()
482
+ outlines, contours = extract_outlines(dilated_mask)
483
+ shrunked_img_contours = cv2.drawContours(shrunked_img.copy(), contours, -1, (0, 0, 255), thickness=2)
484
+ del shrunked_img
485
+ gc.collect()
486
+ print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
487
+ t = time.time()
488
+ use_finger_clearance = True if finger_clearance.lower() == "yes" else False
489
+ doc, final_polygons_inch = save_dxf_spline(contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance)
490
+ del contours
491
+ gc.collect()
492
+ print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
493
+ boundary_polygon = None
494
+ if add_boundary.lower() == "yes":
495
+ boundary_polygon = add_rectangular_boundary(doc, final_polygons_inch, boundary_length, boundary_width, boundary_unit)
496
+ if boundary_polygon is not None:
497
+ final_polygons_inch.append(boundary_polygon)
498
+ # --- Annotation Text Placement (Bottom-Right) ---
499
+ min_x = float("inf")
500
+ min_y = float("inf")
501
+ max_x = -float("inf")
502
+ max_y = -float("inf")
503
+ for poly in final_polygons_inch:
504
+ b = poly.bounds
505
+ if b[0] < min_x:
506
+ min_x = b[0]
507
+ if b[1] < min_y:
508
+ min_y = b[1]
509
+ if b[2] > max_x:
510
+ max_x = b[2]
511
+ if b[3] > max_y:
512
+ max_y = b[3]
513
+ margin = 0.5
514
+ text_x = (min_x + max_x) / 2
515
+ text_y = min_y - margin
516
+ msp = doc.modelspace()
517
+ if annotation_text.strip():
518
+ text_entity = msp.add_text(
519
+ annotation_text.strip(),
520
+ dxfattribs={
521
+ "height": 0.25,
522
+ "layer": "ANNOTATION"
523
+ }
524
+ )
525
+ text_entity.dxf.insert = (text_x, text_y)
526
+ dxf_filepath = os.path.join("./outputs", "out.dxf")
527
+ doc.saveas(dxf_filepath)
528
+ # --- End Annotation Text Placement ---
529
+ draw_polygons_inch(final_polygons_inch, shrunked_img_contours, scaling_factor, processed_size[0], color=(0,0,255), thickness=2)
530
+ outlines_bgr = cv2.cvtColor(outlines, cv2.COLOR_GRAY2BGR)
531
+ draw_polygons_inch(final_polygons_inch, outlines_bgr, scaling_factor, processed_size[0], color=(0,0,255), thickness=2)
532
+ if annotation_text.strip():
533
+ text_px = int(text_x / scaling_factor)
534
+ text_py = int(processed_size[0] - (text_y / scaling_factor))
535
+ cv2.putText(shrunked_img_contours, annotation_text.strip(), (text_px, text_py), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
536
+ cv2.putText(outlines_bgr, annotation_text.strip(), (text_px, text_py), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
537
+ outlines_color = cv2.cvtColor(outlines_bgr, cv2.COLOR_BGR2RGB)
538
+ print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
539
+ return (
540
+ cv2.cvtColor(shrunked_img_contours, cv2.COLOR_BGR2RGB),
541
+ outlines_color,
542
+ dxf_filepath,
543
+ dilated_mask,
544
+ str(scaling_factor)
545
+ )
546
+
547
+ # ---------------------
548
+ # Gradio Interface
549
+ # ---------------------
550
+ if __name__ == "__main__":
551
+ os.makedirs("./outputs", exist_ok=True)
552
+ def gradio_predict(img, offset, finger_clearance, add_boundary, boundary_length, boundary_width, boundary_unit, annotation_text):
553
+ return predict(img, offset, finger_clearance, add_boundary, boundary_length, boundary_width, boundary_unit, annotation_text)
554
+ iface = gr.Interface(
555
+ fn=gradio_predict,
556
+ inputs=[
557
+ gr.Image(label="Input Image"),
558
+ gr.Number(label="Offset value for Mask (inches)", value=0.075),
559
+ gr.Dropdown(label="Add Finger Clearance?", choices=["Yes", "No"], value="No"),
560
+ gr.Dropdown(label="Add Rectangular Boundary?", choices=["Yes", "No"], value="No"),
561
+ gr.Number(label="Boundary Length", value=300.0, precision=2),
562
+ gr.Number(label="Boundary Width", value=200.0, precision=2),
563
+ gr.Dropdown(label="Boundary Unit", choices=["mm", "inches"], value="mm"),
564
+ gr.Textbox(label="Annotation (max 20 chars)", max_length=20, placeholder="Type up to 20 characters")
565
+ ],
566
+ outputs=[
567
+ gr.Image(label="Output Image"),
568
+ gr.Image(label="Outlines of Objects"),
569
+ gr.File(label="DXF file"),
570
+ gr.Image(label="Mask"),
571
+ gr.Textbox(label="Scaling Factor (inches/pixel)")
572
+ ],
573
+ examples=[
574
+ ["./examples/Test20.jpg", 0.075, "No", "No", 300.0, 200.0, "mm", "MyTool"],
575
+ ["./examples/Test21.jpg", 0.075, "Yes", "Yes", 300.0, 200.0, "mm", "Tool2"]
576
+ ]
577
+ )
578
+ iface.launch(share=True)