File size: 6,270 Bytes
2ad48f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from typing import Optional

import gradio as gr
import numpy as np
import torch
from PIL import Image
import io


import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image

yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
platform = 'pc'
if platform == 'pc':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 2,
        'thickness': 2,
    }
elif platform == 'web':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }
elif platform == 'mobile':
    draw_bbox_config = {
        'text_scale': 0.8,
        'text_thickness': 2,
        'text_padding': 3,
        'thickness': 3,
    }



MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
    <a href="https://arxiv.org/pdf/2408.00203">
        <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
    </a>
</div>

OmniParser is a screen parsing tool to convert general GUI screen to structured elements. 
"""

DEVICE = torch.device('cuda')

# @spaces.GPU
# @torch.inference_mode()
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
    image_input,
    box_threshold,
    iou_threshold,
    screen_width,
    screen_height
) -> Optional[Image.Image]:
    """
    Process the image and return both normalized and screen coordinates
    
    Args:
        image_input: Input image
        box_threshold: Confidence threshold for box detection
        iou_threshold: IOU threshold for overlap detection
        screen_width: Actual screen width in pixels
        screen_height: Actual screen height in pixels
    """
    image_save_path = 'imgs/saved_image_demo.png'
    image_input.save(image_save_path)

    # Get image dimensions
    image_width = image_input.width
    image_height = image_input.height

    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img=False, output_bb_format='xyxy', 
                                                   goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
    text, ocr_bbox = ocr_bbox_rslt
    
    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
        image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, 
        output_coord_in_ratio=True, ocr_bbox=ocr_bbox,
        draw_bbox_config=draw_bbox_config, 
        caption_model_processor=caption_model_processor, 
        ocr_text=text, iou_threshold=iou_threshold
    )
    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    
    # Format the output to include both normalized and screen coordinates
    output_text = []
    for i, (element_id, coords) in enumerate(label_coordinates.items()):
        x, y, w, h = coords
        
        # Calculate center points (normalized)
        center_x_norm = x + (w/2)
        center_y_norm = y + (h/2)
        
        # Calculate screen coordinates
        screen_x = int(center_x_norm * screen_width)
        screen_y = int(center_y_norm * screen_height)
        
        # Calculate element dimensions on screen
        screen_w = int(w * screen_width)
        screen_h = int(h * screen_height)
        
        if i < len(parsed_content_list):
            # For text elements
            element_desc = parsed_content_list[i]
            output_text.append(
                f"{element_desc}\n"
                f"  Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
                f"  Screen coordinates: ({screen_x}, {screen_y})\n"
                f"  Dimensions: {screen_w}x{screen_h} pixels"
            )
        else:
            # For icon elements without text
            output_text.append(
                f"Icon {i}\n"
                f"  Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
                f"  Screen coordinates: ({screen_x}, {screen_y})\n"
                f"  Dimensions: {screen_w}x{screen_h} pixels"
            )
    
    parsed_content = '\n\n'.join(output_text)
    return image, parsed_content



with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(
                type='pil', label='Upload image')
            
            with gr.Row():
                # Screen dimension inputs
                screen_width_component = gr.Number(
                    label='Screen Width (pixels)', 
                    value=1920,  # Default value
                    precision=0
                )
                screen_height_component = gr.Number(
                    label='Screen Height (pixels)', 
                    value=1080,  # Default value
                    precision=0
                )
            
            # Threshold sliders
            box_threshold_component = gr.Slider(
                label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
            iou_threshold_component = gr.Slider(
                label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
            
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
                
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
            text_output_component = gr.Textbox(
                label='Parsed screen elements', 
                placeholder='Text Output',
                lines=10  # Increased to show more content
            )

    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component,
            screen_width_component,
            screen_height_component
        ],
        outputs=[image_output_component, text_output_component]
    )

# demo.launch(debug=False, show_error=True, share=True)
demo.launch(share=True, server_port=7861, server_name='0.0.0.0')