Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -86,9 +86,8 @@ paper_model_path = os.path.join(CACHE_DIR, "paper_detector.pt") # You'll need t
|
|
| 86 |
u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
|
| 87 |
|
| 88 |
# Global variable for YOLOWorld
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
|
| 94 |
# Device configuration
|
|
@@ -108,11 +107,7 @@ def ensure_model_files():
|
|
| 108 |
shutil.copy("u2netp.pth", u2net_model_path)
|
| 109 |
else:
|
| 110 |
raise FileNotFoundError("u2netp.pth model file not found")
|
| 111 |
-
if not
|
| 112 |
-
if os.path.exists("yolov8s_world.pt"): # Adjust to match your file name
|
| 113 |
-
shutil.copy("yolov8s_world.pt", yolo_world_model_path)
|
| 114 |
-
else:
|
| 115 |
-
logger.warning("yolov8s-world.pt model file not found - falling back to full image processing")
|
| 116 |
|
| 117 |
ensure_model_files()
|
| 118 |
|
|
@@ -134,22 +129,18 @@ def get_paper_detector():
|
|
| 134 |
logger.warning("Paper model file not found, using fallback detection")
|
| 135 |
paper_detector_global = None
|
| 136 |
return paper_detector_global
|
| 137 |
-
def
|
| 138 |
-
"""Lazy load
|
| 139 |
-
global
|
| 140 |
-
if
|
| 141 |
-
logger.info("Loading
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
else:
|
| 150 |
-
logger.warning("YOLOWorld model file not found, will raise error if used")
|
| 151 |
-
yolo_world_global = None
|
| 152 |
-
return yolo_world_global
|
| 153 |
def get_u2net():
|
| 154 |
"""Lazy load U2NETP model"""
|
| 155 |
global u2net_global
|
|
@@ -976,46 +967,43 @@ def predict_with_paper(image, paper_size, offset, offset_unit, finger_clearance=
|
|
| 976 |
# Mask paper area in input image first
|
| 977 |
masked_input_image = mask_paper_area_in_image(image, paper_contour)
|
| 978 |
|
| 979 |
-
# Use
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
logger.warning("YOLOWorld model not available, proceeding with full image")
|
| 984 |
cropped_image = masked_input_image
|
| 985 |
crop_offset = (0, 0)
|
| 986 |
else:
|
| 987 |
-
|
| 988 |
-
results =
|
| 989 |
|
| 990 |
if not results or len(results) == 0 or not hasattr(results[0], 'boxes') or len(results[0].boxes) == 0:
|
| 991 |
-
logger.warning("No objects detected by
|
| 992 |
cropped_image = masked_input_image
|
| 993 |
crop_offset = (0, 0)
|
| 994 |
else:
|
| 995 |
boxes = results[0].boxes.xyxy.cpu().numpy()
|
| 996 |
confidences = results[0].boxes.conf.cpu().numpy()
|
| 997 |
|
| 998 |
-
# Filter out
|
| 999 |
-
valid_boxes = []
|
| 1000 |
image_area = masked_input_image.shape[0] * masked_input_image.shape[1]
|
|
|
|
| 1001 |
|
| 1002 |
for i, box in enumerate(boxes):
|
| 1003 |
x_min, y_min, x_max, y_max = box
|
| 1004 |
box_area = (x_max - x_min) * (y_max - y_min)
|
| 1005 |
-
|
|
|
|
| 1006 |
valid_boxes.append((i, confidences[i]))
|
| 1007 |
|
| 1008 |
if not valid_boxes:
|
|
|
|
| 1009 |
cropped_image = masked_input_image
|
| 1010 |
crop_offset = (0, 0)
|
| 1011 |
else:
|
| 1012 |
# Get highest confidence valid box
|
| 1013 |
best_idx = max(valid_boxes, key=lambda x: x[1])[0]
|
| 1014 |
x_min, y_min, x_max, y_max = map(int, boxes[best_idx])
|
| 1015 |
-
|
| 1016 |
-
# Larger margin for small objects
|
| 1017 |
-
box_size = min(x_max - x_min, y_max - y_min)
|
| 1018 |
-
margin = max(30, int(box_size * 0.3)) # At least 30px margin
|
| 1019 |
|
| 1020 |
# Remove background from cropped image
|
| 1021 |
orig_size = image.shape[:2]
|
|
|
|
| 86 |
u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
|
| 87 |
|
| 88 |
# Global variable for YOLOWorld
|
| 89 |
+
yolo_v8_global = None
|
| 90 |
+
yolo_v8_model_path = os.path.join(CACHE_DIR, "yolov8s.pt") # Adjust path as needed
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
# Device configuration
|
|
|
|
| 107 |
shutil.copy("u2netp.pth", u2net_model_path)
|
| 108 |
else:
|
| 109 |
raise FileNotFoundError("u2netp.pth model file not found")
|
| 110 |
+
logger.info("YOLOv8 will auto-download if not present")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
ensure_model_files()
|
| 113 |
|
|
|
|
| 129 |
logger.warning("Paper model file not found, using fallback detection")
|
| 130 |
paper_detector_global = None
|
| 131 |
return paper_detector_global
|
| 132 |
+
def get_yolo_v8():
|
| 133 |
+
"""Lazy load YOLOv8 model"""
|
| 134 |
+
global yolo_v8_global
|
| 135 |
+
if yolo_v8_global is None:
|
| 136 |
+
logger.info("Loading YOLOv8 model...")
|
| 137 |
+
try:
|
| 138 |
+
yolo_v8_global = YOLO(yolo_v8_model_path) # Auto-downloads if needed
|
| 139 |
+
logger.info("YOLOv8 model loaded successfully")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Failed to load YOLOv8: {e}")
|
| 142 |
+
yolo_v8_global = None
|
| 143 |
+
return yolo_v8_global
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
def get_u2net():
|
| 145 |
"""Lazy load U2NETP model"""
|
| 146 |
global u2net_global
|
|
|
|
| 967 |
# Mask paper area in input image first
|
| 968 |
masked_input_image = mask_paper_area_in_image(image, paper_contour)
|
| 969 |
|
| 970 |
+
# Use YOLOv8 to detect objects
|
| 971 |
+
yolo_v8 = get_yolo_v8()
|
| 972 |
+
if yolo_v8 is None:
|
| 973 |
+
logger.warning("YOLOv8 model not available, proceeding with full image")
|
|
|
|
| 974 |
cropped_image = masked_input_image
|
| 975 |
crop_offset = (0, 0)
|
| 976 |
else:
|
| 977 |
+
# YOLOv8 detects all COCO classes by default
|
| 978 |
+
results = yolo_v8.predict(masked_input_image, conf=0.1, verbose=False)
|
| 979 |
|
| 980 |
if not results or len(results) == 0 or not hasattr(results[0], 'boxes') or len(results[0].boxes) == 0:
|
| 981 |
+
logger.warning("No objects detected by YOLOv8, proceeding with full image")
|
| 982 |
cropped_image = masked_input_image
|
| 983 |
crop_offset = (0, 0)
|
| 984 |
else:
|
| 985 |
boxes = results[0].boxes.xyxy.cpu().numpy()
|
| 986 |
confidences = results[0].boxes.conf.cpu().numpy()
|
| 987 |
|
| 988 |
+
# Filter out very large boxes (likely paper/background)
|
|
|
|
| 989 |
image_area = masked_input_image.shape[0] * masked_input_image.shape[1]
|
| 990 |
+
valid_boxes = []
|
| 991 |
|
| 992 |
for i, box in enumerate(boxes):
|
| 993 |
x_min, y_min, x_max, y_max = box
|
| 994 |
box_area = (x_max - x_min) * (y_max - y_min)
|
| 995 |
+
# Keep boxes that are 5% to 40% of image area
|
| 996 |
+
if 0.05 * image_area < box_area < 0.4 * image_area:
|
| 997 |
valid_boxes.append((i, confidences[i]))
|
| 998 |
|
| 999 |
if not valid_boxes:
|
| 1000 |
+
logger.warning("No valid objects detected, proceeding with full image")
|
| 1001 |
cropped_image = masked_input_image
|
| 1002 |
crop_offset = (0, 0)
|
| 1003 |
else:
|
| 1004 |
# Get highest confidence valid box
|
| 1005 |
best_idx = max(valid_boxes, key=lambda x: x[1])[0]
|
| 1006 |
x_min, y_min, x_max, y_max = map(int, boxes[best_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
|
| 1008 |
# Remove background from cropped image
|
| 1009 |
orig_size = image.shape[:2]
|