ammariii08 commited on
Commit
f858b25
·
verified ·
1 Parent(s): cdcfb3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -49
app.py CHANGED
@@ -71,18 +71,18 @@ if not os.path.exists(reference_model_path):
71
  reference_detector_global = YOLO(reference_model_path)
72
  print("YOLO reference model loaded in {:.2f} seconds".format(time.time() - start_time))
73
 
74
- print("Loading U²-Net model for reference background removal (U2NETP)...")
75
- start_time = time.time()
76
- u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
77
- if not os.path.exists(u2net_model_path):
78
- print("Caching U²-Net model to", u2net_model_path)
79
- shutil.copy("u2netp.pth", u2net_model_path)
80
- u2net_global = U2NETP(3, 1)
81
- u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
82
- device = "cpu"
83
- u2net_global.to(device)
84
- u2net_global.eval()
85
- print("U²-Net model loaded in {:.2f} seconds".format(time.time() - start_time))
86
 
87
  print("Loading BiRefNet model...")
88
  start_time = time.time()
@@ -119,16 +119,16 @@ def unload_and_reload_models():
119
  new_birefnet = AutoModelForImageSegmentation.from_pretrained(
120
  "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
121
  )
122
- new_birefnet.to(device)
123
- new_birefnet.eval()
124
- new_u2net = U2NETP(3, 1)
125
- new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
126
- new_u2net.to(device)
127
- new_u2net.eval()
128
  drawer_detector_global = new_drawer_detector
129
  reference_detector_global = new_reference_detector
130
  birefnet_global = new_birefnet
131
- u2net_global = new_u2net
132
  print("Models reloaded in {:.2f} seconds".format(time.time() - start_time))
133
 
134
  # ---------------------
@@ -159,23 +159,27 @@ def detect_reference_square(img: np.ndarray):
159
  res[0].cpu().boxes.xyxy[0]
160
  )
161
 
162
- # For reference background removal, we now use BiRefNet.
163
- def remove_bg_reference(image: np.ndarray) -> np.ndarray:
164
- # Use the same BiRefNet method as for the main object.
165
- t = time.time()
166
- image_pil = Image.fromarray(image)
167
- input_images = transform_image_global(image_pil).unsqueeze(0).to("cpu")
168
- with torch.no_grad():
169
- preds = birefnet_global(input_images)[-1].sigmoid().cpu()
170
- pred = preds[0].squeeze()
171
- pred_pil = transforms.ToPILImage()(pred)
172
- scale_ratio = 1024 / max(image_pil.size)
173
- scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
174
- result = np.array(pred_pil.resize(scaled_size))
175
- print("BiRefNet (reference) background removal completed in {:.2f} seconds".format(time.time() - t))
176
- return result
177
-
178
- # The main background removal for objects still uses BiRefNet.
 
 
 
 
179
  def remove_bg(image: np.ndarray) -> np.ndarray:
180
  t = time.time()
181
  image_pil = Image.fromarray(image)
@@ -187,7 +191,7 @@ def remove_bg(image: np.ndarray) -> np.ndarray:
187
  scale_ratio = 1024 / max(image_pil.size)
188
  scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
189
  result = np.array(pred_pil.resize(scaled_size))
190
- print("BiRefNet (object) background removal completed in {:.2f} seconds".format(time.time() - t))
191
  return result
192
 
193
  def make_square(img: np.ndarray):
