Spaces:
Build error
Build error
File size: 4,534 Bytes
2ad48f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Prediction interface for Cog ⚙️
# https://cog.run/python
from cog import BasePredictor, Input, Path
from PIL import Image
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img
)
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory"""
self.yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
self.caption_model_processor = get_caption_model_processor(
model_name="florence2",
model_name_or_path="weights/icon_caption_florence"
)
self.draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
}
def predict(
self,
image: Path = Input(description="Screenshot of the screen"),
screen_width: int = Input(
description="Screen width in pixels",
default=1920,
ge=800, # Setting minimum reasonable screen width
le=7680, # Supporting up to 8K displays
),
screen_height: int = Input(
description="Screen height in pixels",
default=1080,
ge=600, # Setting minimum reasonable screen height
le=4320, # Supporting up to 8K displays
),
box_threshold: float = Input(
description="Confidence threshold for box detection",
default=0.05,
ge=0.01,
le=1.0,
),
iou_threshold: float = Input(
description="IOU threshold for overlap detection",
default=0.1,
ge=0.01,
le=1.0,
),
) -> dict:
"""Run object detection on a screenshot and return coordinates"""
# Ensure the input image exists and is valid
if not image.exists():
raise ValueError("Input image file does not exist")
# Open and validate the image
try:
input_image = Image.open(image)
input_image.verify() # Verify it's a valid image
except Exception as e:
raise ValueError(f"Invalid image file: {str(e)}")
# Save input image temporarily
image_save_path = '/tmp/input_image.png'
input_image = Image.open(image)
input_image.save(image_save_path)
# Get OCR results
ocr_bbox_rslt, _ = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9}
)
text, ocr_bbox = ocr_bbox_rslt
# Get labeled image and coordinates
dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
self.yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=self.draw_bbox_config,
caption_model_processor=self.caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold
)
# Format output
elements = []
for i, (element_id, coords) in enumerate(label_coordinates.items()):
x, y, w, h = coords
# Calculate center points (normalized)
center_x_norm = x + (w/2)
center_y_norm = y + (h/2)
# Calculate screen coordinates
screen_x = int(center_x_norm * screen_width)
screen_y = int(center_y_norm * screen_height)
# Calculate element dimensions on screen
screen_w = int(w * screen_width)
screen_h = int(h * screen_height)
element = {
"description": parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}",
"normalized_coordinates": {
"x": center_x_norm,
"y": center_y_norm
},
"screen_coordinates": {
"x": screen_x,
"y": screen_y
},
"dimensions": {
"width": screen_w,
"height": screen_h
}
}
elements.append(element)
return {
"image": dino_labeled_img, # Base64 encoded image
"elements": elements
}
|