File size: 4,534 Bytes
2ad48f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path
from PIL import Image
from utils import (
    check_ocr_box,
    get_yolo_model,
    get_caption_model_processor,
    get_som_labeled_img
)


class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory"""
        self.yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
        self.caption_model_processor = get_caption_model_processor(
            model_name="florence2", 
            model_name_or_path="weights/icon_caption_florence"
        )
        self.draw_bbox_config = {
            'text_scale': 0.8,
            'text_thickness': 2,
            'text_padding': 2,
            'thickness': 2,
        }

    def predict(
        self,
        image: Path = Input(description="Screenshot of the screen"),
        screen_width: int = Input(
            description="Screen width in pixels",
            default=1920,
            ge=800,  # Setting minimum reasonable screen width
            le=7680,  # Supporting up to 8K displays
        ),
        screen_height: int = Input(
            description="Screen height in pixels",
            default=1080,
            ge=600,  # Setting minimum reasonable screen height
            le=4320,  # Supporting up to 8K displays
        ),
        box_threshold: float = Input(
            description="Confidence threshold for box detection",
            default=0.05,
            ge=0.01,
            le=1.0,
        ),
        iou_threshold: float = Input(
            description="IOU threshold for overlap detection",
            default=0.1,
            ge=0.01,
            le=1.0,
        ),
    ) -> dict:
        """Run object detection on a screenshot and return coordinates"""
        
        # Ensure the input image exists and is valid
        if not image.exists():
            raise ValueError("Input image file does not exist")
            
        # Open and validate the image
        try:
            input_image = Image.open(image)
            input_image.verify()  # Verify it's a valid image
        except Exception as e:
            raise ValueError(f"Invalid image file: {str(e)}")

        # Save input image temporarily
        image_save_path = '/tmp/input_image.png'
        input_image = Image.open(image)
        input_image.save(image_save_path)

        # Get OCR results
        ocr_bbox_rslt, _ = check_ocr_box(
            image_save_path, 
            display_img=False, 
            output_bb_format='xyxy',
            goal_filtering=None, 
            easyocr_args={'paragraph': False, 'text_threshold': 0.9}
        )
        text, ocr_bbox = ocr_bbox_rslt

        # Get labeled image and coordinates
        dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
            image_save_path,
            self.yolo_model,
            BOX_TRESHOLD=box_threshold,
            output_coord_in_ratio=True,
            ocr_bbox=ocr_bbox,
            draw_bbox_config=self.draw_bbox_config,
            caption_model_processor=self.caption_model_processor,
            ocr_text=text,
            iou_threshold=iou_threshold
        )

        # Format output
        elements = []
        for i, (element_id, coords) in enumerate(label_coordinates.items()):
            x, y, w, h = coords
            
            # Calculate center points (normalized)
            center_x_norm = x + (w/2)
            center_y_norm = y + (h/2)
            
            # Calculate screen coordinates
            screen_x = int(center_x_norm * screen_width)
            screen_y = int(center_y_norm * screen_height)
            
            # Calculate element dimensions on screen
            screen_w = int(w * screen_width)
            screen_h = int(h * screen_height)

            element = {
                "description": parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}",
                "normalized_coordinates": {
                    "x": center_x_norm,
                    "y": center_y_norm
                },
                "screen_coordinates": {
                    "x": screen_x,
                    "y": screen_y
                },
                "dimensions": {
                    "width": screen_w,
                    "height": screen_h
                }
            }
            elements.append(element)

        return {
            "image": dino_labeled_img,  # Base64 encoded image
            "elements": elements
        }