@@ -469,6 +473,7 @@ def predict(
469
  print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
470
  except DrawerNotDetectedError as e:
471
  return None, None, None, None, f"Error: {str(e)}"
 
472
  t = time.time()
473
  shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
474
  del drawer_img
@@ -490,9 +495,9 @@ def predict(
490
  # ---------------------
491
  t = time.time()
492
  reference_obj_img = make_square(reference_obj_img)
493
- # Use BiRefNet for reference background removal instead of U2NETP:
494
- reference_square_mask = remove_bg_reference(reference_obj_img)
495
  print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
 
496
  t = time.time()
497
  try:
498
  cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
@@ -565,6 +570,7 @@ def predict(
565
  del objects_mask
566
  gc.collect()
567
  print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
 
568
  Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
569
 
570
  # ---------------------
@@ -573,12 +579,16 @@ def predict(
573
  t = time.time()
574
  outlines, contours = extract_outlines(dilated_mask)
575
  print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
 
576
  output_img = shrunked_img.copy()
577
  del shrunked_img
578
  gc.collect()
 
579
  t = time.time()
580
  use_finger_clearance = True if finger_clearance.lower() == "yes" else False
581
- doc, final_polygons_inch = save_dxf_spline(contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance)
 
 
582
  del contours
583
  gc.collect()
584
  print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
@@ -623,8 +633,14 @@ def predict(
623
  text_x = (inner_min_x + inner_max_x) / 2.0
624
  text_height_dxf = 0.5
625
  text_y_dxf = inner_min_y - 0.125 - text_height_dxf
626
- text_entity = msp.add_text(annotation_text.strip(),
627
- dxfattribs={"height": text_height_dxf, "layer": "ANNOTATION", "style": "Bold"})
 
 
 
 
 
 
628
  text_entity.dxf.insert = (text_x, text_y_dxf)
629
 
630
  # Save the DXF
@@ -644,8 +660,27 @@ def predict(
644
  text_y_in = inner_min_y - 0.125 - text_height_cv
645
  text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
646
  org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
647
- cv2.putText(output_img, annotation_text.strip(), org, cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3, cv2.LINE_AA)
648
- cv2.putText(new_outlines, annotation_text.strip(), org, cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
  # Restore brightness for display purposes:
651
  # Since we reduced brightness by 0.5 during preprocessing,
@@ -656,11 +691,14 @@ def predict(
656
 
657
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
658
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
659
- return (cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB),
660
- outlines_color,
661
- dxf_filepath,
662
- dilated_mask,
663
- str(scaling_factor))
 
 
 
664
 
665
  # ---------------------
666
  # Gradio Interface
 
71
  reference_detector_global = YOLO(reference_model_path)
72
  print("YOLO reference model loaded in {:.2f} seconds".format(time.time() - start_time))
73
 
74
+ # print("Loading U²-Net model for reference background removal (U2NETP)...")
75
+ # start_time = time.time()
76
+ # u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
77
+ # if not os.path.exists(u2net_model_path):
78
+ # print("Caching U²-Net model to", u2net_model_path)
79
+ # shutil.copy("u2netp.pth", u2net_model_path)
80
+ # u2net_global = U2NETP(3, 1)
81
+ # u2net_global.load_state_dict(torch.load(u2net_model_path, map_location="cpu"))
82
+ # device = "cpu"
83
+ # u2net_global.to(device)
84
+ # u2net_global.eval()
85
+ # print("U²-Net model loaded in {:.2f} seconds".format(time.time() - start_time))
86
 
87
  print("Loading BiRefNet model...")
88
  start_time = time.time()
 
119
  new_birefnet = AutoModelForImageSegmentation.from_pretrained(
120
  "zhengpeng7/BiRefNet", trust_remote_code=True, cache_dir=CACHE_DIR
121
  )
122
+ # new_birefnet.to(device)
123
+ # new_birefnet.eval()
124
+ # new_u2net = U2NETP(3, 1)
125
+ # new_u2net.load_state_dict(torch.load(os.path.join(CACHE_DIR, "u2netp.pth"), map_location="cpu"))
126
+ # new_u2net.to(device)
127
+ # new_u2net.eval()
128
  drawer_detector_global = new_drawer_detector
129
  reference_detector_global = new_reference_detector
130
  birefnet_global = new_birefnet
131
+ u2net_global = new_birefnet
132
  print("Models reloaded in {:.2f} seconds".format(time.time() - start_time))
133
 
134
  # ---------------------
 
159
  res[0].cpu().boxes.xyxy[0]
160
  )
161
 
162
+ # Use U2NETP for reference background removal.
163
+ # def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
164
+ # t = time.time()
165
+ # image_pil = Image.fromarray(image)
166
+ # transform_u2netp = transforms.Compose([
167
+ # transforms.Resize((320, 320)),
168
+ # transforms.ToTensor(),
169
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
170
+ # ])
171
+ # input_tensor = transform_u2netp(image_pil).unsqueeze(0).to("cpu")
172
+ # with torch.no_grad():
173
+ # outputs = u2net_global(input_tensor)
174
+ # pred = outputs[0]
175
+ # pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
176
+ # pred_np = pred.squeeze().cpu().numpy()
177
+ # pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
178
+ # pred_np = (pred_np * 255).astype(np.uint8)
179
+ # print("U2NETP background removal completed in {:.2f} seconds".format(time.time() - t))
180
+ # return pred_np
181
+
182
+ # Use BiRefNet for main object background removal.
183
  def remove_bg(image: np.ndarray) -> np.ndarray:
184
  t = time.time()
185
  image_pil = Image.fromarray(image)
 
191
  scale_ratio = 1024 / max(image_pil.size)
192
  scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
193
  result = np.array(pred_pil.resize(scaled_size))
194
+ print("BiRefNet background removal completed in {:.2f} seconds".format(time.time() - t))
195
  return result
196
 
197
  def make_square(img: np.ndarray):
 
473
  print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
474
  except DrawerNotDetectedError as e:
475
  return None, None, None, None, f"Error: {str(e)}"
476
+ # Ensure that shrunked_img is defined only after successful detection.
477
  t = time.time()
478
  shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
479
  del drawer_img
 
495
  # ---------------------
496
  t = time.time()
497
  reference_obj_img = make_square(reference_obj_img)
498
+ reference_square_mask = remove_bg(reference_obj_img)
 
499
  print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
500
+
501
  t = time.time()
502
  try:
503
  cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
 
570
  del objects_mask
571
  gc.collect()
572
  print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
573
+
574
  Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
575
 
576
  # ---------------------
 
579
  t = time.time()
580
  outlines, contours = extract_outlines(dilated_mask)
581
  print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
582
+
583
  output_img = shrunked_img.copy()
584
  del shrunked_img
585
  gc.collect()
586
+
587
  t = time.time()
588
  use_finger_clearance = True if finger_clearance.lower() == "yes" else False
589
+ doc, final_polygons_inch = save_dxf_spline(
590
+ contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance
591
+ )
592
  del contours
593
  gc.collect()
594
  print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
 
633
  text_x = (inner_min_x + inner_max_x) / 2.0
634
  text_height_dxf = 0.5
635
  text_y_dxf = inner_min_y - 0.125 - text_height_dxf
636
+ text_entity = msp.add_text(
637
+ annotation_text.strip(),
638
+ dxfattribs={
639
+ "height": text_height_dxf,
640
+ "layer": "ANNOTATION",
641
+ "style": "Bold"
642
+ }
643
+ )
644
  text_entity.dxf.insert = (text_x, text_y_dxf)
645
 
646
  # Save the DXF
 
660
  text_y_in = inner_min_y - 0.125 - text_height_cv
661
  text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
662
  org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
663
+
664
+ cv2.putText(
665
+ output_img,
666
+ annotation_text.strip(),
667
+ org,
668
+ cv2.FONT_HERSHEY_SIMPLEX,
669
+ 1.3,
670
+ (0, 0, 255),
671
+ 3,
672
+ cv2.LINE_AA
673
+ )
674
+ cv2.putText(
675
+ new_outlines,
676
+ annotation_text.strip(),
677
+ org,
678
+ cv2.FONT_HERSHEY_SIMPLEX,
679
+ 1.3,
680
+ (0, 0, 255),
681
+ 3,
682
+ cv2.LINE_AA
683
+ )
684
 
685
  # Restore brightness for display purposes:
686
  # Since we reduced brightness by 0.5 during preprocessing,
 
691
 
692
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
693
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
694
+
695
+ return (
696
+ cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB),
697
+ outlines_color,
698
+ dxf_filepath,
699
+ dilated_mask,
700
+ str(scaling_factor)
701
+ )
702
 
703
  # ---------------------
704
  # Gradio Interface