zorba111 commited on
Commit
36a599e
·
verified ·
1 Parent(s): 2ad48f3

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. Dockerfile +26 -0
  2. README.md +33 -10
  3. api.py +19 -98
  4. modal_app.py +144 -0
  5. requirements.txt +29 -16
  6. test-api.py +30 -0
  7. test_api.py +43 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy requirements first to leverage Docker cache
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy the rest of the application
10
+ COPY . .
11
+
12
+ # Install system dependencies for PIL and torch
13
+ RUN apt-get update && apt-get install -y \
14
+ libgl1-mesa-glx \
15
+ libglib2.0-0 \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # Set environment variables
19
+ ENV GRADIO_SERVER_NAME=0.0.0.0
20
+ ENV GRADIO_SERVER_PORT=7860
21
+
22
+ # Expose the port
23
+ EXPOSE 7860
24
+
25
+ # Run the application
26
+ CMD ["python", "gradio_demo.py"]
README.md CHANGED
@@ -4,6 +4,7 @@ app_file: gradio_demo.py
4
  sdk: gradio
5
  sdk_version: 5.4.0
6
  ---
 
7
  # OmniParser: Screen Parsing tool for Pure Vision Based GUI Agent
8
 
9
  <p align="center">
@@ -13,50 +14,72 @@ sdk_version: 5.4.0
13
  [![arXiv](https://img.shields.io/badge/Paper-green)](https://arxiv.org/abs/2408.00203)
14
  [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
15
 
16
- 📢 [[Project Page](https://microsoft.github.io/OmniParser/)] [[Blog Post](https://www.microsoft.com/en-us/research/articles/omniparser-for-pure-vision-based-gui-agent/)] [[Models](https://huggingface.co/microsoft/OmniParser)]
17
 
18
- **OmniParser** is a comprehensive method for parsing user interface screenshots into structured and easy-to-understand elements, which significantly enhances the ability of GPT-4V to generate actions that can be accurately grounded in the corresponding regions of the interface.
19
 
20
  ## News
 
21
  - [2024/10] Both Interactive Region Detection Model and Icon functional description model are released! [Hugginface models](https://huggingface.co/microsoft/OmniParser)
22
- - [2024/09] OmniParser achieves the best performance on [Windows Agent Arena](https://microsoft.github.io/WindowsAgentArena/)!
 
 
23
 
24
- ## Install
25
  Install environment:
 
26
  ```python
27
  conda create -n "omni" python==3.12
28
  conda activate omni
29
  pip install -r requirements.txt
30
  ```
31
 
32
- Then download the model ckpts files in: https://huggingface.co/microsoft/OmniParser, and put them under weights/, default folder structure is: weights/icon_detect, weights/icon_caption_florence, weights/icon_caption_blip2.
 
 
33
 
34
- Finally, convert the safetensor to .pt file.
35
  ```python
36
  python weights/convert_safetensor_to_pt.py
37
  ```
38
 
39
  ## Examples:
40
- We put together a few simple examples in the demo.ipynb.
 
41
 
42
  ## Gradio Demo
 
43
  To run gradio demo, simply run:
 
44
  ```python
45
  python gradio_demo.py
46
  ```
47
 
48
-
49
  ## 📚 Citation
 
50
  Our technical report can be found [here](https://arxiv.org/abs/2408.00203).
51
  If you find our work useful, please consider citing our work:
 
52
  ```
53
  @misc{lu2024omniparserpurevisionbased,
54
- title={OmniParser for Pure Vision Based GUI Agent},
55
  author={Yadong Lu and Jianwei Yang and Yelong Shen and Ahmed Awadallah},
56
  year={2024},
57
  eprint={2408.00203},
58
  archivePrefix={arXiv},
59
  primaryClass={cs.CV},
60
- url={https://arxiv.org/abs/2408.00203},
61
  }
62
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  sdk: gradio
5
  sdk_version: 5.4.0
6
  ---
7
+
8
  # OmniParser: Screen Parsing tool for Pure Vision Based GUI Agent
9
 
10
  <p align="center">
 
14
  [![arXiv](https://img.shields.io/badge/Paper-green)](https://arxiv.org/abs/2408.00203)
15
  [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
16
 
17
+ 📢 [[Project Page](https://microsoft.github.io/OmniParser/)] [[Blog Post](https://www.microsoft.com/en-us/research/articles/omniparser-for-pure-vision-based-gui-agent/)] [[Models](https://huggingface.co/microsoft/OmniParser)]
18
 
19
+ **OmniParser** is a comprehensive method for parsing user interface screenshots into structured and easy-to-understand elements, which significantly enhances the ability of GPT-4V to generate actions that can be accurately grounded in the corresponding regions of the interface.
20
 
21
  ## News
22
+
23
  - [2024/10] Both Interactive Region Detection Model and Icon functional description model are released! [Hugginface models](https://huggingface.co/microsoft/OmniParser)
24
+ - [2024/09] OmniParser achieves the best performance on [Windows Agent Arena](https://microsoft.github.io/WindowsAgentArena/)!
25
+
26
+ ## Install
27
 
 
28
  Install environment:
29
+
30
  ```python
31
  conda create -n "omni" python==3.12
32
  conda activate omni
33
  pip install -r requirements.txt
34
  ```
35
 
36
+ Then download the model ckpts files in: https://huggingface.co/microsoft/OmniParser, and put them under weights/, default folder structure is: weights/icon_detect, weights/icon_caption_florence, weights/icon_caption_blip2.
37
+
38
+ Finally, convert the safetensor to .pt file.
39
 
 
40
  ```python
41
  python weights/convert_safetensor_to_pt.py
42
  ```
43
 
44
  ## Examples:
45
+
46
+ We put together a few simple examples in the demo.ipynb.
47
 
48
  ## Gradio Demo
49
+
50
  To run gradio demo, simply run:
51
+
52
  ```python
53
  python gradio_demo.py
54
  ```
55
 
 
56
  ## 📚 Citation
57
+
58
  Our technical report can be found [here](https://arxiv.org/abs/2408.00203).
59
  If you find our work useful, please consider citing our work:
60
+
61
  ```
62
  @misc{lu2024omniparserpurevisionbased,
63
+ title={OmniParser for Pure Vision Based GUI Agent},
64
  author={Yadong Lu and Jianwei Yang and Yelong Shen and Ahmed Awadallah},
65
  year={2024},
66
  eprint={2408.00203},
67
  archivePrefix={arXiv},
68
  primaryClass={cs.CV},
69
+ url={https://arxiv.org/abs/2408.00203},
70
  }
71
  ```
72
+
73
+ title: Ui Element Coordinates Finder
74
+ emoji: 🏢
75
+ colorFrom: pink
76
+ colorTo: red
77
+ sdk: gradio
78
+ sdk_version: 5.4.0
79
+ app_file: app.py
80
+ pinned: false
81
+ license: mit
82
+
83
+ ---
84
+
85
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
api.py CHANGED
@@ -1,105 +1,26 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from pydantic import BaseModel
3
- from PIL import Image
4
- import io
5
- import torch
6
  from slowapi import Limiter, _rate_limit_exceeded_handler
7
  from slowapi.util import get_remote_address
8
  from slowapi.errors import RateLimitExceeded
 
9
 
10
- # Import your existing utilities and models
11
- from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
12
-
13
- # Initialize FastAPI app
14
- app = FastAPI(title="OmniParser API")
15
- app.state.limiter = Limiter(key_func=get_remote_address)
16
  app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
17
 
18
- # Load models at startup (reusing your existing code)
19
- yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
20
- caption_model_processor = get_caption_model_processor(
21
- model_name="florence2",
22
- model_name_or_path="weights/icon_caption_florence"
23
- )
24
-
25
- # Define request model
26
- class ProcessRequest(BaseModel):
27
- box_threshold: float = 0.05
28
- iou_threshold: float = 0.1
29
- screen_width: int = 1920
30
- screen_height: int = 1080
31
-
32
  @app.post("/process")
33
- @app.state.limiter.limit("5/minute") # Limit to 5 requests per minute per IP
34
- async def process_image(
35
- file: UploadFile = File(...),
36
- params: ProcessRequest = None
37
- ):
38
- # Read image from request
39
- image_bytes = await file.read()
40
- image = Image.open(io.BytesIO(image_bytes))
41
-
42
- # Save image temporarily (reusing your existing logic)
43
- temp_path = 'imgs/temp_image.png'
44
- image.save(temp_path)
45
-
46
- # Process image using your existing functions
47
- ocr_bbox_rslt, _ = check_ocr_box(
48
- temp_path,
49
- display_img=False,
50
- output_bb_format='xyxy',
51
- goal_filtering=None,
52
- easyocr_args={'paragraph': False, 'text_threshold':0.9}
53
- )
54
-
55
- text, ocr_bbox = ocr_bbox_rslt
56
-
57
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
58
- temp_path,
59
- yolo_model,
60
- BOX_TRESHOLD=params.box_threshold,
61
- output_coord_in_ratio=True,
62
- ocr_bbox=ocr_bbox,
63
- draw_bbox_config={
64
- 'text_scale': 0.8,
65
- 'text_thickness': 2,
66
- 'text_padding': 2,
67
- 'thickness': 2,
68
- },
69
- caption_model_processor=caption_model_processor,
70
- ocr_text=text,
71
- iou_threshold=params.iou_threshold
72
- )
73
-
74
- # Format output (similar to your existing code)
75
- output_text = []
76
- for i, (element_id, coords) in enumerate(label_coordinates.items()):
77
- x, y, w, h = coords
78
- center_x_norm = x + (w/2)
79
- center_y_norm = y + (h/2)
80
- screen_x = int(center_x_norm * params.screen_width)
81
- screen_y = int(center_y_norm * params.screen_height)
82
- screen_w = int(w * params.screen_width)
83
- screen_h = int(h * params.screen_height)
84
-
85
- element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}"
86
- output_text.append({
87
- "description": element_desc,
88
- "normalized_coordinates": {
89
- "x": center_x_norm,
90
- "y": center_y_norm
91
- },
92
- "screen_coordinates": {
93
- "x": screen_x,
94
- "y": screen_y
95
- },
96
- "dimensions": {
97
- "width": screen_w,
98
- "height": screen_h
99
- }
100
- })
101
-
102
- return {
103
- "processed_image": dino_labled_img, # Base64 encoded image
104
- "elements": output_text
105
- }
 
1
+ from fastapi import FastAPI, File, UploadFile, Request
 
 
 
 
2
  from slowapi import Limiter, _rate_limit_exceeded_handler
3
  from slowapi.util import get_remote_address
4
  from slowapi.errors import RateLimitExceeded
5
+ from fastapi.responses import JSONResponse
6
 
7
+ app = FastAPI()
8
+ limiter = Limiter(key_func=get_remote_address)
9
+ app.state.limiter = limiter
 
 
 
10
  app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @app.post("/process")
13
+ @limiter.limit("5/minute")
14
+ async def process_image(request: Request, file: UploadFile = File(...)):
15
+ try:
16
+ contents = await file.read()
17
+ # Your processing logic here
18
+ return JSONResponse(
19
+ status_code=200,
20
+ content={"message": "Success", "filename": file.filename}
21
+ )
22
+ except Exception as e:
23
+ return JSONResponse(
24
+ status_code=500,
25
+ content={"error": str(e)}
26
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modal_app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ from fastapi import FastAPI, File, UploadFile, Request
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
+ from typing import Optional
9
+ import traceback
10
+
11
+ # Create app and web app
12
+ app = modal.App("ui-coordinates-finder")
13
+ web_app = FastAPI()
14
+
15
+ # Add your model initialization to the app
16
+ @app.function(gpu="T4")
17
+ def init_models():
18
+ from utils import get_yolo_model, get_caption_model_processor
19
+
20
+ yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
21
+ caption_model_processor = get_caption_model_processor(
22
+ model_name="florence2",
23
+ model_name_or_path="weights/icon_caption_florence"
24
+ )
25
+ return yolo_model, caption_model_processor
26
+
27
+ # Configure CORS
28
+ web_app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ @app.function(gpu="T4", timeout=300)
37
+ @web_app.post("/process")
38
+ async def process_image_endpoint(
39
+ request: Request,
40
+ file: UploadFile = File(...),
41
+ box_threshold: float = 0.05,
42
+ iou_threshold: float = 0.1,
43
+ screen_width: int = 1920,
44
+ screen_height: int = 1080
45
+ ):
46
+ try:
47
+ # Add logging for debugging
48
+ print(f"Processing file: {file.filename}")
49
+
50
+ # Read and process the image
51
+ contents = await file.read()
52
+ print("File read successfully")
53
+
54
+ # Save image temporarily
55
+ image_save_path = '/tmp/saved_image_demo.png'
56
+ image = Image.open(io.BytesIO(contents))
57
+ image.save(image_save_path)
58
+
59
+ # Initialize models
60
+ yolo_model, caption_model_processor = init_models()
61
+
62
+ # Process with OCR and detection
63
+ from utils import check_ocr_box, get_som_labeled_img
64
+
65
+ draw_bbox_config = {
66
+ 'text_scale': 0.8,
67
+ 'text_thickness': 2,
68
+ 'text_padding': 2,
69
+ 'thickness': 2,
70
+ }
71
+
72
+ ocr_bbox_rslt, _ = check_ocr_box(
73
+ image_save_path,
74
+ display_img=False,
75
+ output_bb_format='xyxy',
76
+ goal_filtering=None,
77
+ easyocr_args={'paragraph': False, 'text_threshold': 0.9}
78
+ )
79
+ text, ocr_bbox = ocr_bbox_rslt
80
+
81
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
82
+ image_save_path,
83
+ yolo_model,
84
+ BOX_TRESHOLD=box_threshold,
85
+ output_coord_in_ratio=True,
86
+ ocr_bbox=ocr_bbox,
87
+ draw_bbox_config=draw_bbox_config,
88
+ caption_model_processor=caption_model_processor,
89
+ ocr_text=text,
90
+ iou_threshold=iou_threshold
91
+ )
92
+
93
+ # Format the output similar to Gradio demo
94
+ output_text = []
95
+ for i, (element_id, coords) in enumerate(label_coordinates.items()):
96
+ x, y, w, h = coords
97
+
98
+ # Calculate center points (normalized)
99
+ center_x_norm = x + (w/2)
100
+ center_y_norm = y + (h/2)
101
+
102
+ # Calculate screen coordinates
103
+ screen_x = int(center_x_norm * screen_width)
104
+ screen_y = int(center_y_norm * screen_height)
105
+ screen_w = int(w * screen_width)
106
+ screen_h = int(h * screen_height)
107
+
108
+ if i < len(parsed_content_list):
109
+ element_desc = parsed_content_list[i]
110
+ output_text.append({
111
+ "description": element_desc,
112
+ "normalized_coords": (center_x_norm, center_y_norm),
113
+ "screen_coords": (screen_x, screen_y),
114
+ "dimensions": (screen_w, screen_h)
115
+ })
116
+
117
+ return JSONResponse(
118
+ status_code=200,
119
+ content={
120
+ "message": "Success",
121
+ "filename": file.filename,
122
+ "processed_image": dino_labled_img, # Base64 encoded image
123
+ "elements": output_text
124
+ }
125
+ )
126
+
127
+ except Exception as e:
128
+ error_details = traceback.format_exc()
129
+ print(f"Error processing request: {error_details}")
130
+ return JSONResponse(
131
+ status_code=500,
132
+ content={
133
+ "error": str(e),
134
+ "details": error_details
135
+ }
136
+ )
137
+
138
+ @app.function()
139
+ @modal.asgi_app()
140
+ def fastapi_app():
141
+ return web_app
142
+
143
+ if __name__ == "__main__":
144
+ app.serve()
requirements.txt CHANGED
@@ -1,16 +1,29 @@
1
- torch
2
- easyocr
3
- torchvision
4
- supervision==0.18.0
5
- openai==1.3.5
6
- transformers
7
- ultralytics==8.1.24
8
- azure-identity
9
- numpy
10
- opencv-python
11
- opencv-python-headless
12
- gradio
13
- dill
14
- accelerate
15
- timm
16
- einops==0.8.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.12 as base image
2
+ FROM python:3.12-slim
3
+
4
+ # Install system dependencies required for OpenCV and other packages
5
+ RUN apt-get update && apt-get install -y \
6
+ libgl1-mesa-glx \
7
+ libglib2.0-0 \
8
+ build-essential \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Set working directory
12
+ WORKDIR /app
13
+
14
+ # Copy requirements and app files
15
+ COPY requirements.txt .
16
+ COPY . .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Set environment variables
22
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
23
+ ENV GRADIO_SERVER_PORT=7860
24
+
25
+ # Expose the port Gradio will run on
26
+ EXPOSE 7860
27
+
28
+ # Command to run the application
29
+ CMD ["python", "gradio_demo.py"]
test-api.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ def test_api():
5
+ url = "https://zorba11--ui-coordinates-finder-fastapi-app.modal.run/process"
6
+
7
+ headers = {
8
+ 'Accept': 'application/json',
9
+ }
10
+
11
+ try:
12
+ files = {
13
+ 'file': ('screen-1.png', open('/Users/zorba11/Desktop/screen-1.png', 'rb'), 'image/png')
14
+ }
15
+
16
+ response = requests.post(
17
+ url,
18
+ files=files,
19
+ headers=headers
20
+ )
21
+
22
+ print(f"Status Code: {response.status_code}")
23
+ print(f"Response Headers: {dict(response.headers)}")
24
+ print(f"Response Content: {response.content.decode()}")
25
+
26
+ except Exception as e:
27
+ print(f"Error: {str(e)}")
28
+
29
+ if __name__ == "__main__":
30
+ test_api()
test_api.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import base64
4
+ import io
5
+
6
+ def test_api():
7
+ url = "https://zorba11--ui-coordinates-finder-fastapi-app.modal.run/process"
8
+
9
+ # Parameters matching your Gradio demo
10
+ params = {
11
+ 'box_threshold': 0.05,
12
+ 'iou_threshold': 0.1,
13
+ 'screen_width': 1920,
14
+ 'screen_height': 1080
15
+ }
16
+
17
+ files = {
18
+ 'file': ('screen-1.png', open('/Users/zorba11/Desktop/screen-1.png', 'rb'), 'image/png')
19
+ }
20
+
21
+ response = requests.post(url, files=files, params=params)
22
+
23
+ if response.status_code == 200:
24
+ result = response.json()
25
+
26
+ # Convert base64 image back to PIL Image
27
+ img_data = base64.b64decode(result['processed_image'])
28
+ processed_image = Image.open(io.BytesIO(img_data))
29
+
30
+ # Save the processed image
31
+ processed_image.save('processed_output.png')
32
+
33
+ # Print the detected elements
34
+ for element in result['elements']:
35
+ print("\nElement:", element['description'])
36
+ print("Normalized coordinates:", element['normalized_coords'])
37
+ print("Screen coordinates:", element['screen_coords'])
38
+ print("Dimensions:", element['dimensions'])
39
+ else:
40
+ print("Error:", response.text)
41
+
42
+ if __name__ == "__main__":
43
+ test_api()