Spaces:
Build error
Build error
File size: 6,270 Bytes
2ad48f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from typing import Optional
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image
yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
platform = 'pc'
if platform == 'pc':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
}
elif platform == 'web':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 3,
'thickness': 3,
}
elif platform == 'mobile':
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 3,
'thickness': 3,
}
MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
<a href="https://arxiv.org/pdf/2408.00203">
<img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
</a>
</div>
OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
"""
DEVICE = torch.device('cuda')
# @spaces.GPU
# @torch.inference_mode()
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process(
image_input,
box_threshold,
iou_threshold,
screen_width,
screen_height
) -> Optional[Image.Image]:
"""
Process the image and return both normalized and screen coordinates
Args:
image_input: Input image
box_threshold: Confidence threshold for box detection
iou_threshold: IOU threshold for overlap detection
screen_width: Actual screen width in pixels
screen_height: Actual screen height in pixels
"""
image_save_path = 'imgs/saved_image_demo.png'
image_input.save(image_save_path)
# Get image dimensions
image_width = image_input.width
image_height = image_input.height
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img=False, output_bb_format='xyxy',
goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path, yolo_model, BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True, ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text, iou_threshold=iou_threshold
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
# Format the output to include both normalized and screen coordinates
output_text = []
for i, (element_id, coords) in enumerate(label_coordinates.items()):
x, y, w, h = coords
# Calculate center points (normalized)
center_x_norm = x + (w/2)
center_y_norm = y + (h/2)
# Calculate screen coordinates
screen_x = int(center_x_norm * screen_width)
screen_y = int(center_y_norm * screen_height)
# Calculate element dimensions on screen
screen_w = int(w * screen_width)
screen_h = int(h * screen_height)
if i < len(parsed_content_list):
# For text elements
element_desc = parsed_content_list[i]
output_text.append(
f"{element_desc}\n"
f" Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
f" Screen coordinates: ({screen_x}, {screen_y})\n"
f" Dimensions: {screen_w}x{screen_h} pixels"
)
else:
# For icon elements without text
output_text.append(
f"Icon {i}\n"
f" Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
f" Screen coordinates: ({screen_x}, {screen_y})\n"
f" Dimensions: {screen_w}x{screen_h} pixels"
)
parsed_content = '\n\n'.join(output_text)
return image, parsed_content
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
with gr.Row():
# Screen dimension inputs
screen_width_component = gr.Number(
label='Screen Width (pixels)',
value=1920, # Default value
precision=0
)
screen_height_component = gr.Number(
label='Screen Height (pixels)',
value=1080, # Default value
precision=0
)
# Threshold sliders
box_threshold_component = gr.Slider(
label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
iou_threshold_component = gr.Slider(
label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
text_output_component = gr.Textbox(
label='Parsed screen elements',
placeholder='Text Output',
lines=10 # Increased to show more content
)
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
screen_width_component,
screen_height_component
],
outputs=[image_output_component, text_output_component]
)
# demo.launch(debug=False, show_error=True, share=True)
demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
|