Spaces:
Build error
Build error
# Prediction interface for Cog ⚙️ | |
# https://cog.run/python | |
from cog import BasePredictor, Input, Path | |
from PIL import Image | |
from utils import ( | |
check_ocr_box, | |
get_yolo_model, | |
get_caption_model_processor, | |
get_som_labeled_img | |
) | |
class Predictor(BasePredictor): | |
def setup(self): | |
"""Load the model into memory""" | |
self.yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt') | |
self.caption_model_processor = get_caption_model_processor( | |
model_name="florence2", | |
model_name_or_path="weights/icon_caption_florence" | |
) | |
self.draw_bbox_config = { | |
'text_scale': 0.8, | |
'text_thickness': 2, | |
'text_padding': 2, | |
'thickness': 2, | |
} | |
def predict( | |
self, | |
image: Path = Input(description="Screenshot of the screen"), | |
screen_width: int = Input( | |
description="Screen width in pixels", | |
default=1920, | |
ge=800, # Setting minimum reasonable screen width | |
le=7680, # Supporting up to 8K displays | |
), | |
screen_height: int = Input( | |
description="Screen height in pixels", | |
default=1080, | |
ge=600, # Setting minimum reasonable screen height | |
le=4320, # Supporting up to 8K displays | |
), | |
box_threshold: float = Input( | |
description="Confidence threshold for box detection", | |
default=0.05, | |
ge=0.01, | |
le=1.0, | |
), | |
iou_threshold: float = Input( | |
description="IOU threshold for overlap detection", | |
default=0.1, | |
ge=0.01, | |
le=1.0, | |
), | |
) -> dict: | |
"""Run object detection on a screenshot and return coordinates""" | |
# Ensure the input image exists and is valid | |
if not image.exists(): | |
raise ValueError("Input image file does not exist") | |
# Open and validate the image | |
try: | |
input_image = Image.open(image) | |
input_image.verify() # Verify it's a valid image | |
except Exception as e: | |
raise ValueError(f"Invalid image file: {str(e)}") | |
# Save input image temporarily | |
image_save_path = '/tmp/input_image.png' | |
input_image = Image.open(image) | |
input_image.save(image_save_path) | |
# Get OCR results | |
ocr_bbox_rslt, _ = check_ocr_box( | |
image_save_path, | |
display_img=False, | |
output_bb_format='xyxy', | |
goal_filtering=None, | |
easyocr_args={'paragraph': False, 'text_threshold': 0.9} | |
) | |
text, ocr_bbox = ocr_bbox_rslt | |
# Get labeled image and coordinates | |
dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
image_save_path, | |
self.yolo_model, | |
BOX_TRESHOLD=box_threshold, | |
output_coord_in_ratio=True, | |
ocr_bbox=ocr_bbox, | |
draw_bbox_config=self.draw_bbox_config, | |
caption_model_processor=self.caption_model_processor, | |
ocr_text=text, | |
iou_threshold=iou_threshold | |
) | |
# Format output | |
elements = [] | |
for i, (element_id, coords) in enumerate(label_coordinates.items()): | |
x, y, w, h = coords | |
# Calculate center points (normalized) | |
center_x_norm = x + (w/2) | |
center_y_norm = y + (h/2) | |
# Calculate screen coordinates | |
screen_x = int(center_x_norm * screen_width) | |
screen_y = int(center_y_norm * screen_height) | |
# Calculate element dimensions on screen | |
screen_w = int(w * screen_width) | |
screen_h = int(h * screen_height) | |
element = { | |
"description": parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}", | |
"normalized_coordinates": { | |
"x": center_x_norm, | |
"y": center_y_norm | |
}, | |
"screen_coordinates": { | |
"x": screen_x, | |
"y": screen_y | |
}, | |
"dimensions": { | |
"width": screen_w, | |
"height": screen_h | |
} | |
} | |
elements.append(element) | |
return { | |
"image": dino_labeled_img, # Base64 encoded image | |
"elements": elements | |
} | |