zorba111 commited on
Commit
2ad48f3
·
verified ·
1 Parent(s): f280342

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The .dockerignore file excludes files from the container build process.
2
+ #
3
+ # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4
+
5
+ # Exclude Git files
6
+ **/.git
7
+ **/.github
8
+ **/.gitignore
9
+
10
+ # Exclude Python cache files
11
+ __pycache__
12
+ .mypy_cache
13
+ .pytest_cache
14
+ .ruff_cache
15
+
16
+ # Exclude Python virtual environment
17
+ /venv
18
+
19
+
20
+ # Python
21
+ __pycache__
22
+ *.pyc
23
+ *.pyo
24
+ *.pyd
25
+ .Python
26
+ env/
27
+ venv/
28
+ .env/
29
+ .venv/
30
+ pip-log.txt
31
+ pip-delete-this-directory.txt
32
+ .tox/
33
+ .coverage
34
+ .coverage.*
35
+ .cache
36
+ nosetests.xml
37
+ coverage.xml
38
+ *.cover
39
+ *.log
40
+ .pytest_cache/
41
+ .python-version
42
+
43
+ # Editor directories and files
44
+ .idea
45
+ .vscode
46
+ *.swp
47
+ *.swo
48
+ *~
49
+
50
+ # OS generated files
51
+ .DS_Store
52
+ .DS_Store?
53
+ ._*
54
+ .Spotlight-V100
55
+ .Trashes
56
+ ehthumbs.db
57
+ Thumbs.db
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/windows_home.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/push.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Push to Replicate
2
+
3
+ on:
4
+ # Workflow dispatch allows you to manually trigger the workflow from GitHub.com
5
+ # Go to your repo, click "Actions", click "Push to Replicate", click "Run workflow"
6
+ workflow_dispatch:
7
+ inputs:
8
+ model_name:
9
+ description: 'Enter the model name, like "alice/bunny-detector". If unset, this will default to the value of `image` in cog.yaml.'
10
+ # # Uncomment these lines to trigger the workflow on every push to the main branch
11
+ # push:
12
+ # branches:
13
+ # - main
14
+
15
+ jobs:
16
+ push_to_replicate:
17
+ name: Push to Replicate
18
+
19
+ # If your model is large, the default GitHub Actions runner may not
20
+ # have enough disk space. If you need more space you can set up a
21
+ # bigger runner on GitHub.
22
+ runs-on: ubuntu-latest
23
+
24
+ steps:
25
+ # This action cleans up disk space to make more room for your
26
+ # model code, weights, etc.
27
+ - name: Free disk space
28
+ uses: jlumbroso/[email protected]
29
+ with:
30
+ tool-cache: false
31
+ docker-images: false
32
+
33
+ - name: Checkout
34
+ uses: actions/checkout@v4
35
+
36
+ # This action installs Docker buildx and Cog (and optionally CUDA)
37
+ - name: Setup Cog
38
+ uses: replicate/setup-cog@v2
39
+ with:
40
+ # If you set REPLICATE_API_TOKEN in your GitHub repository secrets,
41
+ # the action will authenticate with Replicate automatically so you
42
+ # can push your model
43
+ token: ${{ secrets.REPLICATE_API_TOKEN }}
44
+
45
+ # If you trigger the workflow manually, you can specify the model name.
46
+ # If you leave it blank (or if the workflow is triggered by a push), the
47
+ # model name will be derived from the `image` value in cog.yaml.
48
+ - name: Push to Replicate
49
+ run: |
50
+ if [ -n "${{ inputs.model_name }}" ]; then
51
+ cog push r8.im/${{ inputs.model_name }}
52
+ else
53
+ cog push
54
+ fi
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - master
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ weights/icon_caption_blip2
2
+ weights/icon_caption_florence
3
+ weights/icon_detect/
4
+ .gradio
5
+ __pycache__
6
+
7
+ .venv
8
+ __pycache__
9
+
10
+ # Python virtual environment
11
+ .venv/
12
+ venv/
13
+ ENV/
14
+ env/
15
+
16
+ # Python bytecode
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # Distribution/packaging
22
+ dist/
23
+ build/
24
+ *.egg-info/
25
+ *.egg
26
+
27
+ # IDE settings
28
+ .vscode/
29
+ .idea/
30
+
31
+ # Environment variables
32
+ .env
33
+
34
+ # Pip logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Testing
39
+ .pytest_cache/
40
+ .coverage
41
+ htmlcov/
42
+
43
+ # macOS
44
+ .DS_Store
45
+
46
+ .cog
47
+
48
+ omni
.huggingface/space-config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "runtime": "python3.12",
3
+ "sdk": "gradio",
4
+ "app_file": "gradio_demo.py",
5
+ "python_packages": ["requirements.txt"]
6
+ }
LICENSE ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution 4.0 International Public License
58
+
59
+ By exercising the Licensed Rights (defined below), You accept and agree
60
+ to be bound by the terms and conditions of this Creative Commons
61
+ Attribution 4.0 International Public License ("Public License"). To the
62
+ extent this Public License may be interpreted as a contract, You are
63
+ granted the Licensed Rights in consideration of Your acceptance of
64
+ these terms and conditions, and the Licensor grants You such rights in
65
+ consideration of benefits the Licensor receives from making the
66
+ Licensed Material available under these terms and conditions.
67
+
68
+
69
+ Section 1 -- Definitions.
70
+
71
+ a. Adapted Material means material subject to Copyright and Similar
72
+ Rights that is derived from or based upon the Licensed Material
73
+ and in which the Licensed Material is translated, altered,
74
+ arranged, transformed, or otherwise modified in a manner requiring
75
+ permission under the Copyright and Similar Rights held by the
76
+ Licensor. For purposes of this Public License, where the Licensed
77
+ Material is a musical work, performance, or sound recording,
78
+ Adapted Material is always produced where the Licensed Material is
79
+ synched in timed relation with a moving image.
80
+
81
+ b. Adapter's License means the license You apply to Your Copyright
82
+ and Similar Rights in Your contributions to Adapted Material in
83
+ accordance with the terms and conditions of this Public License.
84
+
85
+ c. Copyright and Similar Rights means copyright and/or similar rights
86
+ closely related to copyright including, without limitation,
87
+ performance, broadcast, sound recording, and Sui Generis Database
88
+ Rights, without regard to how the rights are labeled or
89
+ categorized. For purposes of this Public License, the rights
90
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
91
+ Rights.
92
+
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. Share means to provide material to the public by any means or
116
+ process that requires permission under the Licensed Rights, such
117
+ as reproduction, public display, public performance, distribution,
118
+ dissemination, communication, or importation, and to make material
119
+ available to the public including in ways that members of the
120
+ public may access the material from a place and at a time
121
+ individually chosen by them.
122
+
123
+ j. Sui Generis Database Rights means rights other than copyright
124
+ resulting from Directive 96/9/EC of the European Parliament and of
125
+ the Council of 11 March 1996 on the legal protection of databases,
126
+ as amended and/or succeeded, as well as other essentially
127
+ equivalent rights anywhere in the world.
128
+
129
+ k. You means the individual or entity exercising the Licensed Rights
130
+ under this Public License. Your has a corresponding meaning.
131
+
132
+
133
+ Section 2 -- Scope.
134
+
135
+ a. License grant.
136
+
137
+ 1. Subject to the terms and conditions of this Public License,
138
+ the Licensor hereby grants You a worldwide, royalty-free,
139
+ non-sublicensable, non-exclusive, irrevocable license to
140
+ exercise the Licensed Rights in the Licensed Material to:
141
+
142
+ a. reproduce and Share the Licensed Material, in whole or
143
+ in part; and
144
+
145
+ b. produce, reproduce, and Share Adapted Material.
146
+
147
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
148
+ Exceptions and Limitations apply to Your use, this Public
149
+ License does not apply, and You do not need to comply with
150
+ its terms and conditions.
151
+
152
+ 3. Term. The term of this Public License is specified in Section
153
+ 6(a).
154
+
155
+ 4. Media and formats; technical modifications allowed. The
156
+ Licensor authorizes You to exercise the Licensed Rights in
157
+ all media and formats whether now known or hereafter created,
158
+ and to make technical modifications necessary to do so. The
159
+ Licensor waives and/or agrees not to assert any right or
160
+ authority to forbid You from making technical modifications
161
+ necessary to exercise the Licensed Rights, including
162
+ technical modifications necessary to circumvent Effective
163
+ Technological Measures. For purposes of this Public License,
164
+ simply making modifications authorized by this Section 2(a)
165
+ (4) never produces Adapted Material.
166
+
167
+ 5. Downstream recipients.
168
+
169
+ a. Offer from the Licensor -- Licensed Material. Every
170
+ recipient of the Licensed Material automatically
171
+ receives an offer from the Licensor to exercise the
172
+ Licensed Rights under the terms and conditions of this
173
+ Public License.
174
+
175
+ b. No downstream restrictions. You may not offer or impose
176
+ any additional or different terms or conditions on, or
177
+ apply any Effective Technological Measures to, the
178
+ Licensed Material if doing so restricts exercise of the
179
+ Licensed Rights by any recipient of the Licensed
180
+ Material.
181
+
182
+ 6. No endorsement. Nothing in this Public License constitutes or
183
+ may be construed as permission to assert or imply that You
184
+ are, or that Your use of the Licensed Material is, connected
185
+ with, or sponsored, endorsed, or granted official status by,
186
+ the Licensor or others designated to receive attribution as
187
+ provided in Section 3(a)(1)(A)(i).
188
+
189
+ b. Other rights.
190
+
191
+ 1. Moral rights, such as the right of integrity, are not
192
+ licensed under this Public License, nor are publicity,
193
+ privacy, and/or other similar personality rights; however, to
194
+ the extent possible, the Licensor waives and/or agrees not to
195
+ assert any such rights held by the Licensor to the limited
196
+ extent necessary to allow You to exercise the Licensed
197
+ Rights, but not otherwise.
198
+
199
+ 2. Patent and trademark rights are not licensed under this
200
+ Public License.
201
+
202
+ 3. To the extent possible, the Licensor waives any right to
203
+ collect royalties from You for the exercise of the Licensed
204
+ Rights, whether directly or through a collecting society
205
+ under any voluntary or waivable statutory or compulsory
206
+ licensing scheme. In all other cases the Licensor expressly
207
+ reserves any right to collect such royalties.
208
+
209
+
210
+ Section 3 -- License Conditions.
211
+
212
+ Your exercise of the Licensed Rights is expressly made subject to the
213
+ following conditions.
214
+
215
+ a. Attribution.
216
+
217
+ 1. If You Share the Licensed Material (including in modified
218
+ form), You must:
219
+
220
+ a. retain the following if it is supplied by the Licensor
221
+ with the Licensed Material:
222
+
223
+ i. identification of the creator(s) of the Licensed
224
+ Material and any others designated to receive
225
+ attribution, in any reasonable manner requested by
226
+ the Licensor (including by pseudonym if
227
+ designated);
228
+
229
+ ii. a copyright notice;
230
+
231
+ iii. a notice that refers to this Public License;
232
+
233
+ iv. a notice that refers to the disclaimer of
234
+ warranties;
235
+
236
+ v. a URI or hyperlink to the Licensed Material to the
237
+ extent reasonably practicable;
238
+
239
+ b. indicate if You modified the Licensed Material and
240
+ retain an indication of any previous modifications; and
241
+
242
+ c. indicate the Licensed Material is licensed under this
243
+ Public License, and include the text of, or the URI or
244
+ hyperlink to, this Public License.
245
+
246
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
247
+ reasonable manner based on the medium, means, and context in
248
+ which You Share the Licensed Material. For example, it may be
249
+ reasonable to satisfy the conditions by providing a URI or
250
+ hyperlink to a resource that includes the required
251
+ information.
252
+
253
+ 3. If requested by the Licensor, You must remove any of the
254
+ information required by Section 3(a)(1)(A) to the extent
255
+ reasonably practicable.
256
+
257
+ 4. If You Share Adapted Material You produce, the Adapter's
258
+ License You apply must not prevent recipients of the Adapted
259
+ Material from complying with this Public License.
260
+
261
+
262
+ Section 4 -- Sui Generis Database Rights.
263
+
264
+ Where the Licensed Rights include Sui Generis Database Rights that
265
+ apply to Your use of the Licensed Material:
266
+
267
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
268
+ to extract, reuse, reproduce, and Share all or a substantial
269
+ portion of the contents of the database;
270
+
271
+ b. if You include all or a substantial portion of the database
272
+ contents in a database in which You have Sui Generis Database
273
+ Rights, then the database in which You have Sui Generis Database
274
+ Rights (but not its individual contents) is Adapted Material; and
275
+
276
+ c. You must comply with the conditions in Section 3(a) if You Share
277
+ all or a substantial portion of the contents of the database.
278
+
279
+ For the avoidance of doubt, this Section 4 supplements and does not
280
+ replace Your obligations under this Public License where the Licensed
281
+ Rights include other Copyright and Similar Rights.
282
+
283
+
284
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
285
+
286
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
287
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
288
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
289
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
290
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
291
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
292
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
293
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
294
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
295
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
296
+
297
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
298
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
299
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
300
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
301
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
302
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
303
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
304
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
305
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
306
+
307
+ c. The disclaimer of warranties and limitation of liability provided
308
+ above shall be interpreted in a manner that, to the extent
309
+ possible, most closely approximates an absolute disclaimer and
310
+ waiver of all liability.
311
+
312
+
313
+ Section 6 -- Term and Termination.
314
+
315
+ a. This Public License applies for the term of the Copyright and
316
+ Similar Rights licensed here. However, if You fail to comply with
317
+ this Public License, then Your rights under this Public License
318
+ terminate automatically.
319
+
320
+ b. Where Your right to use the Licensed Material has terminated under
321
+ Section 6(a), it reinstates:
322
+
323
+ 1. automatically as of the date the violation is cured, provided
324
+ it is cured within 30 days of Your discovery of the
325
+ violation; or
326
+
327
+ 2. upon express reinstatement by the Licensor.
328
+
329
+ For the avoidance of doubt, this Section 6(b) does not affect any
330
+ right the Licensor may have to seek remedies for Your violations
331
+ of this Public License.
332
+
333
+ c. For the avoidance of doubt, the Licensor may also offer the
334
+ Licensed Material under separate terms or conditions or stop
335
+ distributing the Licensed Material at any time; however, doing so
336
+ will not terminate this Public License.
337
+
338
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
339
+ License.
340
+
341
+
342
+ Section 7 -- Other Terms and Conditions.
343
+
344
+ a. The Licensor shall not be bound by any additional or different
345
+ terms or conditions communicated by You unless expressly agreed.
346
+
347
+ b. Any arrangements, understandings, or agreements regarding the
348
+ Licensed Material not stated herein are separate from and
349
+ independent of the terms and conditions of this Public License.
350
+
351
+
352
+ Section 8 -- Interpretation.
353
+
354
+ a. For the avoidance of doubt, this Public License does not, and
355
+ shall not be interpreted to, reduce, limit, restrict, or impose
356
+ conditions on any use of the Licensed Material that could lawfully
357
+ be made without permission under this Public License.
358
+
359
+ b. To the extent possible, if any provision of this Public License is
360
+ deemed unenforceable, it shall be automatically reformed to the
361
+ minimum extent necessary to make it enforceable. If the provision
362
+ cannot be reformed, it shall be severed from this Public License
363
+ without affecting the enforceability of the remaining terms and
364
+ conditions.
365
+
366
+ c. No term or condition of this Public License will be waived and no
367
+ failure to comply consented to unless expressly agreed to by the
368
+ Licensor.
369
+
370
+ d. Nothing in this Public License constitutes or may be interpreted
371
+ as a limitation upon, or waiver of, any privileges and immunities
372
+ that apply to the Licensor or You, including from the legal
373
+ processes of any jurisdiction or authority.
374
+
375
+
376
+ =======================================================================
377
+
378
+ Creative Commons is not a party to its public
379
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
380
+ its public licenses to material it publishes and in those instances
381
+ will be considered the “Licensor.” The text of the Creative Commons
382
+ public licenses is dedicated to the public domain under the CC0 Public
383
+ Domain Dedication. Except for the limited purpose of indicating that
384
+ material is shared under a Creative Commons public license or as
385
+ otherwise permitted by the Creative Commons policies published at
386
+ creativecommons.org/policies, Creative Commons does not authorize the
387
+ use of the trademark "Creative Commons" or any other trademark or logo
388
+ of Creative Commons without its prior written consent including,
389
+ without limitation, in connection with any unauthorized modifications
390
+ to any of its public licenses or any other arrangements,
391
+ understandings, or agreements concerning use of licensed material. For
392
+ the avoidance of doubt, this paragraph does not form part of the
393
+ public licenses.
394
+
395
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,12 +1,62 @@
1
  ---
2
- title: Ui Coordinates Finder
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.4.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ui-coordinates-finder
3
+ app_file: gradio_demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.4.0
 
 
6
  ---
7
+ # OmniParser: Screen Parsing tool for Pure Vision Based GUI Agent
8
 
9
+ <p align="center">
10
+ <img src="imgs/logo.png" alt="Logo">
11
+ </p>
12
+
13
+ [![arXiv](https://img.shields.io/badge/Paper-green)](https://arxiv.org/abs/2408.00203)
14
+ [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
15
+
16
+ 📢 [[Project Page](https://microsoft.github.io/OmniParser/)] [[Blog Post](https://www.microsoft.com/en-us/research/articles/omniparser-for-pure-vision-based-gui-agent/)] [[Models](https://huggingface.co/microsoft/OmniParser)]
17
+
18
+ **OmniParser** is a comprehensive method for parsing user interface screenshots into structured and easy-to-understand elements, which significantly enhances the ability of GPT-4V to generate actions that can be accurately grounded in the corresponding regions of the interface.
19
+
20
+ ## News
21
+ - [2024/10] Both Interactive Region Detection Model and Icon functional description model are released! [Hugginface models](https://huggingface.co/microsoft/OmniParser)
22
+ - [2024/09] OmniParser achieves the best performance on [Windows Agent Arena](https://microsoft.github.io/WindowsAgentArena/)!
23
+
24
+ ## Install
25
+ Install environment:
26
+ ```python
27
+ conda create -n "omni" python==3.12
28
+ conda activate omni
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ Then download the model ckpts files in: https://huggingface.co/microsoft/OmniParser, and put them under weights/, default folder structure is: weights/icon_detect, weights/icon_caption_florence, weights/icon_caption_blip2.
33
+
34
+ Finally, convert the safetensor to .pt file.
35
+ ```python
36
+ python weights/convert_safetensor_to_pt.py
37
+ ```
38
+
39
+ ## Examples:
40
+ We put together a few simple examples in the demo.ipynb.
41
+
42
+ ## Gradio Demo
43
+ To run gradio demo, simply run:
44
+ ```python
45
+ python gradio_demo.py
46
+ ```
47
+
48
+
49
+ ## 📚 Citation
50
+ Our technical report can be found [here](https://arxiv.org/abs/2408.00203).
51
+ If you find our work useful, please consider citing our work:
52
+ ```
53
+ @misc{lu2024omniparserpurevisionbased,
54
+ title={OmniParser for Pure Vision Based GUI Agent},
55
+ author={Yadong Lu and Jianwei Yang and Yelong Shen and Ahmed Awadallah},
56
+ year={2024},
57
+ eprint={2408.00203},
58
+ archivePrefix={arXiv},
59
+ primaryClass={cs.CV},
60
+ url={https://arxiv.org/abs/2408.00203},
61
+ }
62
+ ```
SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [[email protected]](mailto:[email protected]). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
api.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from pydantic import BaseModel
3
+ from PIL import Image
4
+ import io
5
+ import torch
6
+ from slowapi import Limiter, _rate_limit_exceeded_handler
7
+ from slowapi.util import get_remote_address
8
+ from slowapi.errors import RateLimitExceeded
9
+
10
+ # Import your existing utilities and models
11
+ from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
12
+
13
+ # Initialize FastAPI app
14
+ app = FastAPI(title="OmniParser API")
15
+ app.state.limiter = Limiter(key_func=get_remote_address)
16
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
17
+
18
+ # Load models at startup (reusing your existing code)
19
+ yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
20
+ caption_model_processor = get_caption_model_processor(
21
+ model_name="florence2",
22
+ model_name_or_path="weights/icon_caption_florence"
23
+ )
24
+
25
+ # Define request model
26
+ class ProcessRequest(BaseModel):
27
+ box_threshold: float = 0.05
28
+ iou_threshold: float = 0.1
29
+ screen_width: int = 1920
30
+ screen_height: int = 1080
31
+
32
+ @app.post("/process")
33
+ @app.state.limiter.limit("5/minute") # Limit to 5 requests per minute per IP
34
+ async def process_image(
35
+ file: UploadFile = File(...),
36
+ params: ProcessRequest = None
37
+ ):
38
+ # Read image from request
39
+ image_bytes = await file.read()
40
+ image = Image.open(io.BytesIO(image_bytes))
41
+
42
+ # Save image temporarily (reusing your existing logic)
43
+ temp_path = 'imgs/temp_image.png'
44
+ image.save(temp_path)
45
+
46
+ # Process image using your existing functions
47
+ ocr_bbox_rslt, _ = check_ocr_box(
48
+ temp_path,
49
+ display_img=False,
50
+ output_bb_format='xyxy',
51
+ goal_filtering=None,
52
+ easyocr_args={'paragraph': False, 'text_threshold':0.9}
53
+ )
54
+
55
+ text, ocr_bbox = ocr_bbox_rslt
56
+
57
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
58
+ temp_path,
59
+ yolo_model,
60
+ BOX_TRESHOLD=params.box_threshold,
61
+ output_coord_in_ratio=True,
62
+ ocr_bbox=ocr_bbox,
63
+ draw_bbox_config={
64
+ 'text_scale': 0.8,
65
+ 'text_thickness': 2,
66
+ 'text_padding': 2,
67
+ 'thickness': 2,
68
+ },
69
+ caption_model_processor=caption_model_processor,
70
+ ocr_text=text,
71
+ iou_threshold=params.iou_threshold
72
+ )
73
+
74
+ # Format output (similar to your existing code)
75
+ output_text = []
76
+ for i, (element_id, coords) in enumerate(label_coordinates.items()):
77
+ x, y, w, h = coords
78
+ center_x_norm = x + (w/2)
79
+ center_y_norm = y + (h/2)
80
+ screen_x = int(center_x_norm * params.screen_width)
81
+ screen_y = int(center_y_norm * params.screen_height)
82
+ screen_w = int(w * params.screen_width)
83
+ screen_h = int(h * params.screen_height)
84
+
85
+ element_desc = parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}"
86
+ output_text.append({
87
+ "description": element_desc,
88
+ "normalized_coordinates": {
89
+ "x": center_x_norm,
90
+ "y": center_y_norm
91
+ },
92
+ "screen_coordinates": {
93
+ "x": screen_x,
94
+ "y": screen_y
95
+ },
96
+ "dimensions": {
97
+ "width": screen_w,
98
+ "height": screen_h
99
+ }
100
+ })
101
+
102
+ return {
103
+ "processed_image": dino_labled_img, # Base64 encoded image
104
+ "elements": output_text
105
+ }
cog.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+
8
+ # a list of ubuntu apt packages to install
9
+ system_packages:
10
+ - 'libgl1-mesa-glx'
11
+ - 'libglib2.0-0'
12
+ - 'libsm6'
13
+ - 'libxext6'
14
+ - 'libxrender-dev'
15
+ - 'libgomp1'
16
+ - 'wget'
17
+ - 'git'
18
+
19
+ # python version in the form '3.11' or '3.11.4'
20
+ python_version: '3.12'
21
+
22
+ # a list of packages in the format <package-name>==<version>
23
+ python_packages:
24
+ - --extra-index-url https://download.pytorch.org/whl/cu121
25
+ - 'torch==2.3.1'
26
+ - pillow==10.2.0
27
+ - numpy==1.26.4
28
+ - opencv-python-headless==4.9.0.80
29
+ - easyocr==1.7.1
30
+ - transformers==4.37.2
31
+ - ultralytics==8.1.2
32
+ - python-bidi==0.4.2
33
+ - PyYAML>=5.3.1
34
+ - scipy>=1.7.1
35
+ - ninja>=1.10.2
36
+
37
+ # commands to run after the environment is setup
38
+ run:
39
+ - pip3 cache purge
40
+ - python3 -m pip install --upgrade pip
41
+ - export PYTHONUNBUFFERED=1
42
+ - export CUDA_VISIBLE_DEVICES=0
43
+ - export TORCH_HOME=/src/.torch
44
+ - export PIP_DISABLE_PIP_VERSION_CHECK=1
45
+ - export PIP_NO_CACHE_DIR=1
46
+
47
+ # predict.py defines how predictions are run on your model
48
+ predict: 'predict.py:Predictor'
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
gradio_demo.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ import io
8
+
9
+
10
+ import base64, os
11
+ from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
12
+ import torch
13
+ from PIL import Image
14
+
15
+ yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
16
+ caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
17
+ platform = 'pc'
18
+ if platform == 'pc':
19
+ draw_bbox_config = {
20
+ 'text_scale': 0.8,
21
+ 'text_thickness': 2,
22
+ 'text_padding': 2,
23
+ 'thickness': 2,
24
+ }
25
+ elif platform == 'web':
26
+ draw_bbox_config = {
27
+ 'text_scale': 0.8,
28
+ 'text_thickness': 2,
29
+ 'text_padding': 3,
30
+ 'thickness': 3,
31
+ }
32
+ elif platform == 'mobile':
33
+ draw_bbox_config = {
34
+ 'text_scale': 0.8,
35
+ 'text_thickness': 2,
36
+ 'text_padding': 3,
37
+ 'thickness': 3,
38
+ }
39
+
40
+
41
+
42
+ MARKDOWN = """
43
+ # OmniParser for Pure Vision Based General GUI Agent 🔥
44
+ <div>
45
+ <a href="https://arxiv.org/pdf/2408.00203">
46
+ <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
47
+ </a>
48
+ </div>
49
+
50
+ OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
51
+ """
52
+
53
+ DEVICE = torch.device('cuda')
54
+
55
+ # @spaces.GPU
56
+ # @torch.inference_mode()
57
+ # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
58
+ def process(
59
+ image_input,
60
+ box_threshold,
61
+ iou_threshold,
62
+ screen_width,
63
+ screen_height
64
+ ) -> Optional[Image.Image]:
65
+ """
66
+ Process the image and return both normalized and screen coordinates
67
+
68
+ Args:
69
+ image_input: Input image
70
+ box_threshold: Confidence threshold for box detection
71
+ iou_threshold: IOU threshold for overlap detection
72
+ screen_width: Actual screen width in pixels
73
+ screen_height: Actual screen height in pixels
74
+ """
75
+ image_save_path = 'imgs/saved_image_demo.png'
76
+ image_input.save(image_save_path)
77
+
78
+ # Get image dimensions
79
+ image_width = image_input.width
80
+ image_height = image_input.height
81
+
82
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img=False, output_bb_format='xyxy',
83
+ goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
84
+ text, ocr_bbox = ocr_bbox_rslt
85
+
86
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
87
+ image_save_path, yolo_model, BOX_TRESHOLD=box_threshold,
88
+ output_coord_in_ratio=True, ocr_bbox=ocr_bbox,
89
+ draw_bbox_config=draw_bbox_config,
90
+ caption_model_processor=caption_model_processor,
91
+ ocr_text=text, iou_threshold=iou_threshold
92
+ )
93
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
94
+
95
+ # Format the output to include both normalized and screen coordinates
96
+ output_text = []
97
+ for i, (element_id, coords) in enumerate(label_coordinates.items()):
98
+ x, y, w, h = coords
99
+
100
+ # Calculate center points (normalized)
101
+ center_x_norm = x + (w/2)
102
+ center_y_norm = y + (h/2)
103
+
104
+ # Calculate screen coordinates
105
+ screen_x = int(center_x_norm * screen_width)
106
+ screen_y = int(center_y_norm * screen_height)
107
+
108
+ # Calculate element dimensions on screen
109
+ screen_w = int(w * screen_width)
110
+ screen_h = int(h * screen_height)
111
+
112
+ if i < len(parsed_content_list):
113
+ # For text elements
114
+ element_desc = parsed_content_list[i]
115
+ output_text.append(
116
+ f"{element_desc}\n"
117
+ f" Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
118
+ f" Screen coordinates: ({screen_x}, {screen_y})\n"
119
+ f" Dimensions: {screen_w}x{screen_h} pixels"
120
+ )
121
+ else:
122
+ # For icon elements without text
123
+ output_text.append(
124
+ f"Icon {i}\n"
125
+ f" Normalized coordinates: ({center_x_norm:.3f}, {center_y_norm:.3f})\n"
126
+ f" Screen coordinates: ({screen_x}, {screen_y})\n"
127
+ f" Dimensions: {screen_w}x{screen_h} pixels"
128
+ )
129
+
130
+ parsed_content = '\n\n'.join(output_text)
131
+ return image, parsed_content
132
+
133
+
134
+
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown(MARKDOWN)
137
+ with gr.Row():
138
+ with gr.Column():
139
+ image_input_component = gr.Image(
140
+ type='pil', label='Upload image')
141
+
142
+ with gr.Row():
143
+ # Screen dimension inputs
144
+ screen_width_component = gr.Number(
145
+ label='Screen Width (pixels)',
146
+ value=1920, # Default value
147
+ precision=0
148
+ )
149
+ screen_height_component = gr.Number(
150
+ label='Screen Height (pixels)',
151
+ value=1080, # Default value
152
+ precision=0
153
+ )
154
+
155
+ # Threshold sliders
156
+ box_threshold_component = gr.Slider(
157
+ label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
158
+ iou_threshold_component = gr.Slider(
159
+ label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
160
+
161
+ submit_button_component = gr.Button(
162
+ value='Submit', variant='primary')
163
+
164
+ with gr.Column():
165
+ image_output_component = gr.Image(type='pil', label='Image Output')
166
+ text_output_component = gr.Textbox(
167
+ label='Parsed screen elements',
168
+ placeholder='Text Output',
169
+ lines=10 # Increased to show more content
170
+ )
171
+
172
+ submit_button_component.click(
173
+ fn=process,
174
+ inputs=[
175
+ image_input_component,
176
+ box_threshold_component,
177
+ iou_threshold_component,
178
+ screen_width_component,
179
+ screen_height_component
180
+ ],
181
+ outputs=[image_output_component, text_output_component]
182
+ )
183
+
184
+ # demo.launch(debug=False, show_error=True, share=True)
185
+ demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
imgs/google_page.png ADDED
imgs/logo.png ADDED
imgs/saved_image_demo.png ADDED
imgs/windows_home.png ADDED

Git LFS Details

  • SHA256: 036008abc32379393876e722fedab2bd02bda9b667b957bc150c2f83c725ebac
  • Pointer size: 132 Bytes
  • Size of remote file: 6.1 MB
imgs/windows_multitab.png ADDED
omniparser.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model
2
+ import torch
3
+ from ultralytics import YOLO
4
+ from PIL import Image
5
+ from typing import Dict, Tuple, List
6
+ import io
7
+ import base64
8
+
9
+
10
+ config = {
11
+ 'som_model_path': 'finetuned_icon_detect.pt',
12
+ 'device': 'cpu',
13
+ 'caption_model_path': 'Salesforce/blip2-opt-2.7b',
14
+ 'draw_bbox_config': {
15
+ 'text_scale': 0.8,
16
+ 'text_thickness': 2,
17
+ 'text_padding': 3,
18
+ 'thickness': 3,
19
+ },
20
+ 'BOX_TRESHOLD': 0.05
21
+ }
22
+
23
+
24
+ class Omniparser(object):
25
+ def __init__(self, config: Dict):
26
+ self.config = config
27
+
28
+ self.som_model = get_yolo_model(model_path=config['som_model_path'])
29
+ # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
30
+ # self.caption_model_processor['model'].to(torch.float32)
31
+
32
+ def parse(self, image_path: str):
33
+ print('Parsing image:', image_path)
34
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
35
+ text, ocr_bbox = ocr_bbox_rslt
36
+
37
+ draw_bbox_config = self.config['draw_bbox_config']
38
+ BOX_TRESHOLD = self.config['BOX_TRESHOLD']
39
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)
40
+
41
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
42
+ # formating output
43
+ return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
44
+ 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]
45
+ return_list.extend(
46
+ [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
47
+ 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]
48
+ )
49
+
50
+ return [image, return_list]
51
+
52
+ parser = Omniparser(config)
53
+ image_path = 'examples/pc_1.png'
54
+
55
+ # time the parser
56
+ import time
57
+ s = time.time()
58
+ image, parsed_content_list = parser.parse(image_path)
59
+ device = config['device']
60
+ print(f'Time taken for Omniparser on {device}:', time.time() - s)
predict.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ from PIL import Image
6
+ from utils import (
7
+ check_ocr_box,
8
+ get_yolo_model,
9
+ get_caption_model_processor,
10
+ get_som_labeled_img
11
+ )
12
+
13
+
14
+ class Predictor(BasePredictor):
15
+ def setup(self):
16
+ """Load the model into memory"""
17
+ self.yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
18
+ self.caption_model_processor = get_caption_model_processor(
19
+ model_name="florence2",
20
+ model_name_or_path="weights/icon_caption_florence"
21
+ )
22
+ self.draw_bbox_config = {
23
+ 'text_scale': 0.8,
24
+ 'text_thickness': 2,
25
+ 'text_padding': 2,
26
+ 'thickness': 2,
27
+ }
28
+
29
+ def predict(
30
+ self,
31
+ image: Path = Input(description="Screenshot of the screen"),
32
+ screen_width: int = Input(
33
+ description="Screen width in pixels",
34
+ default=1920,
35
+ ge=800, # Setting minimum reasonable screen width
36
+ le=7680, # Supporting up to 8K displays
37
+ ),
38
+ screen_height: int = Input(
39
+ description="Screen height in pixels",
40
+ default=1080,
41
+ ge=600, # Setting minimum reasonable screen height
42
+ le=4320, # Supporting up to 8K displays
43
+ ),
44
+ box_threshold: float = Input(
45
+ description="Confidence threshold for box detection",
46
+ default=0.05,
47
+ ge=0.01,
48
+ le=1.0,
49
+ ),
50
+ iou_threshold: float = Input(
51
+ description="IOU threshold for overlap detection",
52
+ default=0.1,
53
+ ge=0.01,
54
+ le=1.0,
55
+ ),
56
+ ) -> dict:
57
+ """Run object detection on a screenshot and return coordinates"""
58
+
59
+ # Ensure the input image exists and is valid
60
+ if not image.exists():
61
+ raise ValueError("Input image file does not exist")
62
+
63
+ # Open and validate the image
64
+ try:
65
+ input_image = Image.open(image)
66
+ input_image.verify() # Verify it's a valid image
67
+ except Exception as e:
68
+ raise ValueError(f"Invalid image file: {str(e)}")
69
+
70
+ # Save input image temporarily
71
+ image_save_path = '/tmp/input_image.png'
72
+ input_image = Image.open(image)
73
+ input_image.save(image_save_path)
74
+
75
+ # Get OCR results
76
+ ocr_bbox_rslt, _ = check_ocr_box(
77
+ image_save_path,
78
+ display_img=False,
79
+ output_bb_format='xyxy',
80
+ goal_filtering=None,
81
+ easyocr_args={'paragraph': False, 'text_threshold': 0.9}
82
+ )
83
+ text, ocr_bbox = ocr_bbox_rslt
84
+
85
+ # Get labeled image and coordinates
86
+ dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
87
+ image_save_path,
88
+ self.yolo_model,
89
+ BOX_TRESHOLD=box_threshold,
90
+ output_coord_in_ratio=True,
91
+ ocr_bbox=ocr_bbox,
92
+ draw_bbox_config=self.draw_bbox_config,
93
+ caption_model_processor=self.caption_model_processor,
94
+ ocr_text=text,
95
+ iou_threshold=iou_threshold
96
+ )
97
+
98
+ # Format output
99
+ elements = []
100
+ for i, (element_id, coords) in enumerate(label_coordinates.items()):
101
+ x, y, w, h = coords
102
+
103
+ # Calculate center points (normalized)
104
+ center_x_norm = x + (w/2)
105
+ center_y_norm = y + (h/2)
106
+
107
+ # Calculate screen coordinates
108
+ screen_x = int(center_x_norm * screen_width)
109
+ screen_y = int(center_y_norm * screen_height)
110
+
111
+ # Calculate element dimensions on screen
112
+ screen_w = int(w * screen_width)
113
+ screen_h = int(h * screen_height)
114
+
115
+ element = {
116
+ "description": parsed_content_list[i] if i < len(parsed_content_list) else f"Icon {i}",
117
+ "normalized_coordinates": {
118
+ "x": center_x_norm,
119
+ "y": center_y_norm
120
+ },
121
+ "screen_coordinates": {
122
+ "x": screen_x,
123
+ "y": screen_y
124
+ },
125
+ "dimensions": {
126
+ "width": screen_w,
127
+ "height": screen_h
128
+ }
129
+ }
130
+ elements.append(element)
131
+
132
+ return {
133
+ "image": dino_labeled_img, # Base64 encoded image
134
+ "elements": elements
135
+ }
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ easyocr
3
+ torchvision
4
+ supervision==0.18.0
5
+ openai==1.3.5
6
+ transformers
7
+ ultralytics==8.1.24
8
+ azure-identity
9
+ numpy
10
+ opencv-python
11
+ opencv-python-headless
12
+ gradio
13
+ dill
14
+ accelerate
15
+ timm
16
+ einops==0.8.0
util/__init__.py ADDED
File without changes
util/action_matching.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
3
+ '''
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ # import action_type as action_type_lib
10
+ import enum
11
+
12
+ class ActionType(enum.IntEnum):
13
+ # Placeholders for unused enum values
14
+ UNUSED_0 = 0
15
+ UNUSED_1 = 1
16
+ UNUSED_2 = 2
17
+ UNUSED_8 = 8
18
+ UNUSED_9 = 9
19
+
20
+ ########### Agent actions ###########
21
+
22
+ # A type action that sends text to the emulator. Note that this simply sends
23
+ # text and does not perform any clicks for element focus or enter presses for
24
+ # submitting text.
25
+ TYPE = 3
26
+
27
+ # The dual point action used to represent all gestures.
28
+ DUAL_POINT = 4
29
+
30
+ # These actions differentiate pressing the home and back button from touches.
31
+ # They represent explicit presses of back and home performed using ADB.
32
+ PRESS_BACK = 5
33
+ PRESS_HOME = 6
34
+
35
+ # An action representing that ADB command for hitting enter was performed.
36
+ PRESS_ENTER = 7
37
+
38
+ ########### Episode status actions ###########
39
+
40
+ # An action used to indicate the desired task has been completed and resets
41
+ # the environment. This action should also be used in the case that the task
42
+ # has already been completed and there is nothing to do.
43
+ # e.g. The task is to turn on the Wi-Fi when it is already on
44
+ STATUS_TASK_COMPLETE = 10
45
+
46
+ # An action used to indicate that desired task is impossible to complete and
47
+ # resets the environment. This can be a result of many different things
48
+ # including UI changes, Android version differences, etc.
49
+ STATUS_TASK_IMPOSSIBLE = 11
50
+
51
+
52
+ _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
53
+ ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
54
+ ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
55
+
56
+ # Interval determining if an action is a tap or a swipe.
57
+ _SWIPE_DISTANCE_THRESHOLD = 0.04
58
+
59
+
60
+ def _yx_in_bounding_boxes(
61
+ yx, bounding_boxes
62
+ ):
63
+ """Check if the (y,x) point is contained in each bounding box.
64
+
65
+ Args:
66
+ yx: The (y, x) coordinate in pixels of the point.
67
+ bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
68
+ represents a bounding box: (y_top_left, x_top_left, box_height,
69
+ box_width). Note: containment is inclusive of the bounding box edges.
70
+
71
+ Returns:
72
+ is_inside: A 1D bool array where each element specifies if the point is
73
+ contained within the respective box.
74
+ """
75
+ y, x = yx
76
+
77
+ # `bounding_boxes` has shape (n_elements, 4); we extract each array along the
78
+ # last axis into shape (n_elements, 1), then squeeze unneeded dimension.
79
+ top, left, height, width = [
80
+ jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
81
+ ]
82
+
83
+ # The y-axis is inverted for AndroidEnv, so bottom = top + height.
84
+ bottom, right = top + height, left + width
85
+
86
+ return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
87
+ x >= left, x <= right)
88
+
89
+
90
+ def _resize_annotation_bounding_boxes(
91
+ annotation_positions, annotation_width_augment_fraction,
92
+ annotation_height_augment_fraction):
93
+ """Resize the bounding boxes by the given fractions.
94
+
95
+ Args:
96
+ annotation_positions: Array of shape (N, 4), where each row represents the
97
+ (y, x, height, width) of the bounding boxes.
98
+ annotation_width_augment_fraction: The fraction to augment the box widths,
99
+ E.g., 1.4 == 240% total increase.
100
+ annotation_height_augment_fraction: Same as described for width, but for box
101
+ height.
102
+
103
+ Returns:
104
+ Resized bounding box.
105
+
106
+ """
107
+ height_change = (
108
+ annotation_height_augment_fraction * annotation_positions[:, 2])
109
+ width_change = (
110
+ annotation_width_augment_fraction * annotation_positions[:, 3])
111
+
112
+ # Limit bounding box positions to the screen.
113
+ resized_annotations = jnp.stack([
114
+ jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
115
+ jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
116
+ jnp.minimum(1, annotation_positions[:, 2] + height_change),
117
+ jnp.minimum(1, annotation_positions[:, 3] + width_change),
118
+ ],
119
+ axis=1)
120
+ return resized_annotations
121
+
122
+
123
+ def is_tap_action(normalized_start_yx,
124
+ normalized_end_yx):
125
+ distance = jnp.linalg.norm(
126
+ jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
127
+ return distance <= _SWIPE_DISTANCE_THRESHOLD
128
+
129
+
130
+ def _is_non_dual_point_action(action_type):
131
+ return jnp.not_equal(action_type, ActionType.DUAL_POINT)
132
+
133
+
134
+ def _check_tap_actions_match(
135
+ tap_1_yx,
136
+ tap_2_yx,
137
+ annotation_positions,
138
+ matching_tap_distance_threshold_screen_percentage,
139
+ annotation_width_augment_fraction,
140
+ annotation_height_augment_fraction,
141
+ ):
142
+ """Determines if two tap actions are the same."""
143
+ resized_annotation_positions = _resize_annotation_bounding_boxes(
144
+ annotation_positions,
145
+ annotation_width_augment_fraction,
146
+ annotation_height_augment_fraction,
147
+ )
148
+
149
+ # Check if the ground truth tap action falls in an annotation's bounding box.
150
+ tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
151
+ tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
152
+ both_in_box = jnp.max(tap1_in_box & tap2_in_box)
153
+
154
+ # If the ground-truth tap action falls outside any of the annotation
155
+ # bounding boxes or one of the actions is inside a bounding box and the other
156
+ # is outside bounding box or vice versa, compare the points using Euclidean
157
+ # distance.
158
+ within_threshold = (
159
+ jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
160
+ <= matching_tap_distance_threshold_screen_percentage
161
+ )
162
+ return jnp.logical_or(both_in_box, within_threshold)
163
+
164
+
165
+ def _check_drag_actions_match(
166
+ drag_1_touch_yx,
167
+ drag_1_lift_yx,
168
+ drag_2_touch_yx,
169
+ drag_2_lift_yx,
170
+ ):
171
+ """Determines if two drag actions are the same."""
172
+ # Store drag deltas (the change in the y and x coordinates from touch to
173
+ # lift), magnitudes, and the index of the main axis, which is the axis with
174
+ # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
175
+ # ending at (0.3, 0.5) has a main axis index of 1).
176
+ drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
177
+ drag_1_magnitudes = jnp.abs(drag_1_deltas)
178
+ drag_1_main_axis = np.argmax(drag_1_magnitudes)
179
+ drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
180
+ drag_2_magnitudes = jnp.abs(drag_2_deltas)
181
+ drag_2_main_axis = np.argmax(drag_2_magnitudes)
182
+
183
+ return jnp.equal(drag_1_main_axis, drag_2_main_axis)
184
+
185
+
186
+ def check_actions_match(
187
+ action_1_touch_yx,
188
+ action_1_lift_yx,
189
+ action_1_action_type,
190
+ action_2_touch_yx,
191
+ action_2_lift_yx,
192
+ action_2_action_type,
193
+ annotation_positions,
194
+ tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
195
+ annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
196
+ annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
197
+ ):
198
+ """Determines if two actions are considered to be the same.
199
+
200
+ Two actions being "the same" is defined here as two actions that would result
201
+ in a similar screen state.
202
+
203
+ Args:
204
+ action_1_touch_yx: The (y, x) coordinates of the first action's touch.
205
+ action_1_lift_yx: The (y, x) coordinates of the first action's lift.
206
+ action_1_action_type: The action type of the first action.
207
+ action_2_touch_yx: The (y, x) coordinates of the second action's touch.
208
+ action_2_lift_yx: The (y, x) coordinates of the second action's lift.
209
+ action_2_action_type: The action type of the second action.
210
+ annotation_positions: The positions of the UI annotations for the screen. It
211
+ is A 2D int array of shape (num_bboxes, 4), where each row represents a
212
+ bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
213
+ containment is inclusive of the bounding box edges.
214
+ tap_distance_threshold: The threshold that determines if two taps result in
215
+ a matching screen state if they don't fall the same bounding boxes.
216
+ annotation_width_augment_fraction: The fraction to increase the width of the
217
+ bounding box by.
218
+ annotation_height_augment_fraction: The fraction to increase the height of
219
+ of the bounding box by.
220
+
221
+ Returns:
222
+ A boolean representing whether the two given actions are the same or not.
223
+ """
224
+ action_1_touch_yx = jnp.asarray(action_1_touch_yx)
225
+ action_1_lift_yx = jnp.asarray(action_1_lift_yx)
226
+ action_2_touch_yx = jnp.asarray(action_2_touch_yx)
227
+ action_2_lift_yx = jnp.asarray(action_2_lift_yx)
228
+
229
+ # Checks if at least one of the actions is global (i.e. not DUAL_POINT),
230
+ # because if that is the case, only the actions' types need to be compared.
231
+ has_non_dual_point_action = jnp.logical_or(
232
+ _is_non_dual_point_action(action_1_action_type),
233
+ _is_non_dual_point_action(action_2_action_type),
234
+ )
235
+ #print("non dual point: "+str(has_non_dual_point_action))
236
+
237
+ different_dual_point_types = jnp.logical_xor(
238
+ is_tap_action(action_1_touch_yx, action_1_lift_yx),
239
+ is_tap_action(action_2_touch_yx, action_2_lift_yx),
240
+ )
241
+ #print("different dual type: "+str(different_dual_point_types))
242
+
243
+ is_tap = jnp.logical_and(
244
+ is_tap_action(action_1_touch_yx, action_1_lift_yx),
245
+ is_tap_action(action_2_touch_yx, action_2_lift_yx),
246
+ )
247
+ #print("is tap: "+str(is_tap))
248
+
249
+ taps_match = _check_tap_actions_match(
250
+ action_1_touch_yx,
251
+ action_2_touch_yx,
252
+ annotation_positions,
253
+ tap_distance_threshold,
254
+ annotation_width_augment_fraction,
255
+ annotation_height_augment_fraction,
256
+ )
257
+ #print("tap match: "+str(taps_match))
258
+
259
+ taps_match = jnp.logical_and(is_tap, taps_match)
260
+ #print("tap match: "+str(taps_match))
261
+
262
+ drags_match = _check_drag_actions_match(
263
+ action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
264
+ )
265
+ drags_match = jnp.where(is_tap, False, drags_match)
266
+ #print("drag match: "+str(drags_match))
267
+
268
+ return jnp.where(
269
+ has_non_dual_point_action,
270
+ jnp.equal(action_1_action_type, action_2_action_type),
271
+ jnp.where(
272
+ different_dual_point_types,
273
+ False,
274
+ jnp.logical_or(taps_match, drags_match),
275
+ ),
276
+ )
277
+
278
+
279
+ def action_2_format(step_data):
280
+ # 把test数据集中的动作格式转换为计算matching score的格式
281
+ action_type = step_data["action_type_id"]
282
+
283
+ if action_type == 4:
284
+ if step_data["action_type_text"] == 'click': # 点击
285
+ touch_point = step_data["touch"]
286
+ lift_point = step_data["lift"]
287
+ else: # 上下左右滑动
288
+ if step_data["action_type_text"] == 'scroll down':
289
+ touch_point = [0.5, 0.8]
290
+ lift_point = [0.5, 0.2]
291
+ elif step_data["action_type_text"] == 'scroll up':
292
+ touch_point = [0.5, 0.2]
293
+ lift_point = [0.5, 0.8]
294
+ elif step_data["action_type_text"] == 'scroll left':
295
+ touch_point = [0.2, 0.5]
296
+ lift_point = [0.8, 0.5]
297
+ elif step_data["action_type_text"] == 'scroll right':
298
+ touch_point = [0.8, 0.5]
299
+ lift_point = [0.2, 0.5]
300
+ else:
301
+ touch_point = [-1.0, -1.0]
302
+ lift_point = [-1.0, -1.0]
303
+
304
+ if action_type == 3:
305
+ typed_text = step_data["type_text"]
306
+ else:
307
+ typed_text = ""
308
+
309
+ action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
310
+ "typed_text": typed_text}
311
+
312
+ action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
313
+ action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
314
+ action["typed_text"] = action["typed_text"].lower()
315
+
316
+ return action
317
+
318
+
319
+ def pred_2_format(step_data):
320
+ # 把模型输出的内容转换为计算action_matching的格式
321
+ action_type = step_data["action_type"]
322
+
323
+ if action_type == 4: # 点击
324
+ action_type_new = 4
325
+ touch_point = step_data["click_point"]
326
+ lift_point = step_data["click_point"]
327
+ typed_text = ""
328
+ elif action_type == 0:
329
+ action_type_new = 4
330
+ touch_point = [0.5, 0.8]
331
+ lift_point = [0.5, 0.2]
332
+ typed_text = ""
333
+ elif action_type == 1:
334
+ action_type_new = 4
335
+ touch_point = [0.5, 0.2]
336
+ lift_point = [0.5, 0.8]
337
+ typed_text = ""
338
+ elif action_type == 8:
339
+ action_type_new = 4
340
+ touch_point = [0.2, 0.5]
341
+ lift_point = [0.8, 0.5]
342
+ typed_text = ""
343
+ elif action_type == 9:
344
+ action_type_new = 4
345
+ touch_point = [0.8, 0.5]
346
+ lift_point = [0.2, 0.5]
347
+ typed_text = ""
348
+ else:
349
+ action_type_new = action_type
350
+ touch_point = [-1.0, -1.0]
351
+ lift_point = [-1.0, -1.0]
352
+ typed_text = ""
353
+ if action_type_new == 3:
354
+ typed_text = step_data["typed_text"]
355
+
356
+ action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
357
+ "typed_text": typed_text}
358
+
359
+ action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
360
+ action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
361
+ action["typed_text"] = action["typed_text"].lower()
362
+
363
+ return action
364
+
365
+
366
+ def pred_2_format_simplified(step_data):
367
+ # 把模型输出的内容转换为计算action_matching的格式
368
+ action_type = step_data["action_type"]
369
+
370
+ if action_type == 'click' : # 点击
371
+ action_type_new = 4
372
+ touch_point = step_data["click_point"]
373
+ lift_point = step_data["click_point"]
374
+ typed_text = ""
375
+ elif action_type == 'scroll' and step_data["direction"] == 'down':
376
+ action_type_new = 4
377
+ touch_point = [0.5, 0.8]
378
+ lift_point = [0.5, 0.2]
379
+ typed_text = ""
380
+ elif action_type == 'scroll' and step_data["direction"] == 'up':
381
+ action_type_new = 4
382
+ touch_point = [0.5, 0.2]
383
+ lift_point = [0.5, 0.8]
384
+ typed_text = ""
385
+ elif action_type == 'scroll' and step_data["direction"] == 'left':
386
+ action_type_new = 4
387
+ touch_point = [0.2, 0.5]
388
+ lift_point = [0.8, 0.5]
389
+ typed_text = ""
390
+ elif action_type == 'scroll' and step_data["direction"] == 'right':
391
+ action_type_new = 4
392
+ touch_point = [0.8, 0.5]
393
+ lift_point = [0.2, 0.5]
394
+ typed_text = ""
395
+ elif action_type == 'type':
396
+ action_type_new = 3
397
+ touch_point = [-1.0, -1.0]
398
+ lift_point = [-1.0, -1.0]
399
+ typed_text = step_data["text"]
400
+ elif action_type == 'navigate_back':
401
+ action_type_new = 5
402
+ touch_point = [-1.0, -1.0]
403
+ lift_point = [-1.0, -1.0]
404
+ typed_text = ""
405
+ elif action_type == 'navigate_home':
406
+ action_type_new = 6
407
+ touch_point = [-1.0, -1.0]
408
+ lift_point = [-1.0, -1.0]
409
+ typed_text = ""
410
+ else:
411
+ action_type_new = action_type
412
+ touch_point = [-1.0, -1.0]
413
+ lift_point = [-1.0, -1.0]
414
+ typed_text = ""
415
+ # if action_type_new == 'type':
416
+ # typed_text = step_data["text"]
417
+
418
+ action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
419
+ "typed_text": typed_text}
420
+
421
+ action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
422
+ action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
423
+ action["typed_text"] = action["typed_text"].lower()
424
+
425
+ return action
util/action_type.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
3
+ '''
4
+
5
+ import enum
6
+
7
+ class ActionType(enum.IntEnum):
8
+
9
+ # Placeholders for unused enum values
10
+ UNUSED_0 = 0
11
+ UNUSED_1 = 1
12
+ UNUSED_2 = 2
13
+ UNUSED_8 = 8
14
+ UNUSED_9 = 9
15
+
16
+ ########### Agent actions ###########
17
+
18
+ # A type action that sends text to the emulator. Note that this simply sends
19
+ # text and does not perform any clicks for element focus or enter presses for
20
+ # submitting text.
21
+ TYPE = 3
22
+
23
+ # The dual point action used to represent all gestures.
24
+ DUAL_POINT = 4
25
+
26
+ # These actions differentiate pressing the home and back button from touches.
27
+ # They represent explicit presses of back and home performed using ADB.
28
+ PRESS_BACK = 5
29
+ PRESS_HOME = 6
30
+
31
+ # An action representing that ADB command for hitting enter was performed.
32
+ PRESS_ENTER = 7
33
+
34
+ ########### Episode status actions ###########
35
+
36
+ # An action used to indicate the desired task has been completed and resets
37
+ # the environment. This action should also be used in the case that the task
38
+ # has already been completed and there is nothing to do.
39
+ # e.g. The task is to turn on the Wi-Fi when it is already on
40
+ STATUS_TASK_COMPLETE = 10
41
+
42
+ # An action used to indicate that desired task is impossible to complete and
43
+ # resets the environment. This can be a result of many different things
44
+ # including UI changes, Android version differences, etc.
45
+ STATUS_TASK_IMPOSSIBLE = 11
util/box_annotator.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from supervision.detection.core import Detections
7
+ from supervision.draw.color import Color, ColorPalette
8
+
9
+
10
+ class BoxAnnotator:
11
+ """
12
+ A class for drawing bounding boxes on an image using detections provided.
13
+
14
+ Attributes:
15
+ color (Union[Color, ColorPalette]): The color to draw the bounding box,
16
+ can be a single color or a color palette
17
+ thickness (int): The thickness of the bounding box lines, default is 2
18
+ text_color (Color): The color of the text on the bounding box, default is white
19
+ text_scale (float): The scale of the text on the bounding box, default is 0.5
20
+ text_thickness (int): The thickness of the text on the bounding box,
21
+ default is 1
22
+ text_padding (int): The padding around the text on the bounding box,
23
+ default is 5
24
+
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
30
+ thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
31
+ text_color: Color = Color.BLACK,
32
+ text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
33
+ text_thickness: int = 2, #1, # 2 for demo
34
+ text_padding: int = 10,
35
+ avoid_overlap: bool = True,
36
+ ):
37
+ self.color: Union[Color, ColorPalette] = color
38
+ self.thickness: int = thickness
39
+ self.text_color: Color = text_color
40
+ self.text_scale: float = text_scale
41
+ self.text_thickness: int = text_thickness
42
+ self.text_padding: int = text_padding
43
+ self.avoid_overlap: bool = avoid_overlap
44
+
45
+ def annotate(
46
+ self,
47
+ scene: np.ndarray,
48
+ detections: Detections,
49
+ labels: Optional[List[str]] = None,
50
+ skip_label: bool = False,
51
+ image_size: Optional[Tuple[int, int]] = None,
52
+ ) -> np.ndarray:
53
+ """
54
+ Draws bounding boxes on the frame using the detections provided.
55
+
56
+ Args:
57
+ scene (np.ndarray): The image on which the bounding boxes will be drawn
58
+ detections (Detections): The detections for which the
59
+ bounding boxes will be drawn
60
+ labels (Optional[List[str]]): An optional list of labels
61
+ corresponding to each detection. If `labels` are not provided,
62
+ corresponding `class_id` will be used as label.
63
+ skip_label (bool): Is set to `True`, skips bounding box label annotation.
64
+ Returns:
65
+ np.ndarray: The image with the bounding boxes drawn on it
66
+
67
+ Example:
68
+ ```python
69
+ import supervision as sv
70
+
71
+ classes = ['person', ...]
72
+ image = ...
73
+ detections = sv.Detections(...)
74
+
75
+ box_annotator = sv.BoxAnnotator()
76
+ labels = [
77
+ f"{classes[class_id]} {confidence:0.2f}"
78
+ for _, _, confidence, class_id, _ in detections
79
+ ]
80
+ annotated_frame = box_annotator.annotate(
81
+ scene=image.copy(),
82
+ detections=detections,
83
+ labels=labels
84
+ )
85
+ ```
86
+ """
87
+ font = cv2.FONT_HERSHEY_SIMPLEX
88
+ for i in range(len(detections)):
89
+ x1, y1, x2, y2 = detections.xyxy[i].astype(int)
90
+ class_id = (
91
+ detections.class_id[i] if detections.class_id is not None else None
92
+ )
93
+ idx = class_id if class_id is not None else i
94
+ color = (
95
+ self.color.by_idx(idx)
96
+ if isinstance(self.color, ColorPalette)
97
+ else self.color
98
+ )
99
+ cv2.rectangle(
100
+ img=scene,
101
+ pt1=(x1, y1),
102
+ pt2=(x2, y2),
103
+ color=color.as_bgr(),
104
+ thickness=self.thickness,
105
+ )
106
+ if skip_label:
107
+ continue
108
+
109
+ text = (
110
+ f"{class_id}"
111
+ if (labels is None or len(detections) != len(labels))
112
+ else labels[i]
113
+ )
114
+
115
+ text_width, text_height = cv2.getTextSize(
116
+ text=text,
117
+ fontFace=font,
118
+ fontScale=self.text_scale,
119
+ thickness=self.text_thickness,
120
+ )[0]
121
+
122
+ if not self.avoid_overlap:
123
+ text_x = x1 + self.text_padding
124
+ text_y = y1 - self.text_padding
125
+
126
+ text_background_x1 = x1
127
+ text_background_y1 = y1 - 2 * self.text_padding - text_height
128
+
129
+ text_background_x2 = x1 + 2 * self.text_padding + text_width
130
+ text_background_y2 = y1
131
+ # text_x = x1 - self.text_padding - text_width
132
+ # text_y = y1 + self.text_padding + text_height
133
+ # text_background_x1 = x1 - 2 * self.text_padding - text_width
134
+ # text_background_y1 = y1
135
+ # text_background_x2 = x1
136
+ # text_background_y2 = y1 + 2 * self.text_padding + text_height
137
+ else:
138
+ text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
139
+
140
+ cv2.rectangle(
141
+ img=scene,
142
+ pt1=(text_background_x1, text_background_y1),
143
+ pt2=(text_background_x2, text_background_y2),
144
+ color=color.as_bgr(),
145
+ thickness=cv2.FILLED,
146
+ )
147
+ # import pdb; pdb.set_trace()
148
+ box_color = color.as_rgb()
149
+ luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
150
+ text_color = (0,0,0) if luminance > 160 else (255,255,255)
151
+ cv2.putText(
152
+ img=scene,
153
+ text=text,
154
+ org=(text_x, text_y),
155
+ fontFace=font,
156
+ fontScale=self.text_scale,
157
+ # color=self.text_color.as_rgb(),
158
+ color=text_color,
159
+ thickness=self.text_thickness,
160
+ lineType=cv2.LINE_AA,
161
+ )
162
+ return scene
163
+
164
+
165
+ def box_area(box):
166
+ return (box[2] - box[0]) * (box[3] - box[1])
167
+
168
+ def intersection_area(box1, box2):
169
+ x1 = max(box1[0], box2[0])
170
+ y1 = max(box1[1], box2[1])
171
+ x2 = min(box1[2], box2[2])
172
+ y2 = min(box1[3], box2[3])
173
+ return max(0, x2 - x1) * max(0, y2 - y1)
174
+
175
+ def IoU(box1, box2, return_max=True):
176
+ intersection = intersection_area(box1, box2)
177
+ union = box_area(box1) + box_area(box2) - intersection
178
+ if box_area(box1) > 0 and box_area(box2) > 0:
179
+ ratio1 = intersection / box_area(box1)
180
+ ratio2 = intersection / box_area(box2)
181
+ else:
182
+ ratio1, ratio2 = 0, 0
183
+ if return_max:
184
+ return max(intersection / union, ratio1, ratio2)
185
+ else:
186
+ return intersection / union
187
+
188
+
189
+ def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
190
+ """ check overlap of text and background detection box, and get_optimal_label_pos,
191
+ pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
192
+ Threshold: default to 0.3
193
+ """
194
+
195
+ def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
196
+ is_overlap = False
197
+ for i in range(len(detections)):
198
+ detection = detections.xyxy[i].astype(int)
199
+ if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
200
+ is_overlap = True
201
+ break
202
+ # check if the text is out of the image
203
+ if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
204
+ is_overlap = True
205
+ return is_overlap
206
+
207
+ # if pos == 'top left':
208
+ text_x = x1 + text_padding
209
+ text_y = y1 - text_padding
210
+
211
+ text_background_x1 = x1
212
+ text_background_y1 = y1 - 2 * text_padding - text_height
213
+
214
+ text_background_x2 = x1 + 2 * text_padding + text_width
215
+ text_background_y2 = y1
216
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
217
+ if not is_overlap:
218
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
219
+
220
+ # elif pos == 'outer left':
221
+ text_x = x1 - text_padding - text_width
222
+ text_y = y1 + text_padding + text_height
223
+
224
+ text_background_x1 = x1 - 2 * text_padding - text_width
225
+ text_background_y1 = y1
226
+
227
+ text_background_x2 = x1
228
+ text_background_y2 = y1 + 2 * text_padding + text_height
229
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
230
+ if not is_overlap:
231
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
232
+
233
+
234
+ # elif pos == 'outer right':
235
+ text_x = x2 + text_padding
236
+ text_y = y1 + text_padding + text_height
237
+
238
+ text_background_x1 = x2
239
+ text_background_y1 = y1
240
+
241
+ text_background_x2 = x2 + 2 * text_padding + text_width
242
+ text_background_y2 = y1 + 2 * text_padding + text_height
243
+
244
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
245
+ if not is_overlap:
246
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
247
+
248
+ # elif pos == 'top right':
249
+ text_x = x2 - text_padding - text_width
250
+ text_y = y1 - text_padding
251
+
252
+ text_background_x1 = x2 - 2 * text_padding - text_width
253
+ text_background_y1 = y1 - 2 * text_padding - text_height
254
+
255
+ text_background_x2 = x2
256
+ text_background_y2 = y1
257
+
258
+ is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
259
+ if not is_overlap:
260
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
261
+
262
+ return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
utils.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from ultralytics import YOLO
2
+ import os
3
+ import io
4
+ import base64
5
+ import time
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import json
8
+ import requests
9
+ # utility function
10
+ import os
11
+ from openai import AzureOpenAI
12
+
13
+ import json
14
+ import sys
15
+ import os
16
+ import cv2
17
+ import numpy as np
18
+ # %matplotlib inline
19
+ from matplotlib import pyplot as plt
20
+ import easyocr
21
+ reader = easyocr.Reader(['en'])
22
+ import time
23
+ import base64
24
+
25
+ import os
26
+ import ast
27
+ import torch
28
+ from typing import Tuple, List
29
+ from torchvision.ops import box_convert
30
+ import re
31
+ from torchvision.transforms import ToPILImage
32
+ import supervision as sv
33
+ import torchvision.transforms as T
34
+
35
+
36
+ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
37
+ if not device:
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ try:
40
+ if model_name == "blip2":
41
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
42
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
43
+ if device == 'cpu':
44
+ model = Blip2ForConditionalGeneration.from_pretrained(
45
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
46
+ )
47
+ else:
48
+ model = Blip2ForConditionalGeneration.from_pretrained(
49
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
50
+ ).to(device)
51
+ elif model_name == "florence2":
52
+ from transformers import AutoProcessor, AutoModelForCausalLM
53
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
54
+ # Try loading with safetensors first
55
+ try:
56
+ if device == 'cpu':
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_name_or_path,
59
+ torch_dtype=torch.float32,
60
+ trust_remote_code=True,
61
+ use_safetensors=True
62
+ )
63
+ else:
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ model_name_or_path,
66
+ torch_dtype=torch.float16,
67
+ trust_remote_code=True,
68
+ use_safetensors=True
69
+ ).to(device)
70
+ except Exception as e:
71
+ print(f"Failed to load with safetensors: {e}")
72
+ # Fallback to regular loading
73
+ if device == 'cpu':
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ model_name_or_path,
76
+ torch_dtype=torch.float32,
77
+ trust_remote_code=True,
78
+ use_safetensors=False
79
+ )
80
+ else:
81
+ model = AutoModelForCausalLM.from_pretrained(
82
+ model_name_or_path,
83
+ torch_dtype=torch.float16,
84
+ trust_remote_code=True,
85
+ use_safetensors=False
86
+ ).to(device)
87
+ return {'model': model.to(device), 'processor': processor}
88
+ except Exception as e:
89
+ print(f"Error loading model: {e}")
90
+ raise
91
+
92
+
93
+ def get_yolo_model(model_path):
94
+ from ultralytics import YOLO
95
+ # Load the model.
96
+ model = YOLO(model_path)
97
+ return model
98
+
99
+
100
+ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=None):
101
+ to_pil = ToPILImage()
102
+ if ocr_bbox:
103
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
104
+ else:
105
+ non_ocr_boxes = filtered_boxes
106
+ croped_pil_image = []
107
+ for i, coord in enumerate(non_ocr_boxes):
108
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
109
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
110
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
111
+ croped_pil_image.append(to_pil(cropped_image))
112
+
113
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
114
+ if not prompt:
115
+ if 'florence' in model.config.name_or_path:
116
+ prompt = "<CAPTION>"
117
+ else:
118
+ prompt = "The image shows"
119
+
120
+ batch_size = 10 # Number of samples per batch
121
+ generated_texts = []
122
+ device = model.device
123
+
124
+ for i in range(0, len(croped_pil_image), batch_size):
125
+ batch = croped_pil_image[i:i+batch_size]
126
+ if model.device.type == 'cuda':
127
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
128
+ else:
129
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
130
+ if 'florence' in model.config.name_or_path:
131
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,num_beams=3, do_sample=False)
132
+ else:
133
+ generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
134
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
135
+ generated_text = [gen.strip() for gen in generated_text]
136
+ generated_texts.extend(generated_text)
137
+
138
+ return generated_texts
139
+
140
+
141
+
142
+ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
143
+ to_pil = ToPILImage()
144
+ if ocr_bbox:
145
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
146
+ else:
147
+ non_ocr_boxes = filtered_boxes
148
+ croped_pil_image = []
149
+ for i, coord in enumerate(non_ocr_boxes):
150
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
151
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
152
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
153
+ croped_pil_image.append(to_pil(cropped_image))
154
+
155
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
156
+ device = model.device
157
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
158
+ prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
159
+
160
+ batch_size = 5 # Number of samples per batch
161
+ generated_texts = []
162
+
163
+ for i in range(0, len(croped_pil_image), batch_size):
164
+ images = croped_pil_image[i:i+batch_size]
165
+ image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
166
+ inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
167
+ texts = [prompt] * len(images)
168
+ for i, txt in enumerate(texts):
169
+ input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
170
+ inputs['input_ids'].append(input['input_ids'])
171
+ inputs['attention_mask'].append(input['attention_mask'])
172
+ inputs['pixel_values'].append(input['pixel_values'])
173
+ inputs['image_sizes'].append(input['image_sizes'])
174
+ max_len = max([x.shape[1] for x in inputs['input_ids']])
175
+ for i, v in enumerate(inputs['input_ids']):
176
+ inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
177
+ inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
178
+ inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
179
+
180
+ generation_args = {
181
+ "max_new_tokens": 25,
182
+ "temperature": 0.01,
183
+ "do_sample": False,
184
+ }
185
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
186
+ # # remove input tokens
187
+ generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
188
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
189
+ response = [res.strip('\n').strip() for res in response]
190
+ generated_texts.extend(response)
191
+
192
+ return generated_texts
193
+
194
+ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
195
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
196
+
197
+ def box_area(box):
198
+ return (box[2] - box[0]) * (box[3] - box[1])
199
+
200
+ def intersection_area(box1, box2):
201
+ x1 = max(box1[0], box2[0])
202
+ y1 = max(box1[1], box2[1])
203
+ x2 = min(box1[2], box2[2])
204
+ y2 = min(box1[3], box2[3])
205
+ return max(0, x2 - x1) * max(0, y2 - y1)
206
+
207
+ def IoU(box1, box2):
208
+ intersection = intersection_area(box1, box2)
209
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
210
+ if box_area(box1) > 0 and box_area(box2) > 0:
211
+ ratio1 = intersection / box_area(box1)
212
+ ratio2 = intersection / box_area(box2)
213
+ else:
214
+ ratio1, ratio2 = 0, 0
215
+ return max(intersection / union, ratio1, ratio2)
216
+
217
+ boxes = boxes.tolist()
218
+ filtered_boxes = []
219
+ if ocr_bbox:
220
+ filtered_boxes.extend(ocr_bbox)
221
+ # print('ocr_bbox!!!', ocr_bbox)
222
+ for i, box1 in enumerate(boxes):
223
+ # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
224
+ is_valid_box = True
225
+ for j, box2 in enumerate(boxes):
226
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
227
+ is_valid_box = False
228
+ break
229
+ if is_valid_box:
230
+ # add the following 2 lines to include ocr bbox
231
+ if ocr_bbox:
232
+ if not any(IoU(box1, box3) > iou_threshold for k, box3 in enumerate(ocr_bbox)):
233
+ filtered_boxes.append(box1)
234
+ else:
235
+ filtered_boxes.append(box1)
236
+ return torch.tensor(filtered_boxes)
237
+
238
+ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
239
+ transform = T.Compose(
240
+ [
241
+ T.RandomResize([800], max_size=1333),
242
+ T.ToTensor(),
243
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
244
+ ]
245
+ )
246
+ image_source = Image.open(image_path).convert("RGB")
247
+ image = np.asarray(image_source)
248
+ image_transformed, _ = transform(image_source, None)
249
+ return image, image_transformed
250
+
251
+
252
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
253
+ text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
254
+ """
255
+ This function annotates an image with bounding boxes and labels.
256
+
257
+ Parameters:
258
+ image_source (np.ndarray): The source image to be annotated.
259
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
260
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
261
+ phrases (List[str]): A list of labels for each bounding box.
262
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
263
+
264
+ Returns:
265
+ np.ndarray: The annotated image.
266
+ """
267
+ h, w, _ = image_source.shape
268
+ boxes = boxes * torch.Tensor([w, h, w, h])
269
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
270
+ xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
271
+ detections = sv.Detections(xyxy=xyxy)
272
+
273
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
274
+
275
+ from util.box_annotator import BoxAnnotator
276
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
277
+ annotated_frame = image_source.copy()
278
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
279
+
280
+ label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
281
+ return annotated_frame, label_coordinates
282
+
283
+
284
+ def predict(model, image, caption, box_threshold, text_threshold):
285
+ """ Use huggingface model to replace the original model
286
+ """
287
+ model, processor = model['model'], model['processor']
288
+ device = model.device
289
+
290
+ inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
291
+ with torch.no_grad():
292
+ outputs = model(**inputs)
293
+
294
+ results = processor.post_process_grounded_object_detection(
295
+ outputs,
296
+ inputs.input_ids,
297
+ box_threshold=box_threshold, # 0.4,
298
+ text_threshold=text_threshold, # 0.3,
299
+ target_sizes=[image.size[::-1]]
300
+ )[0]
301
+ boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
302
+ return boxes, logits, phrases
303
+
304
+
305
+ def predict_yolo(model, image_path, box_threshold):
306
+ """ Use huggingface model to replace the original model
307
+ """
308
+ # model = model['model']
309
+
310
+ result = model.predict(
311
+ source=image_path,
312
+ conf=box_threshold,
313
+ # iou=0.5, # default 0.7
314
+ )
315
+ boxes = result[0].boxes.xyxy#.tolist() # in pixel space
316
+ conf = result[0].boxes.conf
317
+ phrases = [str(i) for i in range(len(boxes))]
318
+
319
+ return boxes, conf, phrases
320
+
321
+
322
+ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None):
323
+ """ ocr_bbox: list of xyxy format bbox
324
+ """
325
+ TEXT_PROMPT = "clickable buttons on the screen"
326
+ # BOX_TRESHOLD = 0.02 # 0.05/0.02 for web and 0.1 for mobile
327
+ TEXT_TRESHOLD = 0.01 # 0.9 # 0.01
328
+ image_source = Image.open(img_path).convert("RGB")
329
+ w, h = image_source.size
330
+ # import pdb; pdb.set_trace()
331
+ if False: # TODO
332
+ xyxy, logits, phrases = predict(model=model, image=image_source, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD)
333
+ else:
334
+ xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD)
335
+ xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
336
+ image_source = np.asarray(image_source)
337
+ phrases = [str(i) for i in range(len(phrases))]
338
+
339
+ # annotate the image with labels
340
+ h, w, _ = image_source.shape
341
+ if ocr_bbox:
342
+ ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
343
+ ocr_bbox=ocr_bbox.tolist()
344
+ else:
345
+ print('no ocr bbox!!!')
346
+ ocr_bbox = None
347
+ filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
348
+
349
+ # get parsed icon local semantics
350
+ if use_local_semantics:
351
+ caption_model = caption_model_processor['model']
352
+ if 'phi3_v' in caption_model.config.model_type:
353
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
354
+ else:
355
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=prompt)
356
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
357
+ icon_start = len(ocr_text)
358
+ parsed_content_icon_ls = []
359
+ for i, txt in enumerate(parsed_content_icon):
360
+ parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
361
+ parsed_content_merged = ocr_text + parsed_content_icon_ls
362
+ else:
363
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
364
+ parsed_content_merged = ocr_text
365
+
366
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
367
+
368
+ phrases = [i for i in range(len(filtered_boxes))]
369
+
370
+ # draw boxes
371
+ if draw_bbox_config:
372
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
373
+ else:
374
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
375
+
376
+ pil_img = Image.fromarray(annotated_frame)
377
+ buffered = io.BytesIO()
378
+ pil_img.save(buffered, format="PNG")
379
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
380
+ if output_coord_in_ratio:
381
+ # h, w, _ = image_source.shape
382
+ label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
383
+ assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
384
+
385
+ return encoded_image, label_coordinates, parsed_content_merged
386
+
387
+
388
+ def get_xywh(input):
389
+ x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
390
+ x, y, w, h = int(x), int(y), int(w), int(h)
391
+ return x, y, w, h
392
+
393
+ def get_xyxy(input):
394
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
395
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
396
+ return x, y, xp, yp
397
+
398
+ def get_xywh_yolo(input):
399
+ x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
400
+ x, y, w, h = int(x), int(y), int(w), int(h)
401
+ return x, y, w, h
402
+
403
+
404
+
405
+ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
406
+ if easyocr_args is None:
407
+ easyocr_args = {}
408
+ result = reader.readtext(image_path, **easyocr_args)
409
+ is_goal_filtered = False
410
+ # print('goal filtering pred:', result[-5:])
411
+ coord = [item[0] for item in result]
412
+ text = [item[1] for item in result]
413
+ # read the image using cv2
414
+ if display_img:
415
+ opencv_img = cv2.imread(image_path)
416
+ opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
417
+ bb = []
418
+ for item in coord:
419
+ x, y, a, b = get_xywh(item)
420
+ # print(x, y, a, b)
421
+ bb.append((x, y, a, b))
422
+ cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
423
+
424
+ # Display the image
425
+ plt.imshow(opencv_img)
426
+ else:
427
+ if output_bb_format == 'xywh':
428
+ bb = [get_xywh(item) for item in coord]
429
+ elif output_bb_format == 'xyxy':
430
+ bb = [get_xyxy(item) for item in coord]
431
+ # print('bounding box!!!', bb)
432
+ return (text, bb), is_goal_filtered
433
+
434
+
435
+
436
+
weights/convert_safetensor_to_pt.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ultralytics.nn.tasks import DetectionModel
3
+ from safetensors.torch import load_file
4
+
5
+ tensor_dict = load_file("weights/icon_detect/model.safetensors")
6
+
7
+ model = DetectionModel('weights/icon_detect/model.yaml')
8
+ model.load_state_dict(tensor_dict)
9
+ torch.save({'model':model}, 'weights/icon_detect/best.pt')