ammariii08 commited on
Commit
cdcfb3e
·
verified ·
1 Parent(s): 86be3ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -67
app.py CHANGED
@@ -159,27 +159,23 @@ def detect_reference_square(img: np.ndarray):
159
  res[0].cpu().boxes.xyxy[0]
160
  )
161
 
162
- # Use U2NETP for reference background removal.
163
- def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
 
164
  t = time.time()
165
  image_pil = Image.fromarray(image)
166
- transform_u2netp = transforms.Compose([
167
- transforms.Resize((320, 320)),
168
- transforms.ToTensor(),
169
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
170
- ])
171
- input_tensor = transform_u2netp(image_pil).unsqueeze(0).to("cpu")
172
  with torch.no_grad():
173
- outputs = u2net_global(input_tensor)
174
- pred = outputs[0]
175
- pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
176
- pred_np = pred.squeeze().cpu().numpy()
177
- pred_np = cv2.resize(pred_np, (image_pil.width, image_pil.height))
178
- pred_np = (pred_np * 255).astype(np.uint8)
179
- print("U2NETP background removal completed in {:.2f} seconds".format(time.time() - t))
180
- return pred_np
181
-
182
- # Use BiRefNet for main object background removal.
183
  def remove_bg(image: np.ndarray) -> np.ndarray:
184
  t = time.time()
185
  image_pil = Image.fromarray(image)
@@ -191,7 +187,7 @@ def remove_bg(image: np.ndarray) -> np.ndarray:
191
  scale_ratio = 1024 / max(image_pil.size)
192
  scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
193
  result = np.array(pred_pil.resize(scaled_size))
194
- print("BiRefNet background removal completed in {:.2f} seconds".format(time.time() - t))
195
  return result
196
 
197
  def make_square(img: np.ndarray):
@@ -473,7 +469,6 @@ def predict(
473
  print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
474
  except DrawerNotDetectedError as e:
475
  return None, None, None, None, f"Error: {str(e)}"
476
- # Ensure that shrunked_img is defined only after successful detection.
477
  t = time.time()
478
  shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
479
  del drawer_img
@@ -495,9 +490,9 @@ def predict(
495
  # ---------------------
496
  t = time.time()
497
  reference_obj_img = make_square(reference_obj_img)
498
- reference_square_mask = remove_bg_u2netp(reference_obj_img)
 
499
  print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
500
-
501
  t = time.time()
502
  try:
503
  cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
@@ -570,7 +565,6 @@ def predict(
570
  del objects_mask
571
  gc.collect()
572
  print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
573
-
574
  Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
575
 
576
  # ---------------------
@@ -579,16 +573,12 @@ def predict(
579
  t = time.time()
580
  outlines, contours = extract_outlines(dilated_mask)
581
  print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
582
-
583
  output_img = shrunked_img.copy()
584
  del shrunked_img
585
  gc.collect()
586
-
587
  t = time.time()
588
  use_finger_clearance = True if finger_clearance.lower() == "yes" else False
589
- doc, final_polygons_inch = save_dxf_spline(
590
- contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance
591
- )
592
  del contours
593
  gc.collect()
594
  print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
@@ -633,14 +623,8 @@ def predict(
633
  text_x = (inner_min_x + inner_max_x) / 2.0
634
  text_height_dxf = 0.5
635
  text_y_dxf = inner_min_y - 0.125 - text_height_dxf
636
- text_entity = msp.add_text(
637
- annotation_text.strip(),
638
- dxfattribs={
639
- "height": text_height_dxf,
640
- "layer": "ANNOTATION",
641
- "style": "Bold"
642
- }
643
- )
644
  text_entity.dxf.insert = (text_x, text_y_dxf)
645
 
646
  # Save the DXF
@@ -660,27 +644,8 @@ def predict(
660
  text_y_in = inner_min_y - 0.125 - text_height_cv
661
  text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
662
  org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
663
-
664
- cv2.putText(
665
- output_img,
666
- annotation_text.strip(),
667
- org,
668
- cv2.FONT_HERSHEY_SIMPLEX,
669
- 1.3,
670
- (0, 0, 255),
671
- 3,
672
- cv2.LINE_AA
673
- )
674
- cv2.putText(
675
- new_outlines,
676
- annotation_text.strip(),
677
- org,
678
- cv2.FONT_HERSHEY_SIMPLEX,
679
- 1.3,
680
- (0, 0, 255),
681
- 3,
682
- cv2.LINE_AA
683
- )
684
 
685
  # Restore brightness for display purposes:
686
  # Since we reduced brightness by 0.5 during preprocessing,
@@ -691,14 +656,11 @@ def predict(
691
 
692
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
693
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
694
-
695
- return (
696
- cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB),
697
- outlines_color,
698
- dxf_filepath,
699
- dilated_mask,
700
- str(scaling_factor)
701
- )
702
 
703
  # ---------------------
704
  # Gradio Interface
@@ -734,5 +696,4 @@ if __name__ == "__main__":
734
  ["./Test21.jpg", 0.075, "inches", "Yes", "Yes", 300.0, 200.0, "Tool2"]
735
  ]
736
  )
737
- iface.launch(share=True)
738
-
 
159
  res[0].cpu().boxes.xyxy[0]
160
  )
161
 
162
+ # For reference background removal, we now use BiRefNet.
163
+ def remove_bg_reference(image: np.ndarray) -> np.ndarray:
164
+ # Use the same BiRefNet method as for the main object.
165
  t = time.time()
166
  image_pil = Image.fromarray(image)
167
+ input_images = transform_image_global(image_pil).unsqueeze(0).to("cpu")
 
 
 
 
 
168
  with torch.no_grad():
169
+ preds = birefnet_global(input_images)[-1].sigmoid().cpu()
170
+ pred = preds[0].squeeze()
171
+ pred_pil = transforms.ToPILImage()(pred)
172
+ scale_ratio = 1024 / max(image_pil.size)
173
+ scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
174
+ result = np.array(pred_pil.resize(scaled_size))
175
+ print("BiRefNet (reference) background removal completed in {:.2f} seconds".format(time.time() - t))
176
+ return result
177
+
178
+ # The main background removal for objects still uses BiRefNet.
179
  def remove_bg(image: np.ndarray) -> np.ndarray:
180
  t = time.time()
181
  image_pil = Image.fromarray(image)
 
187
  scale_ratio = 1024 / max(image_pil.size)
188
  scaled_size = (int(image_pil.size[0] * scale_ratio), int(image_pil.size[1] * scale_ratio))
189
  result = np.array(pred_pil.resize(scaled_size))
190
+ print("BiRefNet (object) background removal completed in {:.2f} seconds".format(time.time() - t))
191
  return result
192
 
193
  def make_square(img: np.ndarray):
 
469
  print("Drawer detection completed in {:.2f} seconds".format(time.time() - t))
470
  except DrawerNotDetectedError as e:
471
  return None, None, None, None, f"Error: {str(e)}"
 
472
  t = time.time()
473
  shrunked_img = make_square(shrink_bbox(drawer_img, 0.90))
474
  del drawer_img
 
490
  # ---------------------
491
  t = time.time()
492
  reference_obj_img = make_square(reference_obj_img)
493
+ # Use BiRefNet for reference background removal instead of U2NETP:
494
+ reference_square_mask = remove_bg_reference(reference_obj_img)
495
  print("Reference image processing completed in {:.2f} seconds".format(time.time() - t))
 
496
  t = time.time()
497
  try:
498
  cv2.imwrite("mask.jpg", cv2.cvtColor(reference_obj_img, cv2.COLOR_RGB2GRAY))
 
565
  del objects_mask
566
  gc.collect()
567
  print("Mask dilation completed in {:.2f} seconds".format(time.time() - t))
 
568
  Image.fromarray(dilated_mask).save("./outputs/scaled_mask_new.jpg")
569
 
570
  # ---------------------
 
573
  t = time.time()
574
  outlines, contours = extract_outlines(dilated_mask)
575
  print("Outline extraction completed in {:.2f} seconds".format(time.time() - t))
 
576
  output_img = shrunked_img.copy()
577
  del shrunked_img
578
  gc.collect()
 
579
  t = time.time()
580
  use_finger_clearance = True if finger_clearance.lower() == "yes" else False
581
+ doc, final_polygons_inch = save_dxf_spline(contours, scaling_factor, processed_size[0], finger_clearance=use_finger_clearance)
 
 
582
  del contours
583
  gc.collect()
584
  print("DXF generation completed in {:.2f} seconds".format(time.time() - t))
 
623
  text_x = (inner_min_x + inner_max_x) / 2.0
624
  text_height_dxf = 0.5
625
  text_y_dxf = inner_min_y - 0.125 - text_height_dxf
626
+ text_entity = msp.add_text(annotation_text.strip(),
627
+ dxfattribs={"height": text_height_dxf, "layer": "ANNOTATION", "style": "Bold"})
 
 
 
 
 
 
628
  text_entity.dxf.insert = (text_x, text_y_dxf)
629
 
630
  # Save the DXF
 
644
  text_y_in = inner_min_y - 0.125 - text_height_cv
645
  text_y_img = int(processed_size[0] - (text_y_in / scaling_factor))
646
  org = (text_x_img - int(len(annotation_text.strip()) * 6), text_y_img)
647
+ cv2.putText(output_img, annotation_text.strip(), org, cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3, cv2.LINE_AA)
648
+ cv2.putText(new_outlines, annotation_text.strip(), org, cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
  # Restore brightness for display purposes:
651
  # Since we reduced brightness by 0.5 during preprocessing,
 
656
 
657
  outlines_color = cv2.cvtColor(new_outlines, cv2.COLOR_BGR2RGB)
658
  print("Total prediction time: {:.2f} seconds".format(time.time() - overall_start))
659
+ return (cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB),
660
+ outlines_color,
661
+ dxf_filepath,
662
+ dilated_mask,
663
+ str(scaling_factor))
 
 
 
664
 
665
  # ---------------------
666
  # Gradio Interface
 
696
  ["./Test21.jpg", 0.075, "inches", "Yes", "Yes", 300.0, 200.0, "Tool2"]
697
  ]
698
  )
699
+ iface.launch(share=True)