Spaces:
Runtime error
Runtime error
push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +167 -0
- LICENSE +201 -0
- app.py +1 -1
- cog.yaml +24 -0
- flux +0 -1
- image_datasets/canny_dataset.py +59 -0
- image_datasets/dataset.py +45 -0
- main.py +180 -0
- models_licence/LICENSE-FLUX1-dev +42 -0
- predict.py +134 -0
- src/flux/__init__.py +11 -0
- src/flux/__main__.py +4 -0
- src/flux/annotator/canny/__init__.py +6 -0
- src/flux/annotator/ckpts/ckpts.txt +1 -0
- src/flux/annotator/dwpose/__init__.py +68 -0
- src/flux/annotator/dwpose/onnxdet.py +125 -0
- src/flux/annotator/dwpose/onnxpose.py +360 -0
- src/flux/annotator/dwpose/util.py +297 -0
- src/flux/annotator/dwpose/wholebody.py +48 -0
- src/flux/annotator/hed/__init__.py +95 -0
- src/flux/annotator/midas/LICENSE +21 -0
- src/flux/annotator/midas/__init__.py +42 -0
- src/flux/annotator/midas/api.py +168 -0
- src/flux/annotator/midas/midas/__init__.py +0 -0
- src/flux/annotator/midas/midas/base_model.py +16 -0
- src/flux/annotator/midas/midas/blocks.py +342 -0
- src/flux/annotator/midas/midas/dpt_depth.py +109 -0
- src/flux/annotator/midas/midas/midas_net.py +76 -0
- src/flux/annotator/midas/midas/midas_net_custom.py +128 -0
- src/flux/annotator/midas/midas/transforms.py +234 -0
- src/flux/annotator/midas/midas/vit.py +491 -0
- src/flux/annotator/midas/utils.py +189 -0
- src/flux/annotator/mlsd/LICENSE +201 -0
- src/flux/annotator/mlsd/__init__.py +40 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
- src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
- src/flux/annotator/mlsd/utils.py +580 -0
- src/flux/annotator/tile/__init__.py +26 -0
- src/flux/annotator/tile/guided_filter.py +280 -0
- src/flux/annotator/util.py +38 -0
- src/flux/api.py +194 -0
- src/flux/cli.py +254 -0
- src/flux/controlnet.py +222 -0
- src/flux/math.py +30 -0
- src/flux/model.py +228 -0
- src/flux/modules/autoencoder.py +312 -0
- src/flux/modules/conditioner.py +38 -0
- src/flux/modules/layers.py +567 -0
- src/flux/sampling.py +242 -0
- src/flux/util.py +383 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
Makefile
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
weights/
|
25 |
+
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache/
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113 |
+
.pdm.toml
|
114 |
+
.pdm-python
|
115 |
+
.pdm-build/
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
#.idea/
|
166 |
+
|
167 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import os
|
4 |
-
from
|
5 |
import random
|
6 |
import spaces
|
7 |
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import os
|
4 |
+
from src.flux.xflux_pipeline import XFluxPipeline
|
5 |
import random
|
6 |
import spaces
|
7 |
|
cog.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://cog.run/yaml
|
3 |
+
|
4 |
+
build:
|
5 |
+
gpu: true
|
6 |
+
cuda: "12.1"
|
7 |
+
python_version: "3.11"
|
8 |
+
python_packages:
|
9 |
+
- "accelerate==0.30.1"
|
10 |
+
- "deepspeed==0.14.4"
|
11 |
+
- "einops==0.8.0"
|
12 |
+
- "transformers==4.43.3"
|
13 |
+
- "huggingface-hub==0.24.5"
|
14 |
+
- "einops==0.8.0"
|
15 |
+
- "pandas==2.2.2"
|
16 |
+
- "opencv-python==4.10.0.84"
|
17 |
+
- "pillow==10.4.0"
|
18 |
+
- "optimum-quanto==0.2.4"
|
19 |
+
- "sentencepiece==0.2.0"
|
20 |
+
run:
|
21 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
22 |
+
|
23 |
+
# predict.py defines how predictions are run on your model
|
24 |
+
predict: "predict.py:Predictor"
|
flux
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 9e1dd391b2316b1cfc20e523e2885fd30134a2e4
|
|
|
|
image_datasets/canny_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
def canny_processor(image, low_threshold=100, high_threshold=200):
|
13 |
+
image = np.array(image)
|
14 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
15 |
+
image = image[:, :, None]
|
16 |
+
image = np.concatenate([image, image, image], axis=2)
|
17 |
+
canny_image = Image.fromarray(image)
|
18 |
+
return canny_image
|
19 |
+
|
20 |
+
|
21 |
+
def c_crop(image):
|
22 |
+
width, height = image.size
|
23 |
+
new_size = min(width, height)
|
24 |
+
left = (width - new_size) / 2
|
25 |
+
top = (height - new_size) / 2
|
26 |
+
right = (width + new_size) / 2
|
27 |
+
bottom = (height + new_size) / 2
|
28 |
+
return image.crop((left, top, right, bottom))
|
29 |
+
|
30 |
+
class CustomImageDataset(Dataset):
|
31 |
+
def __init__(self, img_dir, img_size=512):
|
32 |
+
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
33 |
+
self.images.sort()
|
34 |
+
self.img_size = img_size
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.images)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
try:
|
41 |
+
img = Image.open(self.images[idx])
|
42 |
+
img = c_crop(img)
|
43 |
+
img = img.resize((self.img_size, self.img_size))
|
44 |
+
hint = canny_processor(img)
|
45 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
46 |
+
img = img.permute(2, 0, 1)
|
47 |
+
hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
|
48 |
+
hint = hint.permute(2, 0, 1)
|
49 |
+
json_path = self.images[idx].split('.')[0] + '.json'
|
50 |
+
prompt = json.load(open(json_path))['caption']
|
51 |
+
return img, hint, prompt
|
52 |
+
except Exception as e:
|
53 |
+
print(e)
|
54 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
55 |
+
|
56 |
+
|
57 |
+
def loader(train_batch_size, num_workers, **args):
|
58 |
+
dataset = CustomImageDataset(**args)
|
59 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
image_datasets/dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
|
10 |
+
def c_crop(image):
|
11 |
+
width, height = image.size
|
12 |
+
new_size = min(width, height)
|
13 |
+
left = (width - new_size) / 2
|
14 |
+
top = (height - new_size) / 2
|
15 |
+
right = (width + new_size) / 2
|
16 |
+
bottom = (height + new_size) / 2
|
17 |
+
return image.crop((left, top, right, bottom))
|
18 |
+
|
19 |
+
class CustomImageDataset(Dataset):
|
20 |
+
def __init__(self, img_dir, img_size=512):
|
21 |
+
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
22 |
+
self.images.sort()
|
23 |
+
self.img_size = img_size
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.images)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
try:
|
30 |
+
img = Image.open(self.images[idx])
|
31 |
+
img = c_crop(img)
|
32 |
+
img = img.resize((self.img_size, self.img_size))
|
33 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
34 |
+
img = img.permute(2, 0, 1)
|
35 |
+
json_path = self.images[idx].split('.')[0] + '.json'
|
36 |
+
prompt = json.load(open(json_path))['caption']
|
37 |
+
return img, prompt
|
38 |
+
except Exception as e:
|
39 |
+
print(e)
|
40 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
41 |
+
|
42 |
+
|
43 |
+
def loader(train_batch_size, num_workers, **args):
|
44 |
+
dataset = CustomImageDataset(**args)
|
45 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
main.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
|
5 |
+
from src.flux.xflux_pipeline import XFluxPipeline
|
6 |
+
|
7 |
+
|
8 |
+
def create_argparser():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
|
11 |
+
parser.add_argument(
|
12 |
+
"--prompt", type=str, required=True,
|
13 |
+
help="The input text prompt"
|
14 |
+
)
|
15 |
+
parser.add_argument(
|
16 |
+
"--neg_prompt", type=str, default="",
|
17 |
+
help="The input text negative prompt"
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--img_prompt", type=str, default=None,
|
21 |
+
help="Path to input image prompt"
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--neg_img_prompt", type=str, default=None,
|
25 |
+
help="Path to input negative image prompt"
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--ip_scale", type=float, default=1.0,
|
29 |
+
help="Strength of input image prompt"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--neg_ip_scale", type=float, default=1.0,
|
33 |
+
help="Strength of negative input image prompt"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--local_path", type=str, default=None,
|
37 |
+
help="Local path to the model checkpoint (Controlnet)"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--repo_id", type=str, default=None,
|
41 |
+
help="A HuggingFace repo id to download model (Controlnet)"
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--name", type=str, default=None,
|
45 |
+
help="A filename to download from HuggingFace"
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--ip_repo_id", type=str, default=None,
|
49 |
+
help="A HuggingFace repo id to download model (IP-Adapter)"
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--ip_name", type=str, default=None,
|
53 |
+
help="A IP-Adapter filename to download from HuggingFace"
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--ip_local_path", type=str, default=None,
|
57 |
+
help="Local path to the model checkpoint (IP-Adapter)"
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--lora_repo_id", type=str, default=None,
|
61 |
+
help="A HuggingFace repo id to download model (LoRA)"
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--lora_name", type=str, default=None,
|
65 |
+
help="A LoRA filename to download from HuggingFace"
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--lora_local_path", type=str, default=None,
|
69 |
+
help="Local path to the model checkpoint (Controlnet)"
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--device", type=str, default="cuda",
|
73 |
+
help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--offload", action='store_true', help="Offload model to CPU when not in use"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--use_ip", action='store_true', help="Load IP model"
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--use_lora", action='store_true', help="Load Lora model"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--use_controlnet", action='store_true', help="Load Controlnet model"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--num_images_per_prompt", type=int, default=1,
|
89 |
+
help="The number of images to generate per prompt"
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--image", type=str, default=None, help="Path to image"
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--lora_weight", type=float, default=0.9, help="Lora model strength (from 0 to 1.0)"
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--control_type", type=str, default="canny",
|
99 |
+
choices=("canny", "openpose", "depth", "hed", "hough", "tile"),
|
100 |
+
help="Name of controlnet condition, example: canny"
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--model_type", type=str, default="flux-dev",
|
104 |
+
choices=("flux-dev", "flux-dev-fp8", "flux-schnell"),
|
105 |
+
help="Model type to use (flux-dev, flux-dev-fp8, flux-schnell)"
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--width", type=int, default=1024, help="The width for generated image"
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--height", type=int, default=1024, help="The height for generated image"
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--num_steps", type=int, default=25, help="The num_steps for diffusion process"
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--guidance", type=float, default=4, help="The guidance for diffusion process"
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--seed", type=int, default=123456789, help="A seed for reproducible inference"
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--true_gs", type=float, default=3.5, help="true guidance"
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--timestep_to_start_cfg", type=int, default=5, help="timestep to start true guidance"
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--save_path", type=str, default='results', help="Path to save"
|
130 |
+
)
|
131 |
+
return parser
|
132 |
+
|
133 |
+
|
134 |
+
def main(args):
|
135 |
+
if args.image:
|
136 |
+
image = Image.open(args.image)
|
137 |
+
else:
|
138 |
+
image = None
|
139 |
+
|
140 |
+
xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload)
|
141 |
+
if args.use_ip:
|
142 |
+
print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name)
|
143 |
+
xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name)
|
144 |
+
if args.use_lora:
|
145 |
+
print('load lora:', args.lora_local_path, args.lora_repo_id, args.lora_name)
|
146 |
+
xflux_pipeline.set_lora(args.lora_local_path, args.lora_repo_id, args.lora_name, args.lora_weight)
|
147 |
+
if args.use_controlnet:
|
148 |
+
print('load controlnet:', args.local_path, args.repo_id, args.name)
|
149 |
+
xflux_pipeline.set_controlnet(args.control_type, args.local_path, args.repo_id, args.name)
|
150 |
+
|
151 |
+
image_prompt = Image.open(args.img_prompt) if args.img_prompt else None
|
152 |
+
neg_image_prompt = Image.open(args.neg_img_prompt) if args.neg_img_prompt else None
|
153 |
+
|
154 |
+
for _ in range(args.num_images_per_prompt):
|
155 |
+
result = xflux_pipeline(
|
156 |
+
prompt=args.prompt,
|
157 |
+
controlnet_image=image,
|
158 |
+
width=args.width,
|
159 |
+
height=args.height,
|
160 |
+
guidance=args.guidance,
|
161 |
+
num_steps=args.num_steps,
|
162 |
+
seed=args.seed,
|
163 |
+
true_gs=args.true_gs,
|
164 |
+
neg_prompt=args.neg_prompt,
|
165 |
+
timestep_to_start_cfg=args.timestep_to_start_cfg,
|
166 |
+
image_prompt=image_prompt,
|
167 |
+
neg_image_prompt=neg_image_prompt,
|
168 |
+
ip_scale=args.ip_scale,
|
169 |
+
neg_ip_scale=args.neg_ip_scale,
|
170 |
+
)
|
171 |
+
if not os.path.exists(args.save_path):
|
172 |
+
os.mkdir(args.save_path)
|
173 |
+
ind = len(os.listdir(args.save_path))
|
174 |
+
result.save(os.path.join(args.save_path, f"result_{ind}.png"))
|
175 |
+
args.seed = args.seed + 1
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
args = create_argparser().parse_args()
|
180 |
+
main(args)
|
models_licence/LICENSE-FLUX1-dev
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FLUX.1 [dev] Non-Commercial License
|
2 |
+
Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] text-to-image AI model and its elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI model made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”).
|
3 |
+
By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
|
4 |
+
1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings:
|
5 |
+
a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
|
6 |
+
b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
|
7 |
+
c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose.
|
8 |
+
d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
|
9 |
+
e. “you” or “your” means the individual or entity entering into this License with Company.
|
10 |
+
2. License Grant.
|
11 |
+
a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf.
|
12 |
+
b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: [email protected].
|
13 |
+
c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
|
14 |
+
d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model.
|
15 |
+
3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
|
16 |
+
a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
|
17 |
+
b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
|
18 |
+
“The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc.
|
19 |
+
IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
|
20 |
+
c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and
|
21 |
+
d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions.
|
22 |
+
e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
|
23 |
+
4. Restrictions. You will not, and will not permit, assist or cause any third party to
|
24 |
+
a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
|
25 |
+
b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
|
26 |
+
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or
|
27 |
+
d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License.
|
28 |
+
e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
|
29 |
+
f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
|
30 |
+
5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
|
31 |
+
6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
|
32 |
+
7. INDEMNIFICATION
|
33 |
+
|
34 |
+
You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
|
35 |
+
8. Termination; Survival.
|
36 |
+
a. This License will automatically terminate upon any breach by you of the terms of this License.
|
37 |
+
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
|
38 |
+
c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
|
39 |
+
d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11.
|
40 |
+
9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
|
41 |
+
10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
|
42 |
+
11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company.
|
predict.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://cog.run/python
|
3 |
+
|
4 |
+
from cog import BasePredictor, Input, Path
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import subprocess
|
9 |
+
from PIL import Image
|
10 |
+
from typing import List
|
11 |
+
from image_datasets.canny_dataset import canny_processor, c_crop
|
12 |
+
from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors
|
13 |
+
|
14 |
+
OUTPUT_DIR = "controlnet_results"
|
15 |
+
MODEL_CACHE = "checkpoints"
|
16 |
+
CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors"
|
17 |
+
T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar"
|
18 |
+
CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar"
|
19 |
+
HF_TOKEN = "hf_..." # Your HuggingFace token
|
20 |
+
|
21 |
+
def download_weights(url, dest):
|
22 |
+
start = time.time()
|
23 |
+
print("downloading url: ", url)
|
24 |
+
print("downloading to: ", dest)
|
25 |
+
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
|
26 |
+
print("downloading took: ", time.time() - start)
|
27 |
+
|
28 |
+
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
|
29 |
+
t5 = load_t5(device, max_length=256 if is_schnell else 512)
|
30 |
+
clip = load_clip(device)
|
31 |
+
model = load_flow_model(name, device="cpu" if offload else device)
|
32 |
+
ae = load_ae(name, device="cpu" if offload else device)
|
33 |
+
controlnet = load_controlnet(name, device).to(torch.bfloat16)
|
34 |
+
return model, ae, t5, clip, controlnet
|
35 |
+
|
36 |
+
class Predictor(BasePredictor):
|
37 |
+
def setup(self) -> None:
|
38 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
39 |
+
t1 = time.time()
|
40 |
+
os.system(f"huggingface-cli login --token {HF_TOKEN}")
|
41 |
+
name = "flux-dev"
|
42 |
+
self.offload = False
|
43 |
+
checkpoint = "controlnet.safetensors"
|
44 |
+
|
45 |
+
print("Checking ControlNet weights")
|
46 |
+
checkpoint = "controlnet.safetensors"
|
47 |
+
if not os.path.exists(checkpoint):
|
48 |
+
os.system(f"wget {CONTROLNET_URL}")
|
49 |
+
print("Checking T5 weights")
|
50 |
+
if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"):
|
51 |
+
download_weights(T5_URL, MODEL_CACHE)
|
52 |
+
print("Checking CLIP weights")
|
53 |
+
if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"):
|
54 |
+
download_weights(CLIP_URL, MODEL_CACHE)
|
55 |
+
|
56 |
+
self.is_schnell = False
|
57 |
+
device = "cuda"
|
58 |
+
self.torch_device = torch.device(device)
|
59 |
+
model, ae, t5, clip, controlnet = get_models(
|
60 |
+
name,
|
61 |
+
device=self.torch_device,
|
62 |
+
offload=self.offload,
|
63 |
+
is_schnell=self.is_schnell,
|
64 |
+
)
|
65 |
+
self.ae = ae
|
66 |
+
self.t5 = t5
|
67 |
+
self.clip = clip
|
68 |
+
self.controlnet = controlnet
|
69 |
+
self.model = model.to(self.torch_device)
|
70 |
+
if '.safetensors' in checkpoint:
|
71 |
+
checkpoint1 = load_safetensors(checkpoint)
|
72 |
+
else:
|
73 |
+
checkpoint1 = torch.load(checkpoint, map_location='cpu')
|
74 |
+
|
75 |
+
controlnet.load_state_dict(checkpoint1, strict=False)
|
76 |
+
t2 = time.time()
|
77 |
+
print(f"Setup time: {t2 - t1}")
|
78 |
+
|
79 |
+
def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512):
|
80 |
+
image = Image.open(image_path)
|
81 |
+
image = c_crop(image)
|
82 |
+
image = image.resize((width, height))
|
83 |
+
image = canny_processor(image)
|
84 |
+
return image
|
85 |
+
|
86 |
+
def predict(
|
87 |
+
self,
|
88 |
+
prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"),
|
89 |
+
image: Path = Input(description="Input image", default=None),
|
90 |
+
num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28),
|
91 |
+
cfg: float = Input(description="CFG", ge=0, le=10, default=3.5),
|
92 |
+
seed: int = Input(description="Random seed", default=None)
|
93 |
+
) -> List[Path]:
|
94 |
+
"""Run a single prediction on the model"""
|
95 |
+
if seed is None:
|
96 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
97 |
+
print(f"Using seed: {seed}")
|
98 |
+
|
99 |
+
# clean output dir
|
100 |
+
output_dir = "controlnet_results"
|
101 |
+
os.system(f"rm -rf {output_dir}")
|
102 |
+
|
103 |
+
input_image = str(image)
|
104 |
+
img = Image.open(input_image)
|
105 |
+
width, height = img.size
|
106 |
+
# Resize input image if it's too large
|
107 |
+
max_image_size = 1536
|
108 |
+
scale = min(max_image_size / width, max_image_size / height, 1)
|
109 |
+
if scale < 1:
|
110 |
+
width = int(width * scale)
|
111 |
+
height = int(height * scale)
|
112 |
+
print(f"Scaling image down to {width}x{height}")
|
113 |
+
img = img.resize((width, height), resample=Image.Resampling.LANCZOS)
|
114 |
+
input_image = "/tmp/resized_image.png"
|
115 |
+
img.save(input_image)
|
116 |
+
|
117 |
+
subprocess.check_call(
|
118 |
+
["python3", "main.py",
|
119 |
+
"--local_path", "controlnet.safetensors",
|
120 |
+
"--image", input_image,
|
121 |
+
"--use_controlnet",
|
122 |
+
"--control_type", "canny",
|
123 |
+
"--prompt", prompt,
|
124 |
+
"--width", str(width),
|
125 |
+
"--height", str(height),
|
126 |
+
"--num_steps", str(num_inference_steps),
|
127 |
+
"--guidance", str(cfg),
|
128 |
+
"--seed", str(seed)
|
129 |
+
], close_fds=False)
|
130 |
+
|
131 |
+
# Find the first file that begins with "controlnet_result_"
|
132 |
+
for file in os.listdir(output_dir):
|
133 |
+
if file.startswith("controlnet_result_"):
|
134 |
+
return [Path(os.path.join(output_dir, file))]
|
src/flux/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from ._version import version as __version__ # type: ignore
|
3 |
+
from ._version import version_tuple
|
4 |
+
except ImportError:
|
5 |
+
__version__ = "unknown (no version information available)"
|
6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
PACKAGE = __package__.replace("_", "-")
|
11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
src/flux/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cli import app
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app()
|
src/flux/annotator/canny/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
class CannyDetector:
|
5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
src/flux/annotator/ckpts/ckpts.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Weights here.
|
src/flux/annotator/dwpose/__init__.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Openpose
|
2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
4 |
+
# 3rd Edited by ControlNet
|
5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
6 |
+
|
7 |
+
import os
|
8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from . import util
|
13 |
+
from .wholebody import Wholebody
|
14 |
+
|
15 |
+
def draw_pose(pose, H, W):
|
16 |
+
bodies = pose['bodies']
|
17 |
+
faces = pose['faces']
|
18 |
+
hands = pose['hands']
|
19 |
+
candidate = bodies['candidate']
|
20 |
+
subset = bodies['subset']
|
21 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
22 |
+
|
23 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
24 |
+
|
25 |
+
canvas = util.draw_handpose(canvas, hands)
|
26 |
+
|
27 |
+
canvas = util.draw_facepose(canvas, faces)
|
28 |
+
|
29 |
+
return canvas
|
30 |
+
|
31 |
+
|
32 |
+
class DWposeDetector:
|
33 |
+
def __init__(self, device):
|
34 |
+
|
35 |
+
self.pose_estimation = Wholebody(device)
|
36 |
+
|
37 |
+
def __call__(self, oriImg):
|
38 |
+
oriImg = oriImg.copy()
|
39 |
+
H, W, C = oriImg.shape
|
40 |
+
with torch.no_grad():
|
41 |
+
candidate, subset = self.pose_estimation(oriImg)
|
42 |
+
nums, keys, locs = candidate.shape
|
43 |
+
candidate[..., 0] /= float(W)
|
44 |
+
candidate[..., 1] /= float(H)
|
45 |
+
body = candidate[:,:18].copy()
|
46 |
+
body = body.reshape(nums*18, locs)
|
47 |
+
score = subset[:,:18]
|
48 |
+
for i in range(len(score)):
|
49 |
+
for j in range(len(score[i])):
|
50 |
+
if score[i][j] > 0.3:
|
51 |
+
score[i][j] = int(18*i+j)
|
52 |
+
else:
|
53 |
+
score[i][j] = -1
|
54 |
+
|
55 |
+
un_visible = subset<0.3
|
56 |
+
candidate[un_visible] = -1
|
57 |
+
|
58 |
+
foot = candidate[:,18:24]
|
59 |
+
|
60 |
+
faces = candidate[:,24:92]
|
61 |
+
|
62 |
+
hands = candidate[:,92:113]
|
63 |
+
hands = np.vstack([hands, candidate[:,113:]])
|
64 |
+
|
65 |
+
bodies = dict(candidate=body, subset=score)
|
66 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
67 |
+
|
68 |
+
return draw_pose(pose, H, W)
|
src/flux/annotator/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime
|
5 |
+
|
6 |
+
def nms(boxes, scores, nms_thr):
|
7 |
+
"""Single class NMS implemented in Numpy."""
|
8 |
+
x1 = boxes[:, 0]
|
9 |
+
y1 = boxes[:, 1]
|
10 |
+
x2 = boxes[:, 2]
|
11 |
+
y2 = boxes[:, 3]
|
12 |
+
|
13 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
14 |
+
order = scores.argsort()[::-1]
|
15 |
+
|
16 |
+
keep = []
|
17 |
+
while order.size > 0:
|
18 |
+
i = order[0]
|
19 |
+
keep.append(i)
|
20 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
21 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
22 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
23 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
24 |
+
|
25 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
26 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
27 |
+
inter = w * h
|
28 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
29 |
+
|
30 |
+
inds = np.where(ovr <= nms_thr)[0]
|
31 |
+
order = order[inds + 1]
|
32 |
+
|
33 |
+
return keep
|
34 |
+
|
35 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
36 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
37 |
+
final_dets = []
|
38 |
+
num_classes = scores.shape[1]
|
39 |
+
for cls_ind in range(num_classes):
|
40 |
+
cls_scores = scores[:, cls_ind]
|
41 |
+
valid_score_mask = cls_scores > score_thr
|
42 |
+
if valid_score_mask.sum() == 0:
|
43 |
+
continue
|
44 |
+
else:
|
45 |
+
valid_scores = cls_scores[valid_score_mask]
|
46 |
+
valid_boxes = boxes[valid_score_mask]
|
47 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
48 |
+
if len(keep) > 0:
|
49 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
50 |
+
dets = np.concatenate(
|
51 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
52 |
+
)
|
53 |
+
final_dets.append(dets)
|
54 |
+
if len(final_dets) == 0:
|
55 |
+
return None
|
56 |
+
return np.concatenate(final_dets, 0)
|
57 |
+
|
58 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
59 |
+
grids = []
|
60 |
+
expanded_strides = []
|
61 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
62 |
+
|
63 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
64 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
65 |
+
|
66 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
67 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
68 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
69 |
+
grids.append(grid)
|
70 |
+
shape = grid.shape[:2]
|
71 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
72 |
+
|
73 |
+
grids = np.concatenate(grids, 1)
|
74 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
75 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
76 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
77 |
+
|
78 |
+
return outputs
|
79 |
+
|
80 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
81 |
+
if len(img.shape) == 3:
|
82 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
83 |
+
else:
|
84 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
85 |
+
|
86 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
87 |
+
resized_img = cv2.resize(
|
88 |
+
img,
|
89 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
90 |
+
interpolation=cv2.INTER_LINEAR,
|
91 |
+
).astype(np.uint8)
|
92 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
93 |
+
|
94 |
+
padded_img = padded_img.transpose(swap)
|
95 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
96 |
+
return padded_img, r
|
97 |
+
|
98 |
+
def inference_detector(session, oriImg):
|
99 |
+
input_shape = (640,640)
|
100 |
+
img, ratio = preprocess(oriImg, input_shape)
|
101 |
+
|
102 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
103 |
+
output = session.run(None, ort_inputs)
|
104 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
105 |
+
|
106 |
+
boxes = predictions[:, :4]
|
107 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
108 |
+
|
109 |
+
boxes_xyxy = np.ones_like(boxes)
|
110 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
111 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
112 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
114 |
+
boxes_xyxy /= ratio
|
115 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
116 |
+
if dets is not None:
|
117 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
118 |
+
isscore = final_scores>0.3
|
119 |
+
iscat = final_cls_inds == 0
|
120 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
121 |
+
final_boxes = final_boxes[isbbox]
|
122 |
+
else:
|
123 |
+
final_boxes = np.array([])
|
124 |
+
|
125 |
+
return final_boxes
|
src/flux/annotator/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
|
7 |
+
def preprocess(
|
8 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
9 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
10 |
+
"""Do preprocessing for RTMPose model inference.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (np.ndarray): Input image in shape.
|
14 |
+
input_size (tuple): Input image size in shape (w, h).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple:
|
18 |
+
- resized_img (np.ndarray): Preprocessed image.
|
19 |
+
- center (np.ndarray): Center of image.
|
20 |
+
- scale (np.ndarray): Scale of image.
|
21 |
+
"""
|
22 |
+
# get shape of image
|
23 |
+
img_shape = img.shape[:2]
|
24 |
+
out_img, out_center, out_scale = [], [], []
|
25 |
+
if len(out_bbox) == 0:
|
26 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
27 |
+
for i in range(len(out_bbox)):
|
28 |
+
x0 = out_bbox[i][0]
|
29 |
+
y0 = out_bbox[i][1]
|
30 |
+
x1 = out_bbox[i][2]
|
31 |
+
y1 = out_bbox[i][3]
|
32 |
+
bbox = np.array([x0, y0, x1, y1])
|
33 |
+
|
34 |
+
# get center and scale
|
35 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
36 |
+
|
37 |
+
# do affine transformation
|
38 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
39 |
+
|
40 |
+
# normalize image
|
41 |
+
mean = np.array([123.675, 116.28, 103.53])
|
42 |
+
std = np.array([58.395, 57.12, 57.375])
|
43 |
+
resized_img = (resized_img - mean) / std
|
44 |
+
|
45 |
+
out_img.append(resized_img)
|
46 |
+
out_center.append(center)
|
47 |
+
out_scale.append(scale)
|
48 |
+
|
49 |
+
return out_img, out_center, out_scale
|
50 |
+
|
51 |
+
|
52 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
53 |
+
"""Inference RTMPose model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
57 |
+
img (np.ndarray): Input image in shape.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
outputs (np.ndarray): Output of RTMPose model.
|
61 |
+
"""
|
62 |
+
all_out = []
|
63 |
+
# build input
|
64 |
+
for i in range(len(img)):
|
65 |
+
input = [img[i].transpose(2, 0, 1)]
|
66 |
+
|
67 |
+
# build output
|
68 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
69 |
+
sess_output = []
|
70 |
+
for out in sess.get_outputs():
|
71 |
+
sess_output.append(out.name)
|
72 |
+
|
73 |
+
# run model
|
74 |
+
outputs = sess.run(sess_output, sess_input)
|
75 |
+
all_out.append(outputs)
|
76 |
+
|
77 |
+
return all_out
|
78 |
+
|
79 |
+
|
80 |
+
def postprocess(outputs: List[np.ndarray],
|
81 |
+
model_input_size: Tuple[int, int],
|
82 |
+
center: Tuple[int, int],
|
83 |
+
scale: Tuple[int, int],
|
84 |
+
simcc_split_ratio: float = 2.0
|
85 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
86 |
+
"""Postprocess for RTMPose model output.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
outputs (np.ndarray): Output of RTMPose model.
|
90 |
+
model_input_size (tuple): RTMPose model Input image size.
|
91 |
+
center (tuple): Center of bbox in shape (x, y).
|
92 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
93 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
tuple:
|
97 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
98 |
+
- scores (np.ndarray): Model predict scores.
|
99 |
+
"""
|
100 |
+
all_key = []
|
101 |
+
all_score = []
|
102 |
+
for i in range(len(outputs)):
|
103 |
+
# use simcc to decode
|
104 |
+
simcc_x, simcc_y = outputs[i]
|
105 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
106 |
+
|
107 |
+
# rescale keypoints
|
108 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
109 |
+
all_key.append(keypoints[0])
|
110 |
+
all_score.append(scores[0])
|
111 |
+
|
112 |
+
return np.array(all_key), np.array(all_score)
|
113 |
+
|
114 |
+
|
115 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
116 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
117 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
121 |
+
as (left, top, right, bottom)
|
122 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
123 |
+
Default: 1.0
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tuple: A tuple containing center and scale.
|
127 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
128 |
+
(n, 2)
|
129 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
"""
|
132 |
+
# convert single bbox from (4, ) to (1, 4)
|
133 |
+
dim = bbox.ndim
|
134 |
+
if dim == 1:
|
135 |
+
bbox = bbox[None, :]
|
136 |
+
|
137 |
+
# get bbox center and scale
|
138 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
139 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
140 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
141 |
+
|
142 |
+
if dim == 1:
|
143 |
+
center = center[0]
|
144 |
+
scale = scale[0]
|
145 |
+
|
146 |
+
return center, scale
|
147 |
+
|
148 |
+
|
149 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
150 |
+
aspect_ratio: float) -> np.ndarray:
|
151 |
+
"""Extend the scale to match the given aspect ratio.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
155 |
+
aspect_ratio (float): The ratio of ``w/h``
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
np.ndarray: The reshaped image scale in (2, )
|
159 |
+
"""
|
160 |
+
w, h = np.hsplit(bbox_scale, [1])
|
161 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
162 |
+
np.hstack([w, w / aspect_ratio]),
|
163 |
+
np.hstack([h * aspect_ratio, h]))
|
164 |
+
return bbox_scale
|
165 |
+
|
166 |
+
|
167 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
168 |
+
"""Rotate a point by an angle.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
172 |
+
angle_rad (float): rotation angle in radian
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
np.ndarray: Rotated point in shape (2, )
|
176 |
+
"""
|
177 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
178 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
179 |
+
return rot_mat @ pt
|
180 |
+
|
181 |
+
|
182 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
183 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
184 |
+
function is used to get the 3rd point, given 2D points a & b.
|
185 |
+
|
186 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
187 |
+
anticlockwise, using b as the rotation center.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
191 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
np.ndarray: The 3rd point.
|
195 |
+
"""
|
196 |
+
direction = a - b
|
197 |
+
c = b + np.r_[-direction[1], direction[0]]
|
198 |
+
return c
|
199 |
+
|
200 |
+
|
201 |
+
def get_warp_matrix(center: np.ndarray,
|
202 |
+
scale: np.ndarray,
|
203 |
+
rot: float,
|
204 |
+
output_size: Tuple[int, int],
|
205 |
+
shift: Tuple[float, float] = (0., 0.),
|
206 |
+
inv: bool = False) -> np.ndarray:
|
207 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
208 |
+
in the input image to the output size.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
212 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
213 |
+
wrt [width, height].
|
214 |
+
rot (float): Rotation angle (degree).
|
215 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
216 |
+
destination heatmaps.
|
217 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
218 |
+
Default (0., 0.).
|
219 |
+
inv (bool): Option to inverse the affine transform direction.
|
220 |
+
(inv=False: src->dst or inv=True: dst->src)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
np.ndarray: A 2x3 transformation matrix
|
224 |
+
"""
|
225 |
+
shift = np.array(shift)
|
226 |
+
src_w = scale[0]
|
227 |
+
dst_w = output_size[0]
|
228 |
+
dst_h = output_size[1]
|
229 |
+
|
230 |
+
# compute transformation matrix
|
231 |
+
rot_rad = np.deg2rad(rot)
|
232 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
233 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
234 |
+
|
235 |
+
# get four corners of the src rectangle in the original image
|
236 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
237 |
+
src[0, :] = center + scale * shift
|
238 |
+
src[1, :] = center + src_dir + scale * shift
|
239 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
240 |
+
|
241 |
+
# get four corners of the dst rectangle in the input image
|
242 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
243 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
244 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
245 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
246 |
+
|
247 |
+
if inv:
|
248 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
249 |
+
else:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
251 |
+
|
252 |
+
return warp_mat
|
253 |
+
|
254 |
+
|
255 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
256 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
257 |
+
"""Get the bbox image as the model input by affine transform.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
input_size (dict): The input size of the model.
|
261 |
+
bbox_scale (dict): The bbox scale of the img.
|
262 |
+
bbox_center (dict): The bbox center of the img.
|
263 |
+
img (np.ndarray): The original image.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: A tuple containing center and scale.
|
267 |
+
- np.ndarray[float32]: img after affine transform.
|
268 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
269 |
+
"""
|
270 |
+
w, h = input_size
|
271 |
+
warp_size = (int(w), int(h))
|
272 |
+
|
273 |
+
# reshape bbox to fixed aspect ratio
|
274 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
275 |
+
|
276 |
+
# get the affine matrix
|
277 |
+
center = bbox_center
|
278 |
+
scale = bbox_scale
|
279 |
+
rot = 0
|
280 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
281 |
+
|
282 |
+
# do affine transform
|
283 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
284 |
+
|
285 |
+
return img, bbox_scale
|
286 |
+
|
287 |
+
|
288 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
289 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
290 |
+
"""Get maximum response location and value from simcc representations.
|
291 |
+
|
292 |
+
Note:
|
293 |
+
instance number: N
|
294 |
+
num_keypoints: K
|
295 |
+
heatmap height: H
|
296 |
+
heatmap width: W
|
297 |
+
|
298 |
+
Args:
|
299 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
300 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
tuple:
|
304 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
305 |
+
(K, 2) or (N, K, 2)
|
306 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
307 |
+
(K,) or (N, K)
|
308 |
+
"""
|
309 |
+
N, K, Wx = simcc_x.shape
|
310 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
311 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
312 |
+
|
313 |
+
# get maximum value locations
|
314 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
315 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
316 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
317 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
318 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
319 |
+
|
320 |
+
# get maximum value across x and y axis
|
321 |
+
mask = max_val_x > max_val_y
|
322 |
+
max_val_x[mask] = max_val_y[mask]
|
323 |
+
vals = max_val_x
|
324 |
+
locs[vals <= 0.] = -1
|
325 |
+
|
326 |
+
# reshape
|
327 |
+
locs = locs.reshape(N, K, 2)
|
328 |
+
vals = vals.reshape(N, K)
|
329 |
+
|
330 |
+
return locs, vals
|
331 |
+
|
332 |
+
|
333 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
334 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
335 |
+
"""Modulate simcc distribution with Gaussian.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
339 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
340 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
tuple: A tuple containing center and scale.
|
344 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
345 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
346 |
+
"""
|
347 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
348 |
+
keypoints /= simcc_split_ratio
|
349 |
+
|
350 |
+
return keypoints, scores
|
351 |
+
|
352 |
+
|
353 |
+
def inference_pose(session, out_bbox, oriImg):
|
354 |
+
h, w = session.get_inputs()[0].shape[2:]
|
355 |
+
model_input_size = (w, h)
|
356 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
357 |
+
outputs = inference(session, resized_img)
|
358 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
359 |
+
|
360 |
+
return keypoints, scores
|
src/flux/annotator/dwpose/util.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
eps = 0.01
|
8 |
+
|
9 |
+
|
10 |
+
def smart_resize(x, s):
|
11 |
+
Ht, Wt = s
|
12 |
+
if x.ndim == 2:
|
13 |
+
Ho, Wo = x.shape
|
14 |
+
Co = 1
|
15 |
+
else:
|
16 |
+
Ho, Wo, Co = x.shape
|
17 |
+
if Co == 3 or Co == 1:
|
18 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
19 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
20 |
+
else:
|
21 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
22 |
+
|
23 |
+
|
24 |
+
def smart_resize_k(x, fx, fy):
|
25 |
+
if x.ndim == 2:
|
26 |
+
Ho, Wo = x.shape
|
27 |
+
Co = 1
|
28 |
+
else:
|
29 |
+
Ho, Wo, Co = x.shape
|
30 |
+
Ht, Wt = Ho * fy, Wo * fx
|
31 |
+
if Co == 3 or Co == 1:
|
32 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
33 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
34 |
+
else:
|
35 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
36 |
+
|
37 |
+
|
38 |
+
def padRightDownCorner(img, stride, padValue):
|
39 |
+
h = img.shape[0]
|
40 |
+
w = img.shape[1]
|
41 |
+
|
42 |
+
pad = 4 * [None]
|
43 |
+
pad[0] = 0 # up
|
44 |
+
pad[1] = 0 # left
|
45 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
46 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
47 |
+
|
48 |
+
img_padded = img
|
49 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
50 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
51 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
52 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
53 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
54 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
55 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
56 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
57 |
+
|
58 |
+
return img_padded, pad
|
59 |
+
|
60 |
+
|
61 |
+
def transfer(model, model_weights):
|
62 |
+
transfered_model_weights = {}
|
63 |
+
for weights_name in model.state_dict().keys():
|
64 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
65 |
+
return transfered_model_weights
|
66 |
+
|
67 |
+
|
68 |
+
def draw_bodypose(canvas, candidate, subset):
|
69 |
+
H, W, C = canvas.shape
|
70 |
+
candidate = np.array(candidate)
|
71 |
+
subset = np.array(subset)
|
72 |
+
|
73 |
+
stickwidth = 4
|
74 |
+
|
75 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
76 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
77 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
78 |
+
|
79 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
80 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
81 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
82 |
+
|
83 |
+
for i in range(17):
|
84 |
+
for n in range(len(subset)):
|
85 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
86 |
+
if -1 in index:
|
87 |
+
continue
|
88 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
89 |
+
X = candidate[index.astype(int), 1] * float(H)
|
90 |
+
mX = np.mean(X)
|
91 |
+
mY = np.mean(Y)
|
92 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
93 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
94 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
95 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
96 |
+
|
97 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
98 |
+
|
99 |
+
for i in range(18):
|
100 |
+
for n in range(len(subset)):
|
101 |
+
index = int(subset[n][i])
|
102 |
+
if index == -1:
|
103 |
+
continue
|
104 |
+
x, y = candidate[index][0:2]
|
105 |
+
x = int(x * W)
|
106 |
+
y = int(y * H)
|
107 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
108 |
+
|
109 |
+
return canvas
|
110 |
+
|
111 |
+
|
112 |
+
def draw_handpose(canvas, all_hand_peaks):
|
113 |
+
H, W, C = canvas.shape
|
114 |
+
|
115 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
116 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
117 |
+
|
118 |
+
for peaks in all_hand_peaks:
|
119 |
+
peaks = np.array(peaks)
|
120 |
+
|
121 |
+
for ie, e in enumerate(edges):
|
122 |
+
x1, y1 = peaks[e[0]]
|
123 |
+
x2, y2 = peaks[e[1]]
|
124 |
+
x1 = int(x1 * W)
|
125 |
+
y1 = int(y1 * H)
|
126 |
+
x2 = int(x2 * W)
|
127 |
+
y2 = int(y2 * H)
|
128 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
129 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
130 |
+
|
131 |
+
for i, keyponit in enumerate(peaks):
|
132 |
+
x, y = keyponit
|
133 |
+
x = int(x * W)
|
134 |
+
y = int(y * H)
|
135 |
+
if x > eps and y > eps:
|
136 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
137 |
+
return canvas
|
138 |
+
|
139 |
+
|
140 |
+
def draw_facepose(canvas, all_lmks):
|
141 |
+
H, W, C = canvas.shape
|
142 |
+
for lmks in all_lmks:
|
143 |
+
lmks = np.array(lmks)
|
144 |
+
for lmk in lmks:
|
145 |
+
x, y = lmk
|
146 |
+
x = int(x * W)
|
147 |
+
y = int(y * H)
|
148 |
+
if x > eps and y > eps:
|
149 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
150 |
+
return canvas
|
151 |
+
|
152 |
+
|
153 |
+
# detect hand according to body pose keypoints
|
154 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
155 |
+
def handDetect(candidate, subset, oriImg):
|
156 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
157 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
158 |
+
ratioWristElbow = 0.33
|
159 |
+
detect_result = []
|
160 |
+
image_height, image_width = oriImg.shape[0:2]
|
161 |
+
for person in subset.astype(int):
|
162 |
+
# if any of three not detected
|
163 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
164 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
165 |
+
if not (has_left or has_right):
|
166 |
+
continue
|
167 |
+
hands = []
|
168 |
+
#left hand
|
169 |
+
if has_left:
|
170 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
171 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
172 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
173 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
174 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
175 |
+
# right hand
|
176 |
+
if has_right:
|
177 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
178 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
179 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
180 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
181 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
182 |
+
|
183 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
184 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
185 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
186 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
187 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
188 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
189 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
190 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
191 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
192 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
193 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
194 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
195 |
+
# x-y refers to the center --> offset to topLeft point
|
196 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
197 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
198 |
+
x -= width / 2
|
199 |
+
y -= width / 2 # width = height
|
200 |
+
# overflow the image
|
201 |
+
if x < 0: x = 0
|
202 |
+
if y < 0: y = 0
|
203 |
+
width1 = width
|
204 |
+
width2 = width
|
205 |
+
if x + width > image_width: width1 = image_width - x
|
206 |
+
if y + width > image_height: width2 = image_height - y
|
207 |
+
width = min(width1, width2)
|
208 |
+
# the max hand box value is 20 pixels
|
209 |
+
if width >= 20:
|
210 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
211 |
+
|
212 |
+
'''
|
213 |
+
return value: [[x, y, w, True if left hand else False]].
|
214 |
+
width=height since the network require squared input.
|
215 |
+
x, y is the coordinate of top left
|
216 |
+
'''
|
217 |
+
return detect_result
|
218 |
+
|
219 |
+
|
220 |
+
# Written by Lvmin
|
221 |
+
def faceDetect(candidate, subset, oriImg):
|
222 |
+
# left right eye ear 14 15 16 17
|
223 |
+
detect_result = []
|
224 |
+
image_height, image_width = oriImg.shape[0:2]
|
225 |
+
for person in subset.astype(int):
|
226 |
+
has_head = person[0] > -1
|
227 |
+
if not has_head:
|
228 |
+
continue
|
229 |
+
|
230 |
+
has_left_eye = person[14] > -1
|
231 |
+
has_right_eye = person[15] > -1
|
232 |
+
has_left_ear = person[16] > -1
|
233 |
+
has_right_ear = person[17] > -1
|
234 |
+
|
235 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
236 |
+
continue
|
237 |
+
|
238 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
239 |
+
|
240 |
+
width = 0.0
|
241 |
+
x0, y0 = candidate[head][:2]
|
242 |
+
|
243 |
+
if has_left_eye:
|
244 |
+
x1, y1 = candidate[left_eye][:2]
|
245 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
246 |
+
width = max(width, d * 3.0)
|
247 |
+
|
248 |
+
if has_right_eye:
|
249 |
+
x1, y1 = candidate[right_eye][:2]
|
250 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
251 |
+
width = max(width, d * 3.0)
|
252 |
+
|
253 |
+
if has_left_ear:
|
254 |
+
x1, y1 = candidate[left_ear][:2]
|
255 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
256 |
+
width = max(width, d * 1.5)
|
257 |
+
|
258 |
+
if has_right_ear:
|
259 |
+
x1, y1 = candidate[right_ear][:2]
|
260 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
261 |
+
width = max(width, d * 1.5)
|
262 |
+
|
263 |
+
x, y = x0, y0
|
264 |
+
|
265 |
+
x -= width
|
266 |
+
y -= width
|
267 |
+
|
268 |
+
if x < 0:
|
269 |
+
x = 0
|
270 |
+
|
271 |
+
if y < 0:
|
272 |
+
y = 0
|
273 |
+
|
274 |
+
width1 = width * 2
|
275 |
+
width2 = width * 2
|
276 |
+
|
277 |
+
if x + width > image_width:
|
278 |
+
width1 = image_width - x
|
279 |
+
|
280 |
+
if y + width > image_height:
|
281 |
+
width2 = image_height - y
|
282 |
+
|
283 |
+
width = min(width1, width2)
|
284 |
+
|
285 |
+
if width >= 20:
|
286 |
+
detect_result.append([int(x), int(y), int(width)])
|
287 |
+
|
288 |
+
return detect_result
|
289 |
+
|
290 |
+
|
291 |
+
# get max index of 2d array
|
292 |
+
def npmax(array):
|
293 |
+
arrayindex = array.argmax(1)
|
294 |
+
arrayvalue = array.max(1)
|
295 |
+
i = arrayvalue.argmax()
|
296 |
+
j = arrayindex[i]
|
297 |
+
return i, j
|
src/flux/annotator/dwpose/wholebody.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import onnxruntime as ort
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from .onnxdet import inference_detector
|
7 |
+
from .onnxpose import inference_pose
|
8 |
+
|
9 |
+
|
10 |
+
class Wholebody:
|
11 |
+
def __init__(self, device="cuda:0"):
|
12 |
+
providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
|
13 |
+
onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx")
|
14 |
+
onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx")
|
15 |
+
|
16 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
17 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
18 |
+
|
19 |
+
def __call__(self, oriImg):
|
20 |
+
det_result = inference_detector(self.session_det, oriImg)
|
21 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
22 |
+
|
23 |
+
keypoints_info = np.concatenate(
|
24 |
+
(keypoints, scores[..., None]), axis=-1)
|
25 |
+
# compute neck joint
|
26 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
27 |
+
# neck score when visualizing pred
|
28 |
+
neck[:, 2:4] = np.logical_and(
|
29 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
30 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
31 |
+
new_keypoints_info = np.insert(
|
32 |
+
keypoints_info, 17, neck, axis=1)
|
33 |
+
mmpose_idx = [
|
34 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
35 |
+
]
|
36 |
+
openpose_idx = [
|
37 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
38 |
+
]
|
39 |
+
new_keypoints_info[:, openpose_idx] = \
|
40 |
+
new_keypoints_info[:, mmpose_idx]
|
41 |
+
keypoints_info = new_keypoints_info
|
42 |
+
|
43 |
+
keypoints, scores = keypoints_info[
|
44 |
+
..., :2], keypoints_info[..., 2]
|
45 |
+
|
46 |
+
return keypoints, scores
|
47 |
+
|
48 |
+
|
src/flux/annotator/hed/__init__.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from einops import rearrange
|
15 |
+
from ...annotator.util import annotator_ckpts_path
|
16 |
+
|
17 |
+
|
18 |
+
class DoubleConvBlock(torch.nn.Module):
|
19 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
20 |
+
super().__init__()
|
21 |
+
self.convs = torch.nn.Sequential()
|
22 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
23 |
+
for i in range(1, layer_number):
|
24 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
25 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
26 |
+
|
27 |
+
def __call__(self, x, down_sampling=False):
|
28 |
+
h = x
|
29 |
+
if down_sampling:
|
30 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
31 |
+
for conv in self.convs:
|
32 |
+
h = conv(h)
|
33 |
+
h = torch.nn.functional.relu(h)
|
34 |
+
return h, self.projection(h)
|
35 |
+
|
36 |
+
|
37 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
38 |
+
def __init__(self):
|
39 |
+
super().__init__()
|
40 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
41 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
42 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
43 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
44 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
45 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
46 |
+
|
47 |
+
def __call__(self, x):
|
48 |
+
h = x - self.norm
|
49 |
+
h, projection1 = self.block1(h)
|
50 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
51 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
52 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
53 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
54 |
+
return projection1, projection2, projection3, projection4, projection5
|
55 |
+
|
56 |
+
|
57 |
+
class HEDdetector:
|
58 |
+
def __init__(self):
|
59 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
60 |
+
if not os.path.exists(modelpath):
|
61 |
+
modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
62 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
63 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
64 |
+
|
65 |
+
def __call__(self, input_image):
|
66 |
+
assert input_image.ndim == 3
|
67 |
+
H, W, C = input_image.shape
|
68 |
+
with torch.no_grad():
|
69 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
70 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
71 |
+
edges = self.netNetwork(image_hed)
|
72 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
73 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
74 |
+
edges = np.stack(edges, axis=2)
|
75 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
76 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
77 |
+
return edge
|
78 |
+
|
79 |
+
|
80 |
+
def nms(x, t, s):
|
81 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
82 |
+
|
83 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
84 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
85 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
86 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
87 |
+
|
88 |
+
y = np.zeros_like(x)
|
89 |
+
|
90 |
+
for f in [f1, f2, f3, f4]:
|
91 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
92 |
+
|
93 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
94 |
+
z[y > t] = 255
|
95 |
+
return z
|
src/flux/annotator/midas/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/flux/annotator/midas/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Midas Depth Estimation
|
2 |
+
# From https://github.com/isl-org/MiDaS
|
3 |
+
# MIT LICENSE
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
from .api import MiDaSInference
|
11 |
+
|
12 |
+
|
13 |
+
class MidasDetector:
|
14 |
+
def __init__(self):
|
15 |
+
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
16 |
+
|
17 |
+
def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
|
18 |
+
assert input_image.ndim == 3
|
19 |
+
image_depth = input_image
|
20 |
+
with torch.no_grad():
|
21 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
22 |
+
image_depth = image_depth / 127.5 - 1.0
|
23 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
24 |
+
depth = self.model(image_depth)[0]
|
25 |
+
|
26 |
+
depth_pt = depth.clone()
|
27 |
+
depth_pt -= torch.min(depth_pt)
|
28 |
+
depth_pt /= torch.max(depth_pt)
|
29 |
+
depth_pt = depth_pt.cpu().numpy()
|
30 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
31 |
+
|
32 |
+
depth_np = depth.cpu().numpy()
|
33 |
+
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
34 |
+
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
35 |
+
z = np.ones_like(x) * a
|
36 |
+
x[depth_pt < bg_th] = 0
|
37 |
+
y[depth_pt < bg_th] = 0
|
38 |
+
normal = np.stack([x, y, z], axis=2)
|
39 |
+
normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
|
40 |
+
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
41 |
+
|
42 |
+
return depth_image, normal_image
|
src/flux/annotator/midas/api.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
from .midas.dpt_depth import DPTDepthModel
|
12 |
+
from .midas.midas_net import MidasNet
|
13 |
+
from .midas.midas_net_custom import MidasNet_small
|
14 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
15 |
+
from ...annotator.util import annotator_ckpts_path
|
16 |
+
|
17 |
+
|
18 |
+
ISL_PATHS = {
|
19 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
20 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
21 |
+
"midas_v21": "",
|
22 |
+
"midas_v21_small": "",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == "dpt_large": # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = "minimal"
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = "minimal"
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
44 |
+
|
45 |
+
elif model_type == "midas_v21":
|
46 |
+
net_w, net_h = 384, 384
|
47 |
+
resize_mode = "upper_bound"
|
48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
49 |
+
|
50 |
+
elif model_type == "midas_v21_small":
|
51 |
+
net_w, net_h = 256, 256
|
52 |
+
resize_mode = "upper_bound"
|
53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
|
55 |
+
else:
|
56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
57 |
+
|
58 |
+
transform = Compose(
|
59 |
+
[
|
60 |
+
Resize(
|
61 |
+
net_w,
|
62 |
+
net_h,
|
63 |
+
resize_target=None,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=32,
|
66 |
+
resize_method=resize_mode,
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
normalization,
|
70 |
+
PrepareForNet(),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def load_model(model_type):
|
78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
79 |
+
# load network
|
80 |
+
model_path = ISL_PATHS[model_type]
|
81 |
+
if model_type == "dpt_large": # DPT-Large
|
82 |
+
model = DPTDepthModel(
|
83 |
+
path=model_path,
|
84 |
+
backbone="vitl16_384",
|
85 |
+
non_negative=True,
|
86 |
+
)
|
87 |
+
net_w, net_h = 384, 384
|
88 |
+
resize_mode = "minimal"
|
89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
90 |
+
|
91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
92 |
+
if not os.path.exists(model_path):
|
93 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt")
|
94 |
+
|
95 |
+
model = DPTDepthModel(
|
96 |
+
path=model_path,
|
97 |
+
backbone="vitb_rn50_384",
|
98 |
+
non_negative=True,
|
99 |
+
)
|
100 |
+
net_w, net_h = 384, 384
|
101 |
+
resize_mode = "minimal"
|
102 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
103 |
+
|
104 |
+
elif model_type == "midas_v21":
|
105 |
+
model = MidasNet(model_path, non_negative=True)
|
106 |
+
net_w, net_h = 384, 384
|
107 |
+
resize_mode = "upper_bound"
|
108 |
+
normalization = NormalizeImage(
|
109 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
110 |
+
)
|
111 |
+
|
112 |
+
elif model_type == "midas_v21_small":
|
113 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
114 |
+
non_negative=True, blocks={'expand': True})
|
115 |
+
net_w, net_h = 256, 256
|
116 |
+
resize_mode = "upper_bound"
|
117 |
+
normalization = NormalizeImage(
|
118 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
123 |
+
assert False
|
124 |
+
|
125 |
+
transform = Compose(
|
126 |
+
[
|
127 |
+
Resize(
|
128 |
+
net_w,
|
129 |
+
net_h,
|
130 |
+
resize_target=None,
|
131 |
+
keep_aspect_ratio=True,
|
132 |
+
ensure_multiple_of=32,
|
133 |
+
resize_method=resize_mode,
|
134 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
135 |
+
),
|
136 |
+
normalization,
|
137 |
+
PrepareForNet(),
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
return model.eval(), transform
|
142 |
+
|
143 |
+
|
144 |
+
class MiDaSInference(nn.Module):
|
145 |
+
MODEL_TYPES_TORCH_HUB = [
|
146 |
+
"DPT_Large",
|
147 |
+
"DPT_Hybrid",
|
148 |
+
"MiDaS_small"
|
149 |
+
]
|
150 |
+
MODEL_TYPES_ISL = [
|
151 |
+
"dpt_large",
|
152 |
+
"dpt_hybrid",
|
153 |
+
"midas_v21",
|
154 |
+
"midas_v21_small",
|
155 |
+
]
|
156 |
+
|
157 |
+
def __init__(self, model_type):
|
158 |
+
super().__init__()
|
159 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
160 |
+
model, _ = load_model(model_type)
|
161 |
+
self.model = model
|
162 |
+
self.model.train = disabled_train
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
with torch.no_grad():
|
166 |
+
prediction = self.model(x)
|
167 |
+
return prediction
|
168 |
+
|
src/flux/annotator/midas/midas/__init__.py
ADDED
File without changes
|
src/flux/annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
src/flux/annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
src/flux/annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
src/flux/annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
src/flux/annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
src/flux/annotator/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
src/flux/annotator/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
src/flux/annotator/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
src/flux/annotator/mlsd/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021-present NAVER Corp.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
src/flux/annotator/mlsd/__init__.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MLSD Line Detection
|
2 |
+
# From https://github.com/navervision/mlsd
|
3 |
+
# Apache-2.0 license
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
|
13 |
+
from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
14 |
+
from .utils import pred_lines
|
15 |
+
|
16 |
+
from ...annotator.util import annotator_ckpts_path
|
17 |
+
|
18 |
+
|
19 |
+
class MLSDdetector:
|
20 |
+
def __init__(self):
|
21 |
+
model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
|
22 |
+
if not os.path.exists(model_path):
|
23 |
+
model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth")
|
24 |
+
model = MobileV2_MLSD_Large()
|
25 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
26 |
+
self.model = model.cuda().eval()
|
27 |
+
|
28 |
+
def __call__(self, input_image, thr_v, thr_d):
|
29 |
+
assert input_image.ndim == 3
|
30 |
+
img = input_image
|
31 |
+
img_output = np.zeros_like(img)
|
32 |
+
try:
|
33 |
+
with torch.no_grad():
|
34 |
+
lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
|
35 |
+
for line in lines:
|
36 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
37 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
38 |
+
except Exception as e:
|
39 |
+
pass
|
40 |
+
return img_output[:, :, 0]
|
src/flux/annotator/mlsd/models/mbv2_mlsd_large.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
if self.upscale:
|
28 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
29 |
+
return torch.cat((a, b), dim=1)
|
30 |
+
|
31 |
+
|
32 |
+
class BlockTypeB(nn.Module):
|
33 |
+
def __init__(self, in_c, out_c):
|
34 |
+
super(BlockTypeB, self).__init__()
|
35 |
+
self.conv1 = nn.Sequential(
|
36 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
37 |
+
nn.BatchNorm2d(in_c),
|
38 |
+
nn.ReLU()
|
39 |
+
)
|
40 |
+
self.conv2 = nn.Sequential(
|
41 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
42 |
+
nn.BatchNorm2d(out_c),
|
43 |
+
nn.ReLU()
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.conv1(x) + x
|
48 |
+
x = self.conv2(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class BlockTypeC(nn.Module):
|
52 |
+
def __init__(self, in_c, out_c):
|
53 |
+
super(BlockTypeC, self).__init__()
|
54 |
+
self.conv1 = nn.Sequential(
|
55 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
56 |
+
nn.BatchNorm2d(in_c),
|
57 |
+
nn.ReLU()
|
58 |
+
)
|
59 |
+
self.conv2 = nn.Sequential(
|
60 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
61 |
+
nn.BatchNorm2d(in_c),
|
62 |
+
nn.ReLU()
|
63 |
+
)
|
64 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.conv1(x)
|
68 |
+
x = self.conv2(x)
|
69 |
+
x = self.conv3(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
def _make_divisible(v, divisor, min_value=None):
|
73 |
+
"""
|
74 |
+
This function is taken from the original tf repo.
|
75 |
+
It ensures that all layers have a channel number that is divisible by 8
|
76 |
+
It can be seen here:
|
77 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
78 |
+
:param v:
|
79 |
+
:param divisor:
|
80 |
+
:param min_value:
|
81 |
+
:return:
|
82 |
+
"""
|
83 |
+
if min_value is None:
|
84 |
+
min_value = divisor
|
85 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
86 |
+
# Make sure that round down does not go down by more than 10%.
|
87 |
+
if new_v < 0.9 * v:
|
88 |
+
new_v += divisor
|
89 |
+
return new_v
|
90 |
+
|
91 |
+
|
92 |
+
class ConvBNReLU(nn.Sequential):
|
93 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
94 |
+
self.channel_pad = out_planes - in_planes
|
95 |
+
self.stride = stride
|
96 |
+
#padding = (kernel_size - 1) // 2
|
97 |
+
|
98 |
+
# TFLite uses slightly different padding than PyTorch
|
99 |
+
if stride == 2:
|
100 |
+
padding = 0
|
101 |
+
else:
|
102 |
+
padding = (kernel_size - 1) // 2
|
103 |
+
|
104 |
+
super(ConvBNReLU, self).__init__(
|
105 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
106 |
+
nn.BatchNorm2d(out_planes),
|
107 |
+
nn.ReLU6(inplace=True)
|
108 |
+
)
|
109 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
110 |
+
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
# TFLite uses different padding
|
114 |
+
if self.stride == 2:
|
115 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
116 |
+
#print(x.shape)
|
117 |
+
|
118 |
+
for module in self:
|
119 |
+
if not isinstance(module, nn.MaxPool2d):
|
120 |
+
x = module(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class InvertedResidual(nn.Module):
|
125 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
126 |
+
super(InvertedResidual, self).__init__()
|
127 |
+
self.stride = stride
|
128 |
+
assert stride in [1, 2]
|
129 |
+
|
130 |
+
hidden_dim = int(round(inp * expand_ratio))
|
131 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
132 |
+
|
133 |
+
layers = []
|
134 |
+
if expand_ratio != 1:
|
135 |
+
# pw
|
136 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
137 |
+
layers.extend([
|
138 |
+
# dw
|
139 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
140 |
+
# pw-linear
|
141 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
142 |
+
nn.BatchNorm2d(oup),
|
143 |
+
])
|
144 |
+
self.conv = nn.Sequential(*layers)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
if self.use_res_connect:
|
148 |
+
return x + self.conv(x)
|
149 |
+
else:
|
150 |
+
return self.conv(x)
|
151 |
+
|
152 |
+
|
153 |
+
class MobileNetV2(nn.Module):
|
154 |
+
def __init__(self, pretrained=True):
|
155 |
+
"""
|
156 |
+
MobileNet V2 main class
|
157 |
+
Args:
|
158 |
+
num_classes (int): Number of classes
|
159 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
160 |
+
inverted_residual_setting: Network structure
|
161 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
162 |
+
Set to 1 to turn off rounding
|
163 |
+
block: Module specifying inverted residual building block for mobilenet
|
164 |
+
"""
|
165 |
+
super(MobileNetV2, self).__init__()
|
166 |
+
|
167 |
+
block = InvertedResidual
|
168 |
+
input_channel = 32
|
169 |
+
last_channel = 1280
|
170 |
+
width_mult = 1.0
|
171 |
+
round_nearest = 8
|
172 |
+
|
173 |
+
inverted_residual_setting = [
|
174 |
+
# t, c, n, s
|
175 |
+
[1, 16, 1, 1],
|
176 |
+
[6, 24, 2, 2],
|
177 |
+
[6, 32, 3, 2],
|
178 |
+
[6, 64, 4, 2],
|
179 |
+
[6, 96, 3, 1],
|
180 |
+
#[6, 160, 3, 2],
|
181 |
+
#[6, 320, 1, 1],
|
182 |
+
]
|
183 |
+
|
184 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
185 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
186 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
187 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
188 |
+
|
189 |
+
# building first layer
|
190 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
191 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
192 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
193 |
+
# building inverted residual blocks
|
194 |
+
for t, c, n, s in inverted_residual_setting:
|
195 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
196 |
+
for i in range(n):
|
197 |
+
stride = s if i == 0 else 1
|
198 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
199 |
+
input_channel = output_channel
|
200 |
+
|
201 |
+
self.features = nn.Sequential(*features)
|
202 |
+
self.fpn_selected = [1, 3, 6, 10, 13]
|
203 |
+
# weight initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
207 |
+
if m.bias is not None:
|
208 |
+
nn.init.zeros_(m.bias)
|
209 |
+
elif isinstance(m, nn.BatchNorm2d):
|
210 |
+
nn.init.ones_(m.weight)
|
211 |
+
nn.init.zeros_(m.bias)
|
212 |
+
elif isinstance(m, nn.Linear):
|
213 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
214 |
+
nn.init.zeros_(m.bias)
|
215 |
+
if pretrained:
|
216 |
+
self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c1, c2, c3, c4, c5 = fpn_features
|
230 |
+
return c1, c2, c3, c4, c5
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Large(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Large, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=False)
|
252 |
+
## A, B
|
253 |
+
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
|
254 |
+
out_c1= 64, out_c2=64,
|
255 |
+
upscale=False)
|
256 |
+
self.block16 = BlockTypeB(128, 64)
|
257 |
+
|
258 |
+
## A, B
|
259 |
+
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
|
260 |
+
out_c1= 64, out_c2= 64)
|
261 |
+
self.block18 = BlockTypeB(128, 64)
|
262 |
+
|
263 |
+
## A, B
|
264 |
+
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
|
265 |
+
out_c1=64, out_c2=64)
|
266 |
+
self.block20 = BlockTypeB(128, 64)
|
267 |
+
|
268 |
+
## A, B, C
|
269 |
+
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
|
270 |
+
out_c1=64, out_c2=64)
|
271 |
+
self.block22 = BlockTypeB(128, 64)
|
272 |
+
|
273 |
+
self.block23 = BlockTypeC(64, 16)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
c1, c2, c3, c4, c5 = self.backbone(x)
|
277 |
+
|
278 |
+
x = self.block15(c4, c5)
|
279 |
+
x = self.block16(x)
|
280 |
+
|
281 |
+
x = self.block17(c3, x)
|
282 |
+
x = self.block18(x)
|
283 |
+
|
284 |
+
x = self.block19(c2, x)
|
285 |
+
x = self.block20(x)
|
286 |
+
|
287 |
+
x = self.block21(c1, x)
|
288 |
+
x = self.block22(x)
|
289 |
+
x = self.block23(x)
|
290 |
+
x = x[:, 7:, :, :]
|
291 |
+
|
292 |
+
return x
|
src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
28 |
+
return torch.cat((a, b), dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
class BlockTypeB(nn.Module):
|
32 |
+
def __init__(self, in_c, out_c):
|
33 |
+
super(BlockTypeB, self).__init__()
|
34 |
+
self.conv1 = nn.Sequential(
|
35 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
36 |
+
nn.BatchNorm2d(in_c),
|
37 |
+
nn.ReLU()
|
38 |
+
)
|
39 |
+
self.conv2 = nn.Sequential(
|
40 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
41 |
+
nn.BatchNorm2d(out_c),
|
42 |
+
nn.ReLU()
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.conv1(x) + x
|
47 |
+
x = self.conv2(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
class BlockTypeC(nn.Module):
|
51 |
+
def __init__(self, in_c, out_c):
|
52 |
+
super(BlockTypeC, self).__init__()
|
53 |
+
self.conv1 = nn.Sequential(
|
54 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
55 |
+
nn.BatchNorm2d(in_c),
|
56 |
+
nn.ReLU()
|
57 |
+
)
|
58 |
+
self.conv2 = nn.Sequential(
|
59 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
60 |
+
nn.BatchNorm2d(in_c),
|
61 |
+
nn.ReLU()
|
62 |
+
)
|
63 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.conv1(x)
|
67 |
+
x = self.conv2(x)
|
68 |
+
x = self.conv3(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def _make_divisible(v, divisor, min_value=None):
|
72 |
+
"""
|
73 |
+
This function is taken from the original tf repo.
|
74 |
+
It ensures that all layers have a channel number that is divisible by 8
|
75 |
+
It can be seen here:
|
76 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
77 |
+
:param v:
|
78 |
+
:param divisor:
|
79 |
+
:param min_value:
|
80 |
+
:return:
|
81 |
+
"""
|
82 |
+
if min_value is None:
|
83 |
+
min_value = divisor
|
84 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
85 |
+
# Make sure that round down does not go down by more than 10%.
|
86 |
+
if new_v < 0.9 * v:
|
87 |
+
new_v += divisor
|
88 |
+
return new_v
|
89 |
+
|
90 |
+
|
91 |
+
class ConvBNReLU(nn.Sequential):
|
92 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
93 |
+
self.channel_pad = out_planes - in_planes
|
94 |
+
self.stride = stride
|
95 |
+
#padding = (kernel_size - 1) // 2
|
96 |
+
|
97 |
+
# TFLite uses slightly different padding than PyTorch
|
98 |
+
if stride == 2:
|
99 |
+
padding = 0
|
100 |
+
else:
|
101 |
+
padding = (kernel_size - 1) // 2
|
102 |
+
|
103 |
+
super(ConvBNReLU, self).__init__(
|
104 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
105 |
+
nn.BatchNorm2d(out_planes),
|
106 |
+
nn.ReLU6(inplace=True)
|
107 |
+
)
|
108 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# TFLite uses different padding
|
113 |
+
if self.stride == 2:
|
114 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
115 |
+
#print(x.shape)
|
116 |
+
|
117 |
+
for module in self:
|
118 |
+
if not isinstance(module, nn.MaxPool2d):
|
119 |
+
x = module(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class InvertedResidual(nn.Module):
|
124 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
125 |
+
super(InvertedResidual, self).__init__()
|
126 |
+
self.stride = stride
|
127 |
+
assert stride in [1, 2]
|
128 |
+
|
129 |
+
hidden_dim = int(round(inp * expand_ratio))
|
130 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
131 |
+
|
132 |
+
layers = []
|
133 |
+
if expand_ratio != 1:
|
134 |
+
# pw
|
135 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
136 |
+
layers.extend([
|
137 |
+
# dw
|
138 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
139 |
+
# pw-linear
|
140 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
141 |
+
nn.BatchNorm2d(oup),
|
142 |
+
])
|
143 |
+
self.conv = nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
if self.use_res_connect:
|
147 |
+
return x + self.conv(x)
|
148 |
+
else:
|
149 |
+
return self.conv(x)
|
150 |
+
|
151 |
+
|
152 |
+
class MobileNetV2(nn.Module):
|
153 |
+
def __init__(self, pretrained=True):
|
154 |
+
"""
|
155 |
+
MobileNet V2 main class
|
156 |
+
Args:
|
157 |
+
num_classes (int): Number of classes
|
158 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
159 |
+
inverted_residual_setting: Network structure
|
160 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
161 |
+
Set to 1 to turn off rounding
|
162 |
+
block: Module specifying inverted residual building block for mobilenet
|
163 |
+
"""
|
164 |
+
super(MobileNetV2, self).__init__()
|
165 |
+
|
166 |
+
block = InvertedResidual
|
167 |
+
input_channel = 32
|
168 |
+
last_channel = 1280
|
169 |
+
width_mult = 1.0
|
170 |
+
round_nearest = 8
|
171 |
+
|
172 |
+
inverted_residual_setting = [
|
173 |
+
# t, c, n, s
|
174 |
+
[1, 16, 1, 1],
|
175 |
+
[6, 24, 2, 2],
|
176 |
+
[6, 32, 3, 2],
|
177 |
+
[6, 64, 4, 2],
|
178 |
+
#[6, 96, 3, 1],
|
179 |
+
#[6, 160, 3, 2],
|
180 |
+
#[6, 320, 1, 1],
|
181 |
+
]
|
182 |
+
|
183 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
184 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
185 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
186 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
187 |
+
|
188 |
+
# building first layer
|
189 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
190 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
191 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
192 |
+
# building inverted residual blocks
|
193 |
+
for t, c, n, s in inverted_residual_setting:
|
194 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
195 |
+
for i in range(n):
|
196 |
+
stride = s if i == 0 else 1
|
197 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
198 |
+
input_channel = output_channel
|
199 |
+
self.features = nn.Sequential(*features)
|
200 |
+
|
201 |
+
self.fpn_selected = [3, 6, 10]
|
202 |
+
# weight initialization
|
203 |
+
for m in self.modules():
|
204 |
+
if isinstance(m, nn.Conv2d):
|
205 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
206 |
+
if m.bias is not None:
|
207 |
+
nn.init.zeros_(m.bias)
|
208 |
+
elif isinstance(m, nn.BatchNorm2d):
|
209 |
+
nn.init.ones_(m.weight)
|
210 |
+
nn.init.zeros_(m.bias)
|
211 |
+
elif isinstance(m, nn.Linear):
|
212 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
213 |
+
nn.init.zeros_(m.bias)
|
214 |
+
|
215 |
+
#if pretrained:
|
216 |
+
# self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c2, c3, c4 = fpn_features
|
230 |
+
return c2, c3, c4
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Tiny(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Tiny, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=True)
|
252 |
+
|
253 |
+
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
|
254 |
+
out_c1= 64, out_c2=64)
|
255 |
+
self.block13 = BlockTypeB(128, 64)
|
256 |
+
|
257 |
+
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
|
258 |
+
out_c1= 32, out_c2= 32)
|
259 |
+
self.block15 = BlockTypeB(64, 64)
|
260 |
+
|
261 |
+
self.block16 = BlockTypeC(64, 16)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
c2, c3, c4 = self.backbone(x)
|
265 |
+
|
266 |
+
x = self.block12(c3, c4)
|
267 |
+
x = self.block13(x)
|
268 |
+
x = self.block14(c2, x)
|
269 |
+
x = self.block15(x)
|
270 |
+
x = self.block16(x)
|
271 |
+
x = x[:, 7:, :, :]
|
272 |
+
#print(x.shape)
|
273 |
+
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
|
274 |
+
|
275 |
+
return x
|
src/flux/annotator/mlsd/utils.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
modified by lihaoweicv
|
3 |
+
pytorch version
|
4 |
+
'''
|
5 |
+
|
6 |
+
'''
|
7 |
+
M-LSD
|
8 |
+
Copyright 2021-present NAVER Corp.
|
9 |
+
Apache License v2.0
|
10 |
+
'''
|
11 |
+
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import cv2
|
15 |
+
import torch
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
20 |
+
'''
|
21 |
+
tpMap:
|
22 |
+
center: tpMap[1, 0, :, :]
|
23 |
+
displacement: tpMap[1, 1:5, :, :]
|
24 |
+
'''
|
25 |
+
b, c, h, w = tpMap.shape
|
26 |
+
assert b==1, 'only support bsize==1'
|
27 |
+
displacement = tpMap[:, 1:5, :, :][0]
|
28 |
+
center = tpMap[:, 0, :, :]
|
29 |
+
heat = torch.sigmoid(center)
|
30 |
+
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
|
31 |
+
keep = (hmax == heat).float()
|
32 |
+
heat = heat * keep
|
33 |
+
heat = heat.reshape(-1, )
|
34 |
+
|
35 |
+
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
36 |
+
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
37 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
38 |
+
ptss = torch.cat((yy, xx),dim=-1)
|
39 |
+
|
40 |
+
ptss = ptss.detach().cpu().numpy()
|
41 |
+
scores = scores.detach().cpu().numpy()
|
42 |
+
displacement = displacement.detach().cpu().numpy()
|
43 |
+
displacement = displacement.transpose((1,2,0))
|
44 |
+
return ptss, scores, displacement
|
45 |
+
|
46 |
+
|
47 |
+
def pred_lines(image, model,
|
48 |
+
input_shape=[512, 512],
|
49 |
+
score_thr=0.10,
|
50 |
+
dist_thr=20.0):
|
51 |
+
h, w, _ = image.shape
|
52 |
+
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
53 |
+
|
54 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
55 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
56 |
+
|
57 |
+
resized_image = resized_image.transpose((2,0,1))
|
58 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
59 |
+
batch_image = (batch_image / 127.5) - 1.0
|
60 |
+
|
61 |
+
batch_image = torch.from_numpy(batch_image).float().to("cuda:4")
|
62 |
+
outputs = model(batch_image)
|
63 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
64 |
+
start = vmap[:, :, :2]
|
65 |
+
end = vmap[:, :, 2:]
|
66 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
67 |
+
|
68 |
+
segments_list = []
|
69 |
+
for center, score in zip(pts, pts_score):
|
70 |
+
y, x = center
|
71 |
+
distance = dist_map[y, x]
|
72 |
+
if score > score_thr and distance > dist_thr:
|
73 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
74 |
+
x_start = x + disp_x_start
|
75 |
+
y_start = y + disp_y_start
|
76 |
+
x_end = x + disp_x_end
|
77 |
+
y_end = y + disp_y_end
|
78 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
79 |
+
|
80 |
+
lines = 2 * np.array(segments_list) # 256 > 512
|
81 |
+
lines[:, 0] = lines[:, 0] * w_ratio
|
82 |
+
lines[:, 1] = lines[:, 1] * h_ratio
|
83 |
+
lines[:, 2] = lines[:, 2] * w_ratio
|
84 |
+
lines[:, 3] = lines[:, 3] * h_ratio
|
85 |
+
|
86 |
+
return lines
|
87 |
+
|
88 |
+
|
89 |
+
def pred_squares(image,
|
90 |
+
model,
|
91 |
+
input_shape=[512, 512],
|
92 |
+
params={'score': 0.06,
|
93 |
+
'outside_ratio': 0.28,
|
94 |
+
'inside_ratio': 0.45,
|
95 |
+
'w_overlap': 0.0,
|
96 |
+
'w_degree': 1.95,
|
97 |
+
'w_length': 0.0,
|
98 |
+
'w_area': 1.86,
|
99 |
+
'w_center': 0.14}):
|
100 |
+
'''
|
101 |
+
shape = [height, width]
|
102 |
+
'''
|
103 |
+
h, w, _ = image.shape
|
104 |
+
original_shape = [h, w]
|
105 |
+
|
106 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
107 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
108 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
109 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
110 |
+
batch_image = (batch_image / 127.5) - 1.0
|
111 |
+
|
112 |
+
batch_image = torch.from_numpy(batch_image).float().cuda()
|
113 |
+
outputs = model(batch_image)
|
114 |
+
|
115 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
116 |
+
start = vmap[:, :, :2] # (x, y)
|
117 |
+
end = vmap[:, :, 2:] # (x, y)
|
118 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
119 |
+
|
120 |
+
junc_list = []
|
121 |
+
segments_list = []
|
122 |
+
for junc, score in zip(pts, pts_score):
|
123 |
+
y, x = junc
|
124 |
+
distance = dist_map[y, x]
|
125 |
+
if score > params['score'] and distance > 20.0:
|
126 |
+
junc_list.append([x, y])
|
127 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
128 |
+
d_arrow = 1.0
|
129 |
+
x_start = x + d_arrow * disp_x_start
|
130 |
+
y_start = y + d_arrow * disp_y_start
|
131 |
+
x_end = x + d_arrow * disp_x_end
|
132 |
+
y_end = y + d_arrow * disp_y_end
|
133 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
134 |
+
|
135 |
+
segments = np.array(segments_list)
|
136 |
+
|
137 |
+
####### post processing for squares
|
138 |
+
# 1. get unique lines
|
139 |
+
point = np.array([[0, 0]])
|
140 |
+
point = point[0]
|
141 |
+
start = segments[:, :2]
|
142 |
+
end = segments[:, 2:]
|
143 |
+
diff = start - end
|
144 |
+
a = diff[:, 1]
|
145 |
+
b = -diff[:, 0]
|
146 |
+
c = a * start[:, 0] + b * start[:, 1]
|
147 |
+
|
148 |
+
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
|
149 |
+
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
150 |
+
theta[theta < 0.0] += 180
|
151 |
+
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
152 |
+
|
153 |
+
d_quant = 1
|
154 |
+
theta_quant = 2
|
155 |
+
hough[:, 0] //= d_quant
|
156 |
+
hough[:, 1] //= theta_quant
|
157 |
+
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
|
158 |
+
|
159 |
+
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
|
160 |
+
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
|
161 |
+
yx_indices = hough[indices, :].astype('int32')
|
162 |
+
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
163 |
+
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
164 |
+
|
165 |
+
acc_map_np = acc_map
|
166 |
+
# acc_map = acc_map[None, :, :, None]
|
167 |
+
#
|
168 |
+
# ### fast suppression using tensorflow op
|
169 |
+
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
170 |
+
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
171 |
+
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
172 |
+
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
173 |
+
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
174 |
+
# _, h, w, _ = acc_map.shape
|
175 |
+
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
176 |
+
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
177 |
+
# yx = tf.concat([y, x], axis=-1)
|
178 |
+
|
179 |
+
### fast suppression using pytorch op
|
180 |
+
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
181 |
+
_,_, h, w = acc_map.shape
|
182 |
+
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
|
183 |
+
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
|
184 |
+
flatten_acc_map = acc_map.reshape([-1, ])
|
185 |
+
|
186 |
+
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
|
187 |
+
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
188 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
189 |
+
yx = torch.cat((yy, xx), dim=-1)
|
190 |
+
|
191 |
+
yx = yx.detach().cpu().numpy()
|
192 |
+
|
193 |
+
topk_values = scores.detach().cpu().numpy()
|
194 |
+
indices = idx_map[yx[:, 0], yx[:, 1]]
|
195 |
+
basis = 5 // 2
|
196 |
+
|
197 |
+
merged_segments = []
|
198 |
+
for yx_pt, max_indice, value in zip(yx, indices, topk_values):
|
199 |
+
y, x = yx_pt
|
200 |
+
if max_indice == -1 or value == 0:
|
201 |
+
continue
|
202 |
+
segment_list = []
|
203 |
+
for y_offset in range(-basis, basis + 1):
|
204 |
+
for x_offset in range(-basis, basis + 1):
|
205 |
+
indice = idx_map[y + y_offset, x + x_offset]
|
206 |
+
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
207 |
+
if indice != -1:
|
208 |
+
segment_list.append(segments[indice])
|
209 |
+
if cnt > 1:
|
210 |
+
check_cnt = 1
|
211 |
+
current_hough = hough[indice]
|
212 |
+
for new_indice, new_hough in enumerate(hough):
|
213 |
+
if (current_hough == new_hough).all() and indice != new_indice:
|
214 |
+
segment_list.append(segments[new_indice])
|
215 |
+
check_cnt += 1
|
216 |
+
if check_cnt == cnt:
|
217 |
+
break
|
218 |
+
group_segments = np.array(segment_list).reshape([-1, 2])
|
219 |
+
sorted_group_segments = np.sort(group_segments, axis=0)
|
220 |
+
x_min, y_min = sorted_group_segments[0, :]
|
221 |
+
x_max, y_max = sorted_group_segments[-1, :]
|
222 |
+
|
223 |
+
deg = theta[max_indice]
|
224 |
+
if deg >= 90:
|
225 |
+
merged_segments.append([x_min, y_max, x_max, y_min])
|
226 |
+
else:
|
227 |
+
merged_segments.append([x_min, y_min, x_max, y_max])
|
228 |
+
|
229 |
+
# 2. get intersections
|
230 |
+
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
231 |
+
start = new_segments[:, :2] # (x1, y1)
|
232 |
+
end = new_segments[:, 2:] # (x2, y2)
|
233 |
+
new_centers = (start + end) / 2.0
|
234 |
+
diff = start - end
|
235 |
+
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
|
236 |
+
|
237 |
+
# ax + by = c
|
238 |
+
a = diff[:, 1]
|
239 |
+
b = -diff[:, 0]
|
240 |
+
c = a * start[:, 0] + b * start[:, 1]
|
241 |
+
pre_det = a[:, None] * b[None, :]
|
242 |
+
det = pre_det - np.transpose(pre_det)
|
243 |
+
|
244 |
+
pre_inter_y = a[:, None] * c[None, :]
|
245 |
+
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
246 |
+
pre_inter_x = c[:, None] * b[None, :]
|
247 |
+
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
248 |
+
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
|
249 |
+
|
250 |
+
# 3. get corner information
|
251 |
+
# 3.1 get distance
|
252 |
+
'''
|
253 |
+
dist_segments:
|
254 |
+
| dist(0), dist(1), dist(2), ...|
|
255 |
+
dist_inter_to_segment1:
|
256 |
+
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
257 |
+
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
258 |
+
...
|
259 |
+
dist_inter_to_semgnet2:
|
260 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
261 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
262 |
+
...
|
263 |
+
'''
|
264 |
+
|
265 |
+
dist_inter_to_segment1_start = np.sqrt(
|
266 |
+
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
267 |
+
dist_inter_to_segment1_end = np.sqrt(
|
268 |
+
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
269 |
+
dist_inter_to_segment2_start = np.sqrt(
|
270 |
+
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
271 |
+
dist_inter_to_segment2_end = np.sqrt(
|
272 |
+
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
273 |
+
|
274 |
+
# sort ascending
|
275 |
+
dist_inter_to_segment1 = np.sort(
|
276 |
+
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
277 |
+
axis=-1) # [n_batch, n_batch, 2]
|
278 |
+
dist_inter_to_segment2 = np.sort(
|
279 |
+
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
280 |
+
axis=-1) # [n_batch, n_batch, 2]
|
281 |
+
|
282 |
+
# 3.2 get degree
|
283 |
+
inter_to_start = new_centers[:, None, :] - inter_pts
|
284 |
+
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
|
285 |
+
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
286 |
+
inter_to_end = new_centers[None, :, :] - inter_pts
|
287 |
+
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
|
288 |
+
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
289 |
+
|
290 |
+
'''
|
291 |
+
B -- G
|
292 |
+
| |
|
293 |
+
C -- R
|
294 |
+
B : blue / G: green / C: cyan / R: red
|
295 |
+
|
296 |
+
0 -- 1
|
297 |
+
| |
|
298 |
+
3 -- 2
|
299 |
+
'''
|
300 |
+
# rename variables
|
301 |
+
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
302 |
+
# sort deg ascending
|
303 |
+
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
|
304 |
+
|
305 |
+
deg_diff_map = np.abs(deg1_map - deg2_map)
|
306 |
+
# we only consider the smallest degree of intersect
|
307 |
+
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
308 |
+
|
309 |
+
# define available degree range
|
310 |
+
deg_range = [60, 120]
|
311 |
+
|
312 |
+
corner_dict = {corner_info: [] for corner_info in range(4)}
|
313 |
+
inter_points = []
|
314 |
+
for i in range(inter_pts.shape[0]):
|
315 |
+
for j in range(i + 1, inter_pts.shape[1]):
|
316 |
+
# i, j > line index, always i < j
|
317 |
+
x, y = inter_pts[i, j, :]
|
318 |
+
deg1, deg2 = deg_sort[i, j, :]
|
319 |
+
deg_diff = deg_diff_map[i, j]
|
320 |
+
|
321 |
+
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
322 |
+
|
323 |
+
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
324 |
+
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
325 |
+
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
|
326 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
|
327 |
+
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
|
328 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
329 |
+
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
|
330 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
|
331 |
+
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
|
332 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
333 |
+
|
334 |
+
if check_degree and check_distance:
|
335 |
+
corner_info = None
|
336 |
+
|
337 |
+
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
338 |
+
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
339 |
+
corner_info, color_info = 0, 'blue'
|
340 |
+
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
|
341 |
+
corner_info, color_info = 1, 'green'
|
342 |
+
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
|
343 |
+
corner_info, color_info = 2, 'black'
|
344 |
+
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
345 |
+
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
346 |
+
corner_info, color_info = 3, 'cyan'
|
347 |
+
else:
|
348 |
+
corner_info, color_info = 4, 'red' # we don't use it
|
349 |
+
continue
|
350 |
+
|
351 |
+
corner_dict[corner_info].append([x, y, i, j])
|
352 |
+
inter_points.append([x, y])
|
353 |
+
|
354 |
+
square_list = []
|
355 |
+
connect_list = []
|
356 |
+
segments_list = []
|
357 |
+
for corner0 in corner_dict[0]:
|
358 |
+
for corner1 in corner_dict[1]:
|
359 |
+
connect01 = False
|
360 |
+
for corner0_line in corner0[2:]:
|
361 |
+
if corner0_line in corner1[2:]:
|
362 |
+
connect01 = True
|
363 |
+
break
|
364 |
+
if connect01:
|
365 |
+
for corner2 in corner_dict[2]:
|
366 |
+
connect12 = False
|
367 |
+
for corner1_line in corner1[2:]:
|
368 |
+
if corner1_line in corner2[2:]:
|
369 |
+
connect12 = True
|
370 |
+
break
|
371 |
+
if connect12:
|
372 |
+
for corner3 in corner_dict[3]:
|
373 |
+
connect23 = False
|
374 |
+
for corner2_line in corner2[2:]:
|
375 |
+
if corner2_line in corner3[2:]:
|
376 |
+
connect23 = True
|
377 |
+
break
|
378 |
+
if connect23:
|
379 |
+
for corner3_line in corner3[2:]:
|
380 |
+
if corner3_line in corner0[2:]:
|
381 |
+
# SQUARE!!!
|
382 |
+
'''
|
383 |
+
0 -- 1
|
384 |
+
| |
|
385 |
+
3 -- 2
|
386 |
+
square_list:
|
387 |
+
order: 0 > 1 > 2 > 3
|
388 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
389 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
390 |
+
...
|
391 |
+
connect_list:
|
392 |
+
order: 01 > 12 > 23 > 30
|
393 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
394 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
395 |
+
...
|
396 |
+
segments_list:
|
397 |
+
order: 0 > 1 > 2 > 3
|
398 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
399 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
400 |
+
...
|
401 |
+
'''
|
402 |
+
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
|
403 |
+
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
|
404 |
+
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
|
405 |
+
|
406 |
+
def check_outside_inside(segments_info, connect_idx):
|
407 |
+
# return 'outside or inside', min distance, cover_param, peri_param
|
408 |
+
if connect_idx == segments_info[0]:
|
409 |
+
check_dist_mat = dist_inter_to_segment1
|
410 |
+
else:
|
411 |
+
check_dist_mat = dist_inter_to_segment2
|
412 |
+
|
413 |
+
i, j = segments_info
|
414 |
+
min_dist, max_dist = check_dist_mat[i, j, :]
|
415 |
+
connect_dist = dist_segments[connect_idx]
|
416 |
+
if max_dist > connect_dist:
|
417 |
+
return 'outside', min_dist, 0, 1
|
418 |
+
else:
|
419 |
+
return 'inside', min_dist, -1, -1
|
420 |
+
|
421 |
+
top_square = None
|
422 |
+
|
423 |
+
try:
|
424 |
+
map_size = input_shape[0] / 2
|
425 |
+
squares = np.array(square_list).reshape([-1, 4, 2])
|
426 |
+
score_array = []
|
427 |
+
connect_array = np.array(connect_list)
|
428 |
+
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
429 |
+
|
430 |
+
# get degree of corners:
|
431 |
+
squares_rollup = np.roll(squares, 1, axis=1)
|
432 |
+
squares_rolldown = np.roll(squares, -1, axis=1)
|
433 |
+
vec1 = squares_rollup - squares
|
434 |
+
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
435 |
+
vec2 = squares_rolldown - squares
|
436 |
+
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
437 |
+
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
|
438 |
+
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
|
439 |
+
|
440 |
+
# get square score
|
441 |
+
overlap_scores = []
|
442 |
+
degree_scores = []
|
443 |
+
length_scores = []
|
444 |
+
|
445 |
+
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
|
446 |
+
'''
|
447 |
+
0 -- 1
|
448 |
+
| |
|
449 |
+
3 -- 2
|
450 |
+
|
451 |
+
# segments: [4, 2]
|
452 |
+
# connects: [4]
|
453 |
+
'''
|
454 |
+
|
455 |
+
###################################### OVERLAP SCORES
|
456 |
+
cover = 0
|
457 |
+
perimeter = 0
|
458 |
+
# check 0 > 1 > 2 > 3
|
459 |
+
square_length = []
|
460 |
+
|
461 |
+
for start_idx in range(4):
|
462 |
+
end_idx = (start_idx + 1) % 4
|
463 |
+
|
464 |
+
connect_idx = connects[start_idx] # segment idx of segment01
|
465 |
+
start_segments = segments[start_idx]
|
466 |
+
end_segments = segments[end_idx]
|
467 |
+
|
468 |
+
start_point = square[start_idx]
|
469 |
+
end_point = square[end_idx]
|
470 |
+
|
471 |
+
# check whether outside or inside
|
472 |
+
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
|
473 |
+
connect_idx)
|
474 |
+
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
|
475 |
+
|
476 |
+
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
477 |
+
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
478 |
+
|
479 |
+
square_length.append(
|
480 |
+
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
|
481 |
+
|
482 |
+
overlap_scores.append(cover / perimeter)
|
483 |
+
######################################
|
484 |
+
###################################### DEGREE SCORES
|
485 |
+
'''
|
486 |
+
deg0 vs deg2
|
487 |
+
deg1 vs deg3
|
488 |
+
'''
|
489 |
+
deg0, deg1, deg2, deg3 = degree
|
490 |
+
deg_ratio1 = deg0 / deg2
|
491 |
+
if deg_ratio1 > 1.0:
|
492 |
+
deg_ratio1 = 1 / deg_ratio1
|
493 |
+
deg_ratio2 = deg1 / deg3
|
494 |
+
if deg_ratio2 > 1.0:
|
495 |
+
deg_ratio2 = 1 / deg_ratio2
|
496 |
+
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
497 |
+
######################################
|
498 |
+
###################################### LENGTH SCORES
|
499 |
+
'''
|
500 |
+
len0 vs len2
|
501 |
+
len1 vs len3
|
502 |
+
'''
|
503 |
+
len0, len1, len2, len3 = square_length
|
504 |
+
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
505 |
+
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
506 |
+
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
507 |
+
|
508 |
+
######################################
|
509 |
+
|
510 |
+
overlap_scores = np.array(overlap_scores)
|
511 |
+
overlap_scores /= np.max(overlap_scores)
|
512 |
+
|
513 |
+
degree_scores = np.array(degree_scores)
|
514 |
+
# degree_scores /= np.max(degree_scores)
|
515 |
+
|
516 |
+
length_scores = np.array(length_scores)
|
517 |
+
|
518 |
+
###################################### AREA SCORES
|
519 |
+
area_scores = np.reshape(squares, [-1, 4, 2])
|
520 |
+
area_x = area_scores[:, :, 0]
|
521 |
+
area_y = area_scores[:, :, 1]
|
522 |
+
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
|
523 |
+
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
524 |
+
area_scores = 0.5 * np.abs(area_scores + correction)
|
525 |
+
area_scores /= (map_size * map_size) # np.max(area_scores)
|
526 |
+
######################################
|
527 |
+
|
528 |
+
###################################### CENTER SCORES
|
529 |
+
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
530 |
+
# squares: [n, 4, 2]
|
531 |
+
square_centers = np.mean(squares, axis=1) # [n, 2]
|
532 |
+
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
|
533 |
+
center_scores = center2center / (map_size / np.sqrt(2.0))
|
534 |
+
|
535 |
+
'''
|
536 |
+
score_w = [overlap, degree, area, center, length]
|
537 |
+
'''
|
538 |
+
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
|
539 |
+
score_array = params['w_overlap'] * overlap_scores \
|
540 |
+
+ params['w_degree'] * degree_scores \
|
541 |
+
+ params['w_area'] * area_scores \
|
542 |
+
- params['w_center'] * center_scores \
|
543 |
+
+ params['w_length'] * length_scores
|
544 |
+
|
545 |
+
best_square = []
|
546 |
+
|
547 |
+
sorted_idx = np.argsort(score_array)[::-1]
|
548 |
+
score_array = score_array[sorted_idx]
|
549 |
+
squares = squares[sorted_idx]
|
550 |
+
|
551 |
+
except Exception as e:
|
552 |
+
pass
|
553 |
+
|
554 |
+
'''return list
|
555 |
+
merged_lines, squares, scores
|
556 |
+
'''
|
557 |
+
|
558 |
+
try:
|
559 |
+
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
|
560 |
+
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
|
561 |
+
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
|
562 |
+
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
|
563 |
+
except:
|
564 |
+
new_segments = []
|
565 |
+
|
566 |
+
try:
|
567 |
+
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
568 |
+
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
569 |
+
except:
|
570 |
+
squares = []
|
571 |
+
score_array = []
|
572 |
+
|
573 |
+
try:
|
574 |
+
inter_points = np.array(inter_points)
|
575 |
+
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
|
576 |
+
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
|
577 |
+
except:
|
578 |
+
inter_points = []
|
579 |
+
|
580 |
+
return new_segments, squares, score_array, inter_points
|
src/flux/annotator/tile/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import cv2
|
3 |
+
from .guided_filter import FastGuidedFilter
|
4 |
+
|
5 |
+
|
6 |
+
class TileDetector:
|
7 |
+
# https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def __call__(self, image):
|
12 |
+
blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0]
|
13 |
+
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
|
14 |
+
eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0]
|
15 |
+
scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0]
|
16 |
+
|
17 |
+
ksize = int(blur_strength)
|
18 |
+
if ksize % 2 == 0:
|
19 |
+
ksize += 1
|
20 |
+
|
21 |
+
if random.random() > 0.5:
|
22 |
+
image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2)
|
23 |
+
if random.random() > 0.5:
|
24 |
+
filter = FastGuidedFilter(image, radius, eps, scale_factor)
|
25 |
+
image = filter.filter(image)
|
26 |
+
return image
|
src/flux/annotator/tile/guided_filter.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
## @package guided_filter.core.filters
|
3 |
+
#
|
4 |
+
# Implementation of guided filter.
|
5 |
+
# * GuidedFilter: Original guided filter.
|
6 |
+
# * FastGuidedFilter: Fast version of the guided filter.
|
7 |
+
# @author tody
|
8 |
+
# @date 2015/08/26
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
## Convert image into float32 type.
|
14 |
+
def to32F(img):
|
15 |
+
if img.dtype == np.float32:
|
16 |
+
return img
|
17 |
+
return (1.0 / 255.0) * np.float32(img)
|
18 |
+
|
19 |
+
## Convert image into uint8 type.
|
20 |
+
def to8U(img):
|
21 |
+
if img.dtype == np.uint8:
|
22 |
+
return img
|
23 |
+
return np.clip(np.uint8(255.0 * img), 0, 255)
|
24 |
+
|
25 |
+
## Return if the input image is gray or not.
|
26 |
+
def _isGray(I):
|
27 |
+
return len(I.shape) == 2
|
28 |
+
|
29 |
+
|
30 |
+
## Return down sampled image.
|
31 |
+
# @param scale (w/s, h/s) image will be created.
|
32 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
33 |
+
def _downSample(I, scale=4, shape=None):
|
34 |
+
if shape is not None:
|
35 |
+
h, w = shape
|
36 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
|
37 |
+
|
38 |
+
h, w = I.shape[:2]
|
39 |
+
return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
|
40 |
+
|
41 |
+
|
42 |
+
## Return up sampled image.
|
43 |
+
# @param scale (w*s, h*s) image will be created.
|
44 |
+
# @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
|
45 |
+
def _upSample(I, scale=2, shape=None):
|
46 |
+
if shape is not None:
|
47 |
+
h, w = shape
|
48 |
+
return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
|
49 |
+
|
50 |
+
h, w = I.shape[:2]
|
51 |
+
return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
|
52 |
+
|
53 |
+
## Fast guide filter.
|
54 |
+
class FastGuidedFilter:
|
55 |
+
## Constructor.
|
56 |
+
# @param I Input guidance image. Color or gray.
|
57 |
+
# @param radius Radius of Guided Filter.
|
58 |
+
# @param epsilon Regularization term of Guided Filter.
|
59 |
+
# @param scale Down sampled scale.
|
60 |
+
def __init__(self, I, radius=5, epsilon=0.4, scale=4):
|
61 |
+
I_32F = to32F(I)
|
62 |
+
self._I = I_32F
|
63 |
+
h, w = I.shape[:2]
|
64 |
+
|
65 |
+
I_sub = _downSample(I_32F, scale)
|
66 |
+
|
67 |
+
self._I_sub = I_sub
|
68 |
+
radius = int(radius / scale)
|
69 |
+
|
70 |
+
if _isGray(I):
|
71 |
+
self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
|
72 |
+
else:
|
73 |
+
self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
|
74 |
+
|
75 |
+
## Apply filter for the input image.
|
76 |
+
# @param p Input image for the filtering.
|
77 |
+
def filter(self, p):
|
78 |
+
p_32F = to32F(p)
|
79 |
+
shape_original = p.shape[:2]
|
80 |
+
|
81 |
+
p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
|
82 |
+
|
83 |
+
if _isGray(p_sub):
|
84 |
+
return self._filterGray(p_sub, shape_original)
|
85 |
+
|
86 |
+
cs = p.shape[2]
|
87 |
+
q = np.array(p_32F)
|
88 |
+
|
89 |
+
for ci in range(cs):
|
90 |
+
q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
|
91 |
+
return to8U(q)
|
92 |
+
|
93 |
+
def _filterGray(self, p_sub, shape_original):
|
94 |
+
ab_sub = self._guided_filter._computeCoefficients(p_sub)
|
95 |
+
ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
|
96 |
+
return self._guided_filter._computeOutput(ab, self._I)
|
97 |
+
|
98 |
+
|
99 |
+
## Guide filter.
|
100 |
+
class GuidedFilter:
|
101 |
+
## Constructor.
|
102 |
+
# @param I Input guidance image. Color or gray.
|
103 |
+
# @param radius Radius of Guided Filter.
|
104 |
+
# @param epsilon Regularization term of Guided Filter.
|
105 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
106 |
+
I_32F = to32F(I)
|
107 |
+
|
108 |
+
if _isGray(I):
|
109 |
+
self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
|
110 |
+
else:
|
111 |
+
self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
|
112 |
+
|
113 |
+
## Apply filter for the input image.
|
114 |
+
# @param p Input image for the filtering.
|
115 |
+
def filter(self, p):
|
116 |
+
return to8U(self._guided_filter.filter(p))
|
117 |
+
|
118 |
+
|
119 |
+
## Common parts of guided filter.
|
120 |
+
#
|
121 |
+
# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
|
122 |
+
# Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
|
123 |
+
# GuidedFilterCommon.filter computes filtered image for color and gray.
|
124 |
+
class GuidedFilterCommon:
|
125 |
+
def __init__(self, guided_filter):
|
126 |
+
self._guided_filter = guided_filter
|
127 |
+
|
128 |
+
## Apply filter for the input image.
|
129 |
+
# @param p Input image for the filtering.
|
130 |
+
def filter(self, p):
|
131 |
+
p_32F = to32F(p)
|
132 |
+
if _isGray(p_32F):
|
133 |
+
return self._filterGray(p_32F)
|
134 |
+
|
135 |
+
cs = p.shape[2]
|
136 |
+
q = np.array(p_32F)
|
137 |
+
|
138 |
+
for ci in range(cs):
|
139 |
+
q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
|
140 |
+
return q
|
141 |
+
|
142 |
+
def _filterGray(self, p):
|
143 |
+
ab = self._guided_filter._computeCoefficients(p)
|
144 |
+
return self._guided_filter._computeOutput(ab, self._guided_filter._I)
|
145 |
+
|
146 |
+
|
147 |
+
## Guided filter for gray guidance image.
|
148 |
+
class GuidedFilterGray:
|
149 |
+
# @param I Input gray guidance image.
|
150 |
+
# @param radius Radius of Guided Filter.
|
151 |
+
# @param epsilon Regularization term of Guided Filter.
|
152 |
+
def __init__(self, I, radius=5, epsilon=0.4):
|
153 |
+
self._radius = 2 * radius + 1
|
154 |
+
self._epsilon = epsilon
|
155 |
+
self._I = to32F(I)
|
156 |
+
self._initFilter()
|
157 |
+
self._filter_common = GuidedFilterCommon(self)
|
158 |
+
|
159 |
+
## Apply filter for the input image.
|
160 |
+
# @param p Input image for the filtering.
|
161 |
+
def filter(self, p):
|
162 |
+
return self._filter_common.filter(p)
|
163 |
+
|
164 |
+
def _initFilter(self):
|
165 |
+
I = self._I
|
166 |
+
r = self._radius
|
167 |
+
self._I_mean = cv2.blur(I, (r, r))
|
168 |
+
I_mean_sq = cv2.blur(I ** 2, (r, r))
|
169 |
+
self._I_var = I_mean_sq - self._I_mean ** 2
|
170 |
+
|
171 |
+
def _computeCoefficients(self, p):
|
172 |
+
r = self._radius
|
173 |
+
p_mean = cv2.blur(p, (r, r))
|
174 |
+
p_cov = p_mean - self._I_mean * p_mean
|
175 |
+
a = p_cov / (self._I_var + self._epsilon)
|
176 |
+
b = p_mean - a * self._I_mean
|
177 |
+
a_mean = cv2.blur(a, (r, r))
|
178 |
+
b_mean = cv2.blur(b, (r, r))
|
179 |
+
return a_mean, b_mean
|
180 |
+
|
181 |
+
def _computeOutput(self, ab, I):
|
182 |
+
a_mean, b_mean = ab
|
183 |
+
return a_mean * I + b_mean
|
184 |
+
|
185 |
+
|
186 |
+
## Guided filter for color guidance image.
|
187 |
+
class GuidedFilterColor:
|
188 |
+
# @param I Input color guidance image.
|
189 |
+
# @param radius Radius of Guided Filter.
|
190 |
+
# @param epsilon Regularization term of Guided Filter.
|
191 |
+
def __init__(self, I, radius=5, epsilon=0.2):
|
192 |
+
self._radius = 2 * radius + 1
|
193 |
+
self._epsilon = epsilon
|
194 |
+
self._I = to32F(I)
|
195 |
+
self._initFilter()
|
196 |
+
self._filter_common = GuidedFilterCommon(self)
|
197 |
+
|
198 |
+
## Apply filter for the input image.
|
199 |
+
# @param p Input image for the filtering.
|
200 |
+
def filter(self, p):
|
201 |
+
return self._filter_common.filter(p)
|
202 |
+
|
203 |
+
def _initFilter(self):
|
204 |
+
I = self._I
|
205 |
+
r = self._radius
|
206 |
+
eps = self._epsilon
|
207 |
+
|
208 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
209 |
+
|
210 |
+
self._Ir_mean = cv2.blur(Ir, (r, r))
|
211 |
+
self._Ig_mean = cv2.blur(Ig, (r, r))
|
212 |
+
self._Ib_mean = cv2.blur(Ib, (r, r))
|
213 |
+
|
214 |
+
Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
|
215 |
+
Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
|
216 |
+
Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
|
217 |
+
Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
|
218 |
+
Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
|
219 |
+
Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
|
220 |
+
|
221 |
+
Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
|
222 |
+
Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
|
223 |
+
Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
|
224 |
+
Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
|
225 |
+
Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
|
226 |
+
Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
|
227 |
+
|
228 |
+
I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
|
229 |
+
Irr_inv /= I_cov
|
230 |
+
Irg_inv /= I_cov
|
231 |
+
Irb_inv /= I_cov
|
232 |
+
Igg_inv /= I_cov
|
233 |
+
Igb_inv /= I_cov
|
234 |
+
Ibb_inv /= I_cov
|
235 |
+
|
236 |
+
self._Irr_inv = Irr_inv
|
237 |
+
self._Irg_inv = Irg_inv
|
238 |
+
self._Irb_inv = Irb_inv
|
239 |
+
self._Igg_inv = Igg_inv
|
240 |
+
self._Igb_inv = Igb_inv
|
241 |
+
self._Ibb_inv = Ibb_inv
|
242 |
+
|
243 |
+
def _computeCoefficients(self, p):
|
244 |
+
r = self._radius
|
245 |
+
I = self._I
|
246 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
247 |
+
|
248 |
+
p_mean = cv2.blur(p, (r, r))
|
249 |
+
|
250 |
+
Ipr_mean = cv2.blur(Ir * p, (r, r))
|
251 |
+
Ipg_mean = cv2.blur(Ig * p, (r, r))
|
252 |
+
Ipb_mean = cv2.blur(Ib * p, (r, r))
|
253 |
+
|
254 |
+
Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
|
255 |
+
Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
|
256 |
+
Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
|
257 |
+
|
258 |
+
ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
|
259 |
+
ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
|
260 |
+
ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
|
261 |
+
b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
|
262 |
+
|
263 |
+
ar_mean = cv2.blur(ar, (r, r))
|
264 |
+
ag_mean = cv2.blur(ag, (r, r))
|
265 |
+
ab_mean = cv2.blur(ab, (r, r))
|
266 |
+
b_mean = cv2.blur(b, (r, r))
|
267 |
+
|
268 |
+
return ar_mean, ag_mean, ab_mean, b_mean
|
269 |
+
|
270 |
+
def _computeOutput(self, ab, I):
|
271 |
+
ar_mean, ag_mean, ab_mean, b_mean = ab
|
272 |
+
|
273 |
+
Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
|
274 |
+
|
275 |
+
q = (ar_mean * Ir +
|
276 |
+
ag_mean * Ig +
|
277 |
+
ab_mean * Ib +
|
278 |
+
b_mean)
|
279 |
+
|
280 |
+
return q
|
src/flux/annotator/util.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
7 |
+
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
src/flux/api.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
API_ENDPOINT = "https://api.bfl.ml"
|
10 |
+
|
11 |
+
|
12 |
+
class ApiException(Exception):
|
13 |
+
def __init__(self, status_code: int, detail: str | list[dict] | None = None):
|
14 |
+
super().__init__()
|
15 |
+
self.detail = detail
|
16 |
+
self.status_code = status_code
|
17 |
+
|
18 |
+
def __str__(self) -> str:
|
19 |
+
return self.__repr__()
|
20 |
+
|
21 |
+
def __repr__(self) -> str:
|
22 |
+
if self.detail is None:
|
23 |
+
message = None
|
24 |
+
elif isinstance(self.detail, str):
|
25 |
+
message = self.detail
|
26 |
+
else:
|
27 |
+
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
|
28 |
+
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
|
29 |
+
|
30 |
+
|
31 |
+
class ImageRequest:
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
prompt: str,
|
35 |
+
width: int = 1024,
|
36 |
+
height: int = 1024,
|
37 |
+
name: str = "flux.1-pro",
|
38 |
+
num_steps: int = 50,
|
39 |
+
prompt_upsampling: bool = False,
|
40 |
+
seed: int | None = None,
|
41 |
+
validate: bool = True,
|
42 |
+
launch: bool = True,
|
43 |
+
api_key: str | None = None,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Manages an image generation request to the API.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
prompt: Prompt to sample
|
50 |
+
width: Width of the image in pixel
|
51 |
+
height: Height of the image in pixel
|
52 |
+
name: Name of the model
|
53 |
+
num_steps: Number of network evaluations
|
54 |
+
prompt_upsampling: Use prompt upsampling
|
55 |
+
seed: Fix the generation seed
|
56 |
+
validate: Run input validation
|
57 |
+
launch: Directly launches request
|
58 |
+
api_key: Your API key if not provided by the environment
|
59 |
+
|
60 |
+
Raises:
|
61 |
+
ValueError: For invalid input
|
62 |
+
ApiException: For errors raised from the API
|
63 |
+
"""
|
64 |
+
if validate:
|
65 |
+
if name not in ["flux.1-pro"]:
|
66 |
+
raise ValueError(f"Invalid model {name}")
|
67 |
+
elif width % 32 != 0:
|
68 |
+
raise ValueError(f"width must be divisible by 32, got {width}")
|
69 |
+
elif not (256 <= width <= 1440):
|
70 |
+
raise ValueError(f"width must be between 256 and 1440, got {width}")
|
71 |
+
elif height % 32 != 0:
|
72 |
+
raise ValueError(f"height must be divisible by 32, got {height}")
|
73 |
+
elif not (256 <= height <= 1440):
|
74 |
+
raise ValueError(f"height must be between 256 and 1440, got {height}")
|
75 |
+
elif not (1 <= num_steps <= 50):
|
76 |
+
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
|
77 |
+
|
78 |
+
self.request_json = {
|
79 |
+
"prompt": prompt,
|
80 |
+
"width": width,
|
81 |
+
"height": height,
|
82 |
+
"variant": name,
|
83 |
+
"steps": num_steps,
|
84 |
+
"prompt_upsampling": prompt_upsampling,
|
85 |
+
}
|
86 |
+
if seed is not None:
|
87 |
+
self.request_json["seed"] = seed
|
88 |
+
|
89 |
+
self.request_id: str | None = None
|
90 |
+
self.result: dict | None = None
|
91 |
+
self._image_bytes: bytes | None = None
|
92 |
+
self._url: str | None = None
|
93 |
+
if api_key is None:
|
94 |
+
self.api_key = os.environ.get("BFL_API_KEY")
|
95 |
+
else:
|
96 |
+
self.api_key = api_key
|
97 |
+
|
98 |
+
if launch:
|
99 |
+
self.request()
|
100 |
+
|
101 |
+
def request(self):
|
102 |
+
"""
|
103 |
+
Request to generate the image.
|
104 |
+
"""
|
105 |
+
if self.request_id is not None:
|
106 |
+
return
|
107 |
+
response = requests.post(
|
108 |
+
f"{API_ENDPOINT}/v1/image",
|
109 |
+
headers={
|
110 |
+
"accept": "application/json",
|
111 |
+
"x-key": self.api_key,
|
112 |
+
"Content-Type": "application/json",
|
113 |
+
},
|
114 |
+
json=self.request_json,
|
115 |
+
)
|
116 |
+
result = response.json()
|
117 |
+
if response.status_code != 200:
|
118 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
119 |
+
self.request_id = response.json()["id"]
|
120 |
+
|
121 |
+
def retrieve(self) -> dict:
|
122 |
+
"""
|
123 |
+
Wait for the generation to finish and retrieve response.
|
124 |
+
"""
|
125 |
+
if self.request_id is None:
|
126 |
+
self.request()
|
127 |
+
while self.result is None:
|
128 |
+
response = requests.get(
|
129 |
+
f"{API_ENDPOINT}/v1/get_result",
|
130 |
+
headers={
|
131 |
+
"accept": "application/json",
|
132 |
+
"x-key": self.api_key,
|
133 |
+
},
|
134 |
+
params={
|
135 |
+
"id": self.request_id,
|
136 |
+
},
|
137 |
+
)
|
138 |
+
result = response.json()
|
139 |
+
if "status" not in result:
|
140 |
+
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
|
141 |
+
elif result["status"] == "Ready":
|
142 |
+
self.result = result["result"]
|
143 |
+
elif result["status"] == "Pending":
|
144 |
+
time.sleep(0.5)
|
145 |
+
else:
|
146 |
+
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
|
147 |
+
return self.result
|
148 |
+
|
149 |
+
@property
|
150 |
+
def bytes(self) -> bytes:
|
151 |
+
"""
|
152 |
+
Generated image as bytes.
|
153 |
+
"""
|
154 |
+
if self._image_bytes is None:
|
155 |
+
response = requests.get(self.url)
|
156 |
+
if response.status_code == 200:
|
157 |
+
self._image_bytes = response.content
|
158 |
+
else:
|
159 |
+
raise ApiException(status_code=response.status_code)
|
160 |
+
return self._image_bytes
|
161 |
+
|
162 |
+
@property
|
163 |
+
def url(self) -> str:
|
164 |
+
"""
|
165 |
+
Public url to retrieve the image from
|
166 |
+
"""
|
167 |
+
if self._url is None:
|
168 |
+
result = self.retrieve()
|
169 |
+
self._url = result["sample"]
|
170 |
+
return self._url
|
171 |
+
|
172 |
+
@property
|
173 |
+
def image(self) -> Image.Image:
|
174 |
+
"""
|
175 |
+
Load the image as a PIL Image
|
176 |
+
"""
|
177 |
+
return Image.open(io.BytesIO(self.bytes))
|
178 |
+
|
179 |
+
def save(self, path: str):
|
180 |
+
"""
|
181 |
+
Save the generated image to a local path
|
182 |
+
"""
|
183 |
+
suffix = Path(self.url).suffix
|
184 |
+
if not path.endswith(suffix):
|
185 |
+
path = path + suffix
|
186 |
+
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
|
187 |
+
with open(path, "wb") as file:
|
188 |
+
file.write(self.bytes)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
from fire import Fire
|
193 |
+
|
194 |
+
Fire(ImageRequest)
|
src/flux/cli.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from glob import iglob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from fire import Fire
|
10 |
+
from PIL import ExifTags, Image
|
11 |
+
|
12 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
13 |
+
from flux.util import (configs, embed_watermark, load_ae, load_clip,
|
14 |
+
load_flow_model, load_t5)
|
15 |
+
from transformers import pipeline
|
16 |
+
|
17 |
+
NSFW_THRESHOLD = 0.85
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class SamplingOptions:
|
21 |
+
prompt: str
|
22 |
+
width: int
|
23 |
+
height: int
|
24 |
+
num_steps: int
|
25 |
+
guidance: float
|
26 |
+
seed: int | None
|
27 |
+
|
28 |
+
|
29 |
+
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
|
30 |
+
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
|
31 |
+
usage = (
|
32 |
+
"Usage: Either write your prompt directly, leave this field empty "
|
33 |
+
"to repeat the prompt or write a command starting with a slash:\n"
|
34 |
+
"- '/w <width>' will set the width of the generated image\n"
|
35 |
+
"- '/h <height>' will set the height of the generated image\n"
|
36 |
+
"- '/s <seed>' sets the next seed\n"
|
37 |
+
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
|
38 |
+
"- '/n <steps>' sets the number of steps\n"
|
39 |
+
"- '/q' to quit"
|
40 |
+
)
|
41 |
+
|
42 |
+
while (prompt := input(user_question)).startswith("/"):
|
43 |
+
if prompt.startswith("/w"):
|
44 |
+
if prompt.count(" ") != 1:
|
45 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
46 |
+
continue
|
47 |
+
_, width = prompt.split()
|
48 |
+
options.width = 16 * (int(width) // 16)
|
49 |
+
print(
|
50 |
+
f"Setting resolution to {options.width} x {options.height} "
|
51 |
+
f"({options.height *options.width/1e6:.2f}MP)"
|
52 |
+
)
|
53 |
+
elif prompt.startswith("/h"):
|
54 |
+
if prompt.count(" ") != 1:
|
55 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
56 |
+
continue
|
57 |
+
_, height = prompt.split()
|
58 |
+
options.height = 16 * (int(height) // 16)
|
59 |
+
print(
|
60 |
+
f"Setting resolution to {options.width} x {options.height} "
|
61 |
+
f"({options.height *options.width/1e6:.2f}MP)"
|
62 |
+
)
|
63 |
+
elif prompt.startswith("/g"):
|
64 |
+
if prompt.count(" ") != 1:
|
65 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
66 |
+
continue
|
67 |
+
_, guidance = prompt.split()
|
68 |
+
options.guidance = float(guidance)
|
69 |
+
print(f"Setting guidance to {options.guidance}")
|
70 |
+
elif prompt.startswith("/s"):
|
71 |
+
if prompt.count(" ") != 1:
|
72 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
73 |
+
continue
|
74 |
+
_, seed = prompt.split()
|
75 |
+
options.seed = int(seed)
|
76 |
+
print(f"Setting seed to {options.seed}")
|
77 |
+
elif prompt.startswith("/n"):
|
78 |
+
if prompt.count(" ") != 1:
|
79 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
80 |
+
continue
|
81 |
+
_, steps = prompt.split()
|
82 |
+
options.num_steps = int(steps)
|
83 |
+
print(f"Setting seed to {options.num_steps}")
|
84 |
+
elif prompt.startswith("/q"):
|
85 |
+
print("Quitting")
|
86 |
+
return None
|
87 |
+
else:
|
88 |
+
if not prompt.startswith("/h"):
|
89 |
+
print(f"Got invalid command '{prompt}'\n{usage}")
|
90 |
+
print(usage)
|
91 |
+
if prompt != "":
|
92 |
+
options.prompt = prompt
|
93 |
+
return options
|
94 |
+
|
95 |
+
|
96 |
+
@torch.inference_mode()
|
97 |
+
def main(
|
98 |
+
name: str = "flux-schnell",
|
99 |
+
width: int = 1360,
|
100 |
+
height: int = 768,
|
101 |
+
seed: int | None = None,
|
102 |
+
prompt: str = (
|
103 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
104 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
105 |
+
),
|
106 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
107 |
+
num_steps: int | None = None,
|
108 |
+
loop: bool = False,
|
109 |
+
guidance: float = 3.5,
|
110 |
+
offload: bool = False,
|
111 |
+
output_dir: str = "output",
|
112 |
+
add_sampling_metadata: bool = True,
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
Sample the flux model. Either interactively (set `--loop`) or run for a
|
116 |
+
single image.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
name: Name of the model to load
|
120 |
+
height: height of the sample in pixels (should be a multiple of 16)
|
121 |
+
width: width of the sample in pixels (should be a multiple of 16)
|
122 |
+
seed: Set a seed for sampling
|
123 |
+
output_name: where to save the output image, `{idx}` will be replaced
|
124 |
+
by the index of the sample
|
125 |
+
prompt: Prompt used for sampling
|
126 |
+
device: Pytorch device
|
127 |
+
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
|
128 |
+
loop: start an interactive session and sample multiple times
|
129 |
+
guidance: guidance value used for guidance distillation
|
130 |
+
add_sampling_metadata: Add the prompt to the image Exif metadata
|
131 |
+
"""
|
132 |
+
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
133 |
+
|
134 |
+
if name not in configs:
|
135 |
+
available = ", ".join(configs.keys())
|
136 |
+
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
|
137 |
+
|
138 |
+
torch_device = torch.device(device)
|
139 |
+
if num_steps is None:
|
140 |
+
num_steps = 4 if name == "flux-schnell" else 50
|
141 |
+
|
142 |
+
# allow for packing and conversion to latent space
|
143 |
+
height = 16 * (height // 16)
|
144 |
+
width = 16 * (width // 16)
|
145 |
+
|
146 |
+
output_name = os.path.join(output_dir, "img_{idx}.jpg")
|
147 |
+
if not os.path.exists(output_dir):
|
148 |
+
os.makedirs(output_dir)
|
149 |
+
idx = 0
|
150 |
+
else:
|
151 |
+
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
|
152 |
+
if len(fns) > 0:
|
153 |
+
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
|
154 |
+
else:
|
155 |
+
idx = 0
|
156 |
+
|
157 |
+
# init all components
|
158 |
+
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
|
159 |
+
clip = load_clip(torch_device)
|
160 |
+
model = load_flow_model(name, device="cpu" if offload else torch_device)
|
161 |
+
ae = load_ae(name, device="cpu" if offload else torch_device)
|
162 |
+
|
163 |
+
rng = torch.Generator(device="cpu")
|
164 |
+
opts = SamplingOptions(
|
165 |
+
prompt=prompt,
|
166 |
+
width=width,
|
167 |
+
height=height,
|
168 |
+
num_steps=num_steps,
|
169 |
+
guidance=guidance,
|
170 |
+
seed=seed,
|
171 |
+
)
|
172 |
+
|
173 |
+
if loop:
|
174 |
+
opts = parse_prompt(opts)
|
175 |
+
|
176 |
+
while opts is not None:
|
177 |
+
if opts.seed is None:
|
178 |
+
opts.seed = rng.seed()
|
179 |
+
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
|
180 |
+
t0 = time.perf_counter()
|
181 |
+
|
182 |
+
# prepare input
|
183 |
+
x = get_noise(
|
184 |
+
1,
|
185 |
+
opts.height,
|
186 |
+
opts.width,
|
187 |
+
device=torch_device,
|
188 |
+
dtype=torch.bfloat16,
|
189 |
+
seed=opts.seed,
|
190 |
+
)
|
191 |
+
opts.seed = None
|
192 |
+
if offload:
|
193 |
+
ae = ae.cpu()
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
t5, clip = t5.to(torch_device), clip.to(torch_device)
|
196 |
+
inp = prepare(t5, clip, x, prompt=opts.prompt)
|
197 |
+
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
|
198 |
+
|
199 |
+
# offload TEs to CPU, load model to gpu
|
200 |
+
if offload:
|
201 |
+
t5, clip = t5.cpu(), clip.cpu()
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
model = model.to(torch_device)
|
204 |
+
|
205 |
+
# denoise initial noise
|
206 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
|
207 |
+
|
208 |
+
# offload model, load autoencoder to gpu
|
209 |
+
if offload:
|
210 |
+
model.cpu()
|
211 |
+
torch.cuda.empty_cache()
|
212 |
+
ae.decoder.to(x.device)
|
213 |
+
|
214 |
+
# decode latents to pixel space
|
215 |
+
x = unpack(x.float(), opts.height, opts.width)
|
216 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
217 |
+
x = ae.decode(x)
|
218 |
+
t1 = time.perf_counter()
|
219 |
+
|
220 |
+
fn = output_name.format(idx=idx)
|
221 |
+
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
|
222 |
+
# bring into PIL format and save
|
223 |
+
x = x.clamp(-1, 1)
|
224 |
+
x = embed_watermark(x.float())
|
225 |
+
x = rearrange(x[0], "c h w -> h w c")
|
226 |
+
|
227 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
228 |
+
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
229 |
+
|
230 |
+
if nsfw_score < NSFW_THRESHOLD:
|
231 |
+
exif_data = Image.Exif()
|
232 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
233 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
234 |
+
exif_data[ExifTags.Base.Model] = name
|
235 |
+
if add_sampling_metadata:
|
236 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
237 |
+
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
238 |
+
idx += 1
|
239 |
+
else:
|
240 |
+
print("Your generated image may contain NSFW content.")
|
241 |
+
|
242 |
+
if loop:
|
243 |
+
print("-" * 80)
|
244 |
+
opts = parse_prompt(opts)
|
245 |
+
else:
|
246 |
+
opts = None
|
247 |
+
|
248 |
+
|
249 |
+
def app():
|
250 |
+
Fire(main)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
app()
|
src/flux/controlnet.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
8 |
+
MLPEmbedder, SingleStreamBlock,
|
9 |
+
timestep_embedding)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FluxParams:
|
14 |
+
in_channels: int
|
15 |
+
vec_in_dim: int
|
16 |
+
context_in_dim: int
|
17 |
+
hidden_size: int
|
18 |
+
mlp_ratio: float
|
19 |
+
num_heads: int
|
20 |
+
depth: int
|
21 |
+
depth_single_blocks: int
|
22 |
+
axes_dim: list[int]
|
23 |
+
theta: int
|
24 |
+
qkv_bias: bool
|
25 |
+
guidance_embed: bool
|
26 |
+
|
27 |
+
def zero_module(module):
|
28 |
+
for p in module.parameters():
|
29 |
+
nn.init.zeros_(p)
|
30 |
+
return module
|
31 |
+
|
32 |
+
|
33 |
+
class ControlNetFlux(nn.Module):
|
34 |
+
"""
|
35 |
+
Transformer model for flow matching on sequences.
|
36 |
+
"""
|
37 |
+
_supports_gradient_checkpointing = True
|
38 |
+
|
39 |
+
def __init__(self, params: FluxParams, controlnet_depth=2):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.params = params
|
43 |
+
self.in_channels = params.in_channels
|
44 |
+
self.out_channels = self.in_channels
|
45 |
+
if params.hidden_size % params.num_heads != 0:
|
46 |
+
raise ValueError(
|
47 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
48 |
+
)
|
49 |
+
pe_dim = params.hidden_size // params.num_heads
|
50 |
+
if sum(params.axes_dim) != pe_dim:
|
51 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
52 |
+
self.hidden_size = params.hidden_size
|
53 |
+
self.num_heads = params.num_heads
|
54 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
55 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
56 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
57 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
58 |
+
self.guidance_in = (
|
59 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
60 |
+
)
|
61 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
62 |
+
|
63 |
+
self.double_blocks = nn.ModuleList(
|
64 |
+
[
|
65 |
+
DoubleStreamBlock(
|
66 |
+
self.hidden_size,
|
67 |
+
self.num_heads,
|
68 |
+
mlp_ratio=params.mlp_ratio,
|
69 |
+
qkv_bias=params.qkv_bias,
|
70 |
+
)
|
71 |
+
for _ in range(controlnet_depth)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
# add ControlNet blocks
|
76 |
+
self.controlnet_blocks = nn.ModuleList([])
|
77 |
+
for _ in range(controlnet_depth):
|
78 |
+
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
79 |
+
controlnet_block = zero_module(controlnet_block)
|
80 |
+
self.controlnet_blocks.append(controlnet_block)
|
81 |
+
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
82 |
+
self.gradient_checkpointing = False
|
83 |
+
self.input_hint_block = nn.Sequential(
|
84 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
85 |
+
nn.SiLU(),
|
86 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
87 |
+
nn.SiLU(),
|
88 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
89 |
+
nn.SiLU(),
|
90 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
91 |
+
nn.SiLU(),
|
92 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
93 |
+
nn.SiLU(),
|
94 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
95 |
+
nn.SiLU(),
|
96 |
+
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
97 |
+
nn.SiLU(),
|
98 |
+
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
99 |
+
)
|
100 |
+
|
101 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
102 |
+
if hasattr(module, "gradient_checkpointing"):
|
103 |
+
module.gradient_checkpointing = value
|
104 |
+
|
105 |
+
|
106 |
+
@property
|
107 |
+
def attn_processors(self):
|
108 |
+
# set recursively
|
109 |
+
processors = {}
|
110 |
+
|
111 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
112 |
+
if hasattr(module, "set_processor"):
|
113 |
+
processors[f"{name}.processor"] = module.processor
|
114 |
+
|
115 |
+
for sub_name, child in module.named_children():
|
116 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
117 |
+
|
118 |
+
return processors
|
119 |
+
|
120 |
+
for name, module in self.named_children():
|
121 |
+
fn_recursive_add_processors(name, module, processors)
|
122 |
+
|
123 |
+
return processors
|
124 |
+
|
125 |
+
def set_attn_processor(self, processor):
|
126 |
+
r"""
|
127 |
+
Sets the attention processor to use to compute attention.
|
128 |
+
|
129 |
+
Parameters:
|
130 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
131 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
132 |
+
for **all** `Attention` layers.
|
133 |
+
|
134 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
135 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
136 |
+
|
137 |
+
"""
|
138 |
+
count = len(self.attn_processors.keys())
|
139 |
+
|
140 |
+
if isinstance(processor, dict) and len(processor) != count:
|
141 |
+
raise ValueError(
|
142 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
143 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
144 |
+
)
|
145 |
+
|
146 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
147 |
+
if hasattr(module, "set_processor"):
|
148 |
+
if not isinstance(processor, dict):
|
149 |
+
module.set_processor(processor)
|
150 |
+
else:
|
151 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
152 |
+
|
153 |
+
for sub_name, child in module.named_children():
|
154 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
155 |
+
|
156 |
+
for name, module in self.named_children():
|
157 |
+
fn_recursive_attn_processor(name, module, processor)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
img: Tensor,
|
162 |
+
img_ids: Tensor,
|
163 |
+
controlnet_cond: Tensor,
|
164 |
+
txt: Tensor,
|
165 |
+
txt_ids: Tensor,
|
166 |
+
timesteps: Tensor,
|
167 |
+
y: Tensor,
|
168 |
+
guidance: Tensor | None = None,
|
169 |
+
) -> Tensor:
|
170 |
+
if img.ndim != 3 or txt.ndim != 3:
|
171 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
172 |
+
|
173 |
+
# running on sequences img
|
174 |
+
img = self.img_in(img)
|
175 |
+
controlnet_cond = self.input_hint_block(controlnet_cond)
|
176 |
+
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
177 |
+
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
178 |
+
img = img + controlnet_cond
|
179 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
180 |
+
if self.params.guidance_embed:
|
181 |
+
if guidance is None:
|
182 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
183 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
184 |
+
vec = vec + self.vector_in(y)
|
185 |
+
txt = self.txt_in(txt)
|
186 |
+
|
187 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
188 |
+
pe = self.pe_embedder(ids)
|
189 |
+
|
190 |
+
block_res_samples = ()
|
191 |
+
|
192 |
+
for block in self.double_blocks:
|
193 |
+
if self.training and self.gradient_checkpointing:
|
194 |
+
|
195 |
+
def create_custom_forward(module, return_dict=None):
|
196 |
+
def custom_forward(*inputs):
|
197 |
+
if return_dict is not None:
|
198 |
+
return module(*inputs, return_dict=return_dict)
|
199 |
+
else:
|
200 |
+
return module(*inputs)
|
201 |
+
|
202 |
+
return custom_forward
|
203 |
+
|
204 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
205 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
206 |
+
create_custom_forward(block),
|
207 |
+
img,
|
208 |
+
txt,
|
209 |
+
vec,
|
210 |
+
pe,
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
214 |
+
|
215 |
+
block_res_samples = block_res_samples + (img,)
|
216 |
+
|
217 |
+
controlnet_block_res_samples = ()
|
218 |
+
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
219 |
+
block_res_sample = controlnet_block(block_res_sample)
|
220 |
+
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
221 |
+
|
222 |
+
return controlnet_block_res_samples
|
src/flux/math.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
7 |
+
q, k = apply_rope(q, k, pe)
|
8 |
+
|
9 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
10 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
11 |
+
|
12 |
+
return x
|
13 |
+
|
14 |
+
|
15 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
16 |
+
assert dim % 2 == 0
|
17 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
18 |
+
omega = 1.0 / (theta**scale)
|
19 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
20 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
21 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
22 |
+
return out.float()
|
23 |
+
|
24 |
+
|
25 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
26 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
27 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
28 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
29 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
30 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
src/flux/model.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
8 |
+
MLPEmbedder, SingleStreamBlock,
|
9 |
+
timestep_embedding)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FluxParams:
|
14 |
+
in_channels: int
|
15 |
+
vec_in_dim: int
|
16 |
+
context_in_dim: int
|
17 |
+
hidden_size: int
|
18 |
+
mlp_ratio: float
|
19 |
+
num_heads: int
|
20 |
+
depth: int
|
21 |
+
depth_single_blocks: int
|
22 |
+
axes_dim: list[int]
|
23 |
+
theta: int
|
24 |
+
qkv_bias: bool
|
25 |
+
guidance_embed: bool
|
26 |
+
|
27 |
+
|
28 |
+
class Flux(nn.Module):
|
29 |
+
"""
|
30 |
+
Transformer model for flow matching on sequences.
|
31 |
+
"""
|
32 |
+
_supports_gradient_checkpointing = True
|
33 |
+
|
34 |
+
def __init__(self, params: FluxParams):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.params = params
|
38 |
+
self.in_channels = params.in_channels
|
39 |
+
self.out_channels = self.in_channels
|
40 |
+
if params.hidden_size % params.num_heads != 0:
|
41 |
+
raise ValueError(
|
42 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
43 |
+
)
|
44 |
+
pe_dim = params.hidden_size // params.num_heads
|
45 |
+
if sum(params.axes_dim) != pe_dim:
|
46 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
47 |
+
self.hidden_size = params.hidden_size
|
48 |
+
self.num_heads = params.num_heads
|
49 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
50 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
51 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
52 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
53 |
+
self.guidance_in = (
|
54 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
55 |
+
)
|
56 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
57 |
+
|
58 |
+
self.double_blocks = nn.ModuleList(
|
59 |
+
[
|
60 |
+
DoubleStreamBlock(
|
61 |
+
self.hidden_size,
|
62 |
+
self.num_heads,
|
63 |
+
mlp_ratio=params.mlp_ratio,
|
64 |
+
qkv_bias=params.qkv_bias,
|
65 |
+
)
|
66 |
+
for _ in range(params.depth)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
self.single_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
73 |
+
for _ in range(params.depth_single_blocks)
|
74 |
+
]
|
75 |
+
)
|
76 |
+
|
77 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
78 |
+
self.gradient_checkpointing = False
|
79 |
+
|
80 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
81 |
+
if hasattr(module, "gradient_checkpointing"):
|
82 |
+
module.gradient_checkpointing = value
|
83 |
+
|
84 |
+
@property
|
85 |
+
def attn_processors(self):
|
86 |
+
# set recursively
|
87 |
+
processors = {}
|
88 |
+
|
89 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
90 |
+
if hasattr(module, "set_processor"):
|
91 |
+
processors[f"{name}.processor"] = module.processor
|
92 |
+
|
93 |
+
for sub_name, child in module.named_children():
|
94 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
95 |
+
|
96 |
+
return processors
|
97 |
+
|
98 |
+
for name, module in self.named_children():
|
99 |
+
fn_recursive_add_processors(name, module, processors)
|
100 |
+
|
101 |
+
return processors
|
102 |
+
|
103 |
+
def set_attn_processor(self, processor):
|
104 |
+
r"""
|
105 |
+
Sets the attention processor to use to compute attention.
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
109 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
110 |
+
for **all** `Attention` layers.
|
111 |
+
|
112 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
113 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
114 |
+
|
115 |
+
"""
|
116 |
+
count = len(self.attn_processors.keys())
|
117 |
+
|
118 |
+
if isinstance(processor, dict) and len(processor) != count:
|
119 |
+
raise ValueError(
|
120 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
121 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
122 |
+
)
|
123 |
+
|
124 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
125 |
+
if hasattr(module, "set_processor"):
|
126 |
+
if not isinstance(processor, dict):
|
127 |
+
module.set_processor(processor)
|
128 |
+
else:
|
129 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
130 |
+
|
131 |
+
for sub_name, child in module.named_children():
|
132 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
133 |
+
|
134 |
+
for name, module in self.named_children():
|
135 |
+
fn_recursive_attn_processor(name, module, processor)
|
136 |
+
|
137 |
+
def forward(
|
138 |
+
self,
|
139 |
+
img: Tensor,
|
140 |
+
img_ids: Tensor,
|
141 |
+
txt: Tensor,
|
142 |
+
txt_ids: Tensor,
|
143 |
+
timesteps: Tensor,
|
144 |
+
y: Tensor,
|
145 |
+
block_controlnet_hidden_states=None,
|
146 |
+
guidance: Tensor | None = None,
|
147 |
+
image_proj: Tensor | None = None,
|
148 |
+
ip_scale: Tensor | float = 1.0,
|
149 |
+
) -> Tensor:
|
150 |
+
if img.ndim != 3 or txt.ndim != 3:
|
151 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
152 |
+
|
153 |
+
# running on sequences img
|
154 |
+
img = self.img_in(img)
|
155 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
156 |
+
if self.params.guidance_embed:
|
157 |
+
if guidance is None:
|
158 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
159 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
160 |
+
vec = vec + self.vector_in(y)
|
161 |
+
txt = self.txt_in(txt)
|
162 |
+
|
163 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
164 |
+
pe = self.pe_embedder(ids)
|
165 |
+
if block_controlnet_hidden_states is not None:
|
166 |
+
controlnet_depth = len(block_controlnet_hidden_states)
|
167 |
+
for index_block, block in enumerate(self.double_blocks):
|
168 |
+
if self.training and self.gradient_checkpointing:
|
169 |
+
|
170 |
+
def create_custom_forward(module, return_dict=None):
|
171 |
+
def custom_forward(*inputs):
|
172 |
+
if return_dict is not None:
|
173 |
+
return module(*inputs, return_dict=return_dict)
|
174 |
+
else:
|
175 |
+
return module(*inputs)
|
176 |
+
|
177 |
+
return custom_forward
|
178 |
+
|
179 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
180 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
181 |
+
create_custom_forward(block),
|
182 |
+
img,
|
183 |
+
txt,
|
184 |
+
vec,
|
185 |
+
pe,
|
186 |
+
image_proj,
|
187 |
+
ip_scale,
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
img, txt = block(
|
191 |
+
img=img,
|
192 |
+
txt=txt,
|
193 |
+
vec=vec,
|
194 |
+
pe=pe,
|
195 |
+
image_proj=image_proj,
|
196 |
+
ip_scale=ip_scale,
|
197 |
+
)
|
198 |
+
# controlnet residual
|
199 |
+
if block_controlnet_hidden_states is not None:
|
200 |
+
img = img + block_controlnet_hidden_states[index_block % 2]
|
201 |
+
|
202 |
+
|
203 |
+
img = torch.cat((txt, img), 1)
|
204 |
+
for block in self.single_blocks:
|
205 |
+
if self.training and self.gradient_checkpointing:
|
206 |
+
|
207 |
+
def create_custom_forward(module, return_dict=None):
|
208 |
+
def custom_forward(*inputs):
|
209 |
+
if return_dict is not None:
|
210 |
+
return module(*inputs, return_dict=return_dict)
|
211 |
+
else:
|
212 |
+
return module(*inputs)
|
213 |
+
|
214 |
+
return custom_forward
|
215 |
+
|
216 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
217 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
218 |
+
create_custom_forward(block),
|
219 |
+
img,
|
220 |
+
vec,
|
221 |
+
pe,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
img = block(img, vec=vec, pe=pe)
|
225 |
+
img = img[:, txt.shape[1] :, ...]
|
226 |
+
|
227 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
228 |
+
return img
|
src/flux/modules/autoencoder.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class AutoEncoderParams:
|
10 |
+
resolution: int
|
11 |
+
in_channels: int
|
12 |
+
ch: int
|
13 |
+
out_ch: int
|
14 |
+
ch_mult: list[int]
|
15 |
+
num_res_blocks: int
|
16 |
+
z_channels: int
|
17 |
+
scale_factor: float
|
18 |
+
shift_factor: float
|
19 |
+
|
20 |
+
|
21 |
+
def swish(x: Tensor) -> Tensor:
|
22 |
+
return x * torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
class AttnBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels: int):
|
27 |
+
super().__init__()
|
28 |
+
self.in_channels = in_channels
|
29 |
+
|
30 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
31 |
+
|
32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
+
|
37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
38 |
+
h_ = self.norm(h_)
|
39 |
+
q = self.q(h_)
|
40 |
+
k = self.k(h_)
|
41 |
+
v = self.v(h_)
|
42 |
+
|
43 |
+
b, c, h, w = q.shape
|
44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
48 |
+
|
49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor) -> Tensor:
|
52 |
+
return x + self.proj_out(self.attention(x))
|
53 |
+
|
54 |
+
|
55 |
+
class ResnetBlock(nn.Module):
|
56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
|
62 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
63 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
65 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
66 |
+
if self.in_channels != self.out_channels:
|
67 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
h = x
|
71 |
+
h = self.norm1(h)
|
72 |
+
h = swish(h)
|
73 |
+
h = self.conv1(h)
|
74 |
+
|
75 |
+
h = self.norm2(h)
|
76 |
+
h = swish(h)
|
77 |
+
h = self.conv2(h)
|
78 |
+
|
79 |
+
if self.in_channels != self.out_channels:
|
80 |
+
x = self.nin_shortcut(x)
|
81 |
+
|
82 |
+
return x + h
|
83 |
+
|
84 |
+
|
85 |
+
class Downsample(nn.Module):
|
86 |
+
def __init__(self, in_channels: int):
|
87 |
+
super().__init__()
|
88 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
89 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
90 |
+
|
91 |
+
def forward(self, x: Tensor):
|
92 |
+
pad = (0, 1, 0, 1)
|
93 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
94 |
+
x = self.conv(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class Upsample(nn.Module):
|
99 |
+
def __init__(self, in_channels: int):
|
100 |
+
super().__init__()
|
101 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
102 |
+
|
103 |
+
def forward(self, x: Tensor):
|
104 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
105 |
+
x = self.conv(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class Encoder(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
resolution: int,
|
113 |
+
in_channels: int,
|
114 |
+
ch: int,
|
115 |
+
ch_mult: list[int],
|
116 |
+
num_res_blocks: int,
|
117 |
+
z_channels: int,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.ch = ch
|
121 |
+
self.num_resolutions = len(ch_mult)
|
122 |
+
self.num_res_blocks = num_res_blocks
|
123 |
+
self.resolution = resolution
|
124 |
+
self.in_channels = in_channels
|
125 |
+
# downsampling
|
126 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
127 |
+
|
128 |
+
curr_res = resolution
|
129 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
130 |
+
self.in_ch_mult = in_ch_mult
|
131 |
+
self.down = nn.ModuleList()
|
132 |
+
block_in = self.ch
|
133 |
+
for i_level in range(self.num_resolutions):
|
134 |
+
block = nn.ModuleList()
|
135 |
+
attn = nn.ModuleList()
|
136 |
+
block_in = ch * in_ch_mult[i_level]
|
137 |
+
block_out = ch * ch_mult[i_level]
|
138 |
+
for _ in range(self.num_res_blocks):
|
139 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
140 |
+
block_in = block_out
|
141 |
+
down = nn.Module()
|
142 |
+
down.block = block
|
143 |
+
down.attn = attn
|
144 |
+
if i_level != self.num_resolutions - 1:
|
145 |
+
down.downsample = Downsample(block_in)
|
146 |
+
curr_res = curr_res // 2
|
147 |
+
self.down.append(down)
|
148 |
+
|
149 |
+
# middle
|
150 |
+
self.mid = nn.Module()
|
151 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
152 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
153 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
154 |
+
|
155 |
+
# end
|
156 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
157 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
158 |
+
|
159 |
+
def forward(self, x: Tensor) -> Tensor:
|
160 |
+
# downsampling
|
161 |
+
hs = [self.conv_in(x)]
|
162 |
+
for i_level in range(self.num_resolutions):
|
163 |
+
for i_block in range(self.num_res_blocks):
|
164 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
165 |
+
if len(self.down[i_level].attn) > 0:
|
166 |
+
h = self.down[i_level].attn[i_block](h)
|
167 |
+
hs.append(h)
|
168 |
+
if i_level != self.num_resolutions - 1:
|
169 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
170 |
+
|
171 |
+
# middle
|
172 |
+
h = hs[-1]
|
173 |
+
h = self.mid.block_1(h)
|
174 |
+
h = self.mid.attn_1(h)
|
175 |
+
h = self.mid.block_2(h)
|
176 |
+
# end
|
177 |
+
h = self.norm_out(h)
|
178 |
+
h = swish(h)
|
179 |
+
h = self.conv_out(h)
|
180 |
+
return h
|
181 |
+
|
182 |
+
|
183 |
+
class Decoder(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
ch: int,
|
187 |
+
out_ch: int,
|
188 |
+
ch_mult: list[int],
|
189 |
+
num_res_blocks: int,
|
190 |
+
in_channels: int,
|
191 |
+
resolution: int,
|
192 |
+
z_channels: int,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
self.ch = ch
|
196 |
+
self.num_resolutions = len(ch_mult)
|
197 |
+
self.num_res_blocks = num_res_blocks
|
198 |
+
self.resolution = resolution
|
199 |
+
self.in_channels = in_channels
|
200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
201 |
+
|
202 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
203 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
204 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
205 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
206 |
+
|
207 |
+
# z to block_in
|
208 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
209 |
+
|
210 |
+
# middle
|
211 |
+
self.mid = nn.Module()
|
212 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
213 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
214 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
215 |
+
|
216 |
+
# upsampling
|
217 |
+
self.up = nn.ModuleList()
|
218 |
+
for i_level in reversed(range(self.num_resolutions)):
|
219 |
+
block = nn.ModuleList()
|
220 |
+
attn = nn.ModuleList()
|
221 |
+
block_out = ch * ch_mult[i_level]
|
222 |
+
for _ in range(self.num_res_blocks + 1):
|
223 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
224 |
+
block_in = block_out
|
225 |
+
up = nn.Module()
|
226 |
+
up.block = block
|
227 |
+
up.attn = attn
|
228 |
+
if i_level != 0:
|
229 |
+
up.upsample = Upsample(block_in)
|
230 |
+
curr_res = curr_res * 2
|
231 |
+
self.up.insert(0, up) # prepend to get consistent order
|
232 |
+
|
233 |
+
# end
|
234 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
235 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
236 |
+
|
237 |
+
def forward(self, z: Tensor) -> Tensor:
|
238 |
+
# z to block_in
|
239 |
+
h = self.conv_in(z)
|
240 |
+
|
241 |
+
# middle
|
242 |
+
h = self.mid.block_1(h)
|
243 |
+
h = self.mid.attn_1(h)
|
244 |
+
h = self.mid.block_2(h)
|
245 |
+
|
246 |
+
# upsampling
|
247 |
+
for i_level in reversed(range(self.num_resolutions)):
|
248 |
+
for i_block in range(self.num_res_blocks + 1):
|
249 |
+
h = self.up[i_level].block[i_block](h)
|
250 |
+
if len(self.up[i_level].attn) > 0:
|
251 |
+
h = self.up[i_level].attn[i_block](h)
|
252 |
+
if i_level != 0:
|
253 |
+
h = self.up[i_level].upsample(h)
|
254 |
+
|
255 |
+
# end
|
256 |
+
h = self.norm_out(h)
|
257 |
+
h = swish(h)
|
258 |
+
h = self.conv_out(h)
|
259 |
+
return h
|
260 |
+
|
261 |
+
|
262 |
+
class DiagonalGaussian(nn.Module):
|
263 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
264 |
+
super().__init__()
|
265 |
+
self.sample = sample
|
266 |
+
self.chunk_dim = chunk_dim
|
267 |
+
|
268 |
+
def forward(self, z: Tensor) -> Tensor:
|
269 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
270 |
+
if self.sample:
|
271 |
+
std = torch.exp(0.5 * logvar)
|
272 |
+
return mean + std * torch.randn_like(mean)
|
273 |
+
else:
|
274 |
+
return mean
|
275 |
+
|
276 |
+
|
277 |
+
class AutoEncoder(nn.Module):
|
278 |
+
def __init__(self, params: AutoEncoderParams):
|
279 |
+
super().__init__()
|
280 |
+
self.encoder = Encoder(
|
281 |
+
resolution=params.resolution,
|
282 |
+
in_channels=params.in_channels,
|
283 |
+
ch=params.ch,
|
284 |
+
ch_mult=params.ch_mult,
|
285 |
+
num_res_blocks=params.num_res_blocks,
|
286 |
+
z_channels=params.z_channels,
|
287 |
+
)
|
288 |
+
self.decoder = Decoder(
|
289 |
+
resolution=params.resolution,
|
290 |
+
in_channels=params.in_channels,
|
291 |
+
ch=params.ch,
|
292 |
+
out_ch=params.out_ch,
|
293 |
+
ch_mult=params.ch_mult,
|
294 |
+
num_res_blocks=params.num_res_blocks,
|
295 |
+
z_channels=params.z_channels,
|
296 |
+
)
|
297 |
+
self.reg = DiagonalGaussian()
|
298 |
+
|
299 |
+
self.scale_factor = params.scale_factor
|
300 |
+
self.shift_factor = params.shift_factor
|
301 |
+
|
302 |
+
def encode(self, x: Tensor) -> Tensor:
|
303 |
+
z = self.reg(self.encoder(x))
|
304 |
+
z = self.scale_factor * (z - self.shift_factor)
|
305 |
+
return z
|
306 |
+
|
307 |
+
def decode(self, z: Tensor) -> Tensor:
|
308 |
+
z = z / self.scale_factor + self.shift_factor
|
309 |
+
return self.decoder(z)
|
310 |
+
|
311 |
+
def forward(self, x: Tensor) -> Tensor:
|
312 |
+
return self.decode(self.encode(x))
|
src/flux/modules/conditioner.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor, nn
|
2 |
+
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
3 |
+
T5Tokenizer)
|
4 |
+
|
5 |
+
|
6 |
+
class HFEmbedder(nn.Module):
|
7 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
8 |
+
super().__init__()
|
9 |
+
self.is_clip = version.startswith("openai")
|
10 |
+
self.max_length = max_length
|
11 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
12 |
+
|
13 |
+
if self.is_clip:
|
14 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
15 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
16 |
+
else:
|
17 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
18 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
19 |
+
|
20 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
21 |
+
|
22 |
+
def forward(self, text: list[str]) -> Tensor:
|
23 |
+
batch_encoding = self.tokenizer(
|
24 |
+
text,
|
25 |
+
truncation=True,
|
26 |
+
max_length=self.max_length,
|
27 |
+
return_length=False,
|
28 |
+
return_overflowing_tokens=False,
|
29 |
+
padding="max_length",
|
30 |
+
return_tensors="pt",
|
31 |
+
)
|
32 |
+
|
33 |
+
outputs = self.hf_module(
|
34 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
35 |
+
attention_mask=None,
|
36 |
+
output_hidden_states=False,
|
37 |
+
)
|
38 |
+
return outputs[self.output_key]
|
src/flux/modules/layers.py
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
from ..math import attention, rope
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
class EmbedND(nn.Module):
|
12 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
self.theta = theta
|
16 |
+
self.axes_dim = axes_dim
|
17 |
+
|
18 |
+
def forward(self, ids: Tensor) -> Tensor:
|
19 |
+
n_axes = ids.shape[-1]
|
20 |
+
emb = torch.cat(
|
21 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
22 |
+
dim=-3,
|
23 |
+
)
|
24 |
+
|
25 |
+
return emb.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
29 |
+
"""
|
30 |
+
Create sinusoidal timestep embeddings.
|
31 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
32 |
+
These may be fractional.
|
33 |
+
:param dim: the dimension of the output.
|
34 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
35 |
+
:return: an (N, D) Tensor of positional embeddings.
|
36 |
+
"""
|
37 |
+
t = time_factor * t
|
38 |
+
half = dim // 2
|
39 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
40 |
+
t.device
|
41 |
+
)
|
42 |
+
|
43 |
+
args = t[:, None].float() * freqs[None]
|
44 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
45 |
+
if dim % 2:
|
46 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
47 |
+
if torch.is_floating_point(t):
|
48 |
+
embedding = embedding.to(t)
|
49 |
+
return embedding
|
50 |
+
|
51 |
+
|
52 |
+
class MLPEmbedder(nn.Module):
|
53 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
54 |
+
super().__init__()
|
55 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
56 |
+
self.silu = nn.SiLU()
|
57 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
58 |
+
|
59 |
+
def forward(self, x: Tensor) -> Tensor:
|
60 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
61 |
+
|
62 |
+
|
63 |
+
class RMSNorm(torch.nn.Module):
|
64 |
+
def __init__(self, dim: int):
|
65 |
+
super().__init__()
|
66 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
67 |
+
|
68 |
+
def forward(self, x: Tensor):
|
69 |
+
x_dtype = x.dtype
|
70 |
+
x = x.float()
|
71 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
72 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
73 |
+
|
74 |
+
|
75 |
+
class QKNorm(torch.nn.Module):
|
76 |
+
def __init__(self, dim: int):
|
77 |
+
super().__init__()
|
78 |
+
self.query_norm = RMSNorm(dim)
|
79 |
+
self.key_norm = RMSNorm(dim)
|
80 |
+
|
81 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
82 |
+
q = self.query_norm(q)
|
83 |
+
k = self.key_norm(k)
|
84 |
+
return q.to(v), k.to(v)
|
85 |
+
|
86 |
+
class LoRALinearLayer(nn.Module):
|
87 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
91 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
92 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
93 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
94 |
+
self.network_alpha = network_alpha
|
95 |
+
self.rank = rank
|
96 |
+
|
97 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
98 |
+
nn.init.zeros_(self.up.weight)
|
99 |
+
|
100 |
+
def forward(self, hidden_states):
|
101 |
+
orig_dtype = hidden_states.dtype
|
102 |
+
dtype = self.down.weight.dtype
|
103 |
+
|
104 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
105 |
+
up_hidden_states = self.up(down_hidden_states)
|
106 |
+
|
107 |
+
if self.network_alpha is not None:
|
108 |
+
up_hidden_states *= self.network_alpha / self.rank
|
109 |
+
|
110 |
+
return up_hidden_states.to(orig_dtype)
|
111 |
+
|
112 |
+
class FLuxSelfAttnProcessor:
|
113 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
114 |
+
print('2' * 30)
|
115 |
+
|
116 |
+
qkv = attn.qkv(x)
|
117 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
118 |
+
q, k = attn.norm(q, k, v)
|
119 |
+
x = attention(q, k, v, pe=pe)
|
120 |
+
x = attn.proj(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
class LoraFluxAttnProcessor(nn.Module):
|
124 |
+
|
125 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
126 |
+
super().__init__()
|
127 |
+
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
128 |
+
self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
|
129 |
+
self.lora_weight = lora_weight
|
130 |
+
|
131 |
+
|
132 |
+
def __call__(self, attn, x, pe, **attention_kwargs):
|
133 |
+
qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
|
134 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
135 |
+
q, k = attn.norm(q, k, v)
|
136 |
+
x = attention(q, k, v, pe=pe)
|
137 |
+
x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
|
138 |
+
print('1' * 30)
|
139 |
+
print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
|
140 |
+
return x
|
141 |
+
|
142 |
+
class SelfAttention(nn.Module):
|
143 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
144 |
+
super().__init__()
|
145 |
+
self.num_heads = num_heads
|
146 |
+
head_dim = dim // num_heads
|
147 |
+
|
148 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
149 |
+
self.norm = QKNorm(head_dim)
|
150 |
+
self.proj = nn.Linear(dim, dim)
|
151 |
+
def forward():
|
152 |
+
pass
|
153 |
+
|
154 |
+
|
155 |
+
@dataclass
|
156 |
+
class ModulationOut:
|
157 |
+
shift: Tensor
|
158 |
+
scale: Tensor
|
159 |
+
gate: Tensor
|
160 |
+
|
161 |
+
|
162 |
+
class Modulation(nn.Module):
|
163 |
+
def __init__(self, dim: int, double: bool):
|
164 |
+
super().__init__()
|
165 |
+
self.is_double = double
|
166 |
+
self.multiplier = 6 if double else 3
|
167 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
168 |
+
|
169 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
170 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
171 |
+
|
172 |
+
return (
|
173 |
+
ModulationOut(*out[:3]),
|
174 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
175 |
+
)
|
176 |
+
|
177 |
+
class DoubleStreamBlockLoraProcessor(nn.Module):
|
178 |
+
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
179 |
+
super().__init__()
|
180 |
+
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
181 |
+
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
182 |
+
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
183 |
+
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
184 |
+
self.lora_weight = lora_weight
|
185 |
+
|
186 |
+
def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
|
187 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
188 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
189 |
+
|
190 |
+
# prepare image for attention
|
191 |
+
img_modulated = attn.img_norm1(img)
|
192 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
193 |
+
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
|
194 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
195 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
196 |
+
|
197 |
+
# prepare txt for attention
|
198 |
+
txt_modulated = attn.txt_norm1(txt)
|
199 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
200 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
|
201 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
202 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
203 |
+
|
204 |
+
# run actual attention
|
205 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
206 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
207 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
208 |
+
|
209 |
+
attn1 = attention(q, k, v, pe=pe)
|
210 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
211 |
+
|
212 |
+
# calculate the img bloks
|
213 |
+
img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
|
214 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
215 |
+
|
216 |
+
# calculate the txt bloks
|
217 |
+
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
|
218 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
219 |
+
return img, txt
|
220 |
+
|
221 |
+
class IPDoubleStreamBlockProcessor(nn.Module):
|
222 |
+
"""Attention processor for handling IP-adapter with double stream block."""
|
223 |
+
|
224 |
+
def __init__(self, context_dim, hidden_dim):
|
225 |
+
super().__init__()
|
226 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
227 |
+
raise ImportError(
|
228 |
+
"IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
|
229 |
+
)
|
230 |
+
|
231 |
+
# Ensure context_dim matches the dimension of image_proj
|
232 |
+
self.context_dim = context_dim
|
233 |
+
self.hidden_dim = hidden_dim
|
234 |
+
|
235 |
+
# Initialize projections for IP-adapter
|
236 |
+
self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
|
237 |
+
self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
|
238 |
+
|
239 |
+
nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
|
240 |
+
nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
|
241 |
+
|
242 |
+
nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
|
243 |
+
nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
|
244 |
+
|
245 |
+
def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
|
246 |
+
|
247 |
+
# Prepare image for attention
|
248 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
249 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
250 |
+
|
251 |
+
img_modulated = attn.img_norm1(img)
|
252 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
253 |
+
img_qkv = attn.img_attn.qkv(img_modulated)
|
254 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
255 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
256 |
+
|
257 |
+
txt_modulated = attn.txt_norm1(txt)
|
258 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
259 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
260 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
261 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
262 |
+
|
263 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
264 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
265 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
266 |
+
|
267 |
+
attn1 = attention(q, k, v, pe=pe)
|
268 |
+
txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:]
|
269 |
+
|
270 |
+
# print(f"txt_attn shape: {txt_attn.size()}")
|
271 |
+
# print(f"img_attn shape: {img_attn.size()}")
|
272 |
+
|
273 |
+
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
274 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
275 |
+
|
276 |
+
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
277 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
278 |
+
|
279 |
+
|
280 |
+
# IP-adapter processing
|
281 |
+
ip_query = img_q # latent sample query
|
282 |
+
ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
|
283 |
+
ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
|
284 |
+
|
285 |
+
# Reshape projections for multi-head attention
|
286 |
+
ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
|
287 |
+
ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
|
288 |
+
|
289 |
+
# Compute attention between IP projections and the latent query
|
290 |
+
ip_attention = F.scaled_dot_product_attention(
|
291 |
+
ip_query,
|
292 |
+
ip_key,
|
293 |
+
ip_value,
|
294 |
+
dropout_p=0.0,
|
295 |
+
is_causal=False
|
296 |
+
)
|
297 |
+
ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
|
298 |
+
|
299 |
+
img = img + ip_scale * ip_attention
|
300 |
+
|
301 |
+
return img, txt
|
302 |
+
|
303 |
+
class DoubleStreamBlockProcessor:
|
304 |
+
def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
|
305 |
+
img_mod1, img_mod2 = attn.img_mod(vec)
|
306 |
+
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
307 |
+
|
308 |
+
# prepare image for attention
|
309 |
+
img_modulated = attn.img_norm1(img)
|
310 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
311 |
+
img_qkv = attn.img_attn.qkv(img_modulated)
|
312 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
313 |
+
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
314 |
+
|
315 |
+
# prepare txt for attention
|
316 |
+
txt_modulated = attn.txt_norm1(txt)
|
317 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
318 |
+
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
319 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
320 |
+
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
321 |
+
|
322 |
+
# run actual attention
|
323 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
324 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
325 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
326 |
+
|
327 |
+
attn1 = attention(q, k, v, pe=pe)
|
328 |
+
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
329 |
+
|
330 |
+
# calculate the img bloks
|
331 |
+
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
332 |
+
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
333 |
+
|
334 |
+
# calculate the txt bloks
|
335 |
+
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
336 |
+
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
337 |
+
return img, txt
|
338 |
+
|
339 |
+
class DoubleStreamBlock(nn.Module):
|
340 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
341 |
+
super().__init__()
|
342 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
343 |
+
self.num_heads = num_heads
|
344 |
+
self.hidden_size = hidden_size
|
345 |
+
self.head_dim = hidden_size // num_heads
|
346 |
+
|
347 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
348 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
349 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
350 |
+
|
351 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
352 |
+
self.img_mlp = nn.Sequential(
|
353 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
354 |
+
nn.GELU(approximate="tanh"),
|
355 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
356 |
+
)
|
357 |
+
|
358 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
359 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
360 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
361 |
+
|
362 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
363 |
+
self.txt_mlp = nn.Sequential(
|
364 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
365 |
+
nn.GELU(approximate="tanh"),
|
366 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
367 |
+
)
|
368 |
+
processor = DoubleStreamBlockProcessor()
|
369 |
+
self.set_processor(processor)
|
370 |
+
|
371 |
+
def set_processor(self, processor) -> None:
|
372 |
+
self.processor = processor
|
373 |
+
|
374 |
+
def get_processor(self):
|
375 |
+
return self.processor
|
376 |
+
|
377 |
+
def forward(
|
378 |
+
self,
|
379 |
+
img: Tensor,
|
380 |
+
txt: Tensor,
|
381 |
+
vec: Tensor,
|
382 |
+
pe: Tensor,
|
383 |
+
image_proj: Tensor = None,
|
384 |
+
ip_scale: float =1.0,
|
385 |
+
) -> tuple[Tensor, Tensor]:
|
386 |
+
if image_proj is None:
|
387 |
+
return self.processor(self, img, txt, vec, pe)
|
388 |
+
else:
|
389 |
+
return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
|
390 |
+
|
391 |
+
class IPSingleStreamBlockProcessor(nn.Module):
|
392 |
+
"""Attention processor for handling IP-adapter with single stream block."""
|
393 |
+
def __init__(self, context_dim, hidden_dim):
|
394 |
+
super().__init__()
|
395 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
396 |
+
raise ImportError(
|
397 |
+
"IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
|
398 |
+
)
|
399 |
+
|
400 |
+
# Ensure context_dim matches the dimension of image_proj
|
401 |
+
self.context_dim = context_dim
|
402 |
+
self.hidden_dim = hidden_dim
|
403 |
+
|
404 |
+
# Initialize projections for IP-adapter
|
405 |
+
self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
406 |
+
self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
407 |
+
|
408 |
+
nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
|
409 |
+
nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
|
410 |
+
|
411 |
+
def __call__(
|
412 |
+
self,
|
413 |
+
attn: nn.Module,
|
414 |
+
x: Tensor,
|
415 |
+
vec: Tensor,
|
416 |
+
pe: Tensor,
|
417 |
+
image_proj: Tensor | None = None,
|
418 |
+
ip_scale: float = 1.0
|
419 |
+
) -> Tensor:
|
420 |
+
|
421 |
+
mod, _ = attn.modulation(vec)
|
422 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
423 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
424 |
+
|
425 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
426 |
+
q, k = attn.norm(q, k, v)
|
427 |
+
|
428 |
+
# compute attention
|
429 |
+
attn_1 = attention(q, k, v, pe=pe)
|
430 |
+
|
431 |
+
# IP-adapter processing
|
432 |
+
ip_query = q
|
433 |
+
ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
|
434 |
+
ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
|
435 |
+
|
436 |
+
# Reshape projections for multi-head attention
|
437 |
+
ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
|
438 |
+
ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
|
439 |
+
|
440 |
+
|
441 |
+
# Compute attention between IP projections and the latent query
|
442 |
+
ip_attention = F.scaled_dot_product_attention(
|
443 |
+
ip_query,
|
444 |
+
ip_key,
|
445 |
+
ip_value
|
446 |
+
)
|
447 |
+
ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
|
448 |
+
|
449 |
+
attn_out = attn_1 + ip_scale * ip_attention
|
450 |
+
|
451 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
452 |
+
output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
|
453 |
+
out = x + mod.gate * output
|
454 |
+
|
455 |
+
return out
|
456 |
+
|
457 |
+
class SingleStreamBlockProcessor:
|
458 |
+
def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
459 |
+
|
460 |
+
mod, _ = attn.modulation(vec)
|
461 |
+
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
462 |
+
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
463 |
+
|
464 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
465 |
+
q, k = attn.norm(q, k, v)
|
466 |
+
|
467 |
+
# compute attention
|
468 |
+
attn_1 = attention(q, k, v, pe=pe)
|
469 |
+
|
470 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
471 |
+
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
472 |
+
output = x + mod.gate * output
|
473 |
+
return output
|
474 |
+
|
475 |
+
class SingleStreamBlock(nn.Module):
|
476 |
+
"""
|
477 |
+
A DiT block with parallel linear layers as described in
|
478 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
479 |
+
"""
|
480 |
+
|
481 |
+
def __init__(
|
482 |
+
self,
|
483 |
+
hidden_size: int,
|
484 |
+
num_heads: int,
|
485 |
+
mlp_ratio: float = 4.0,
|
486 |
+
qk_scale: float | None = None,
|
487 |
+
):
|
488 |
+
super().__init__()
|
489 |
+
self.hidden_dim = hidden_size
|
490 |
+
self.num_heads = num_heads
|
491 |
+
self.head_dim = hidden_size // num_heads
|
492 |
+
self.scale = qk_scale or self.head_dim**-0.5
|
493 |
+
|
494 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
495 |
+
# qkv and mlp_in
|
496 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
497 |
+
# proj and mlp_out
|
498 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
499 |
+
|
500 |
+
self.norm = QKNorm(self.head_dim)
|
501 |
+
|
502 |
+
self.hidden_size = hidden_size
|
503 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
504 |
+
|
505 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
506 |
+
self.modulation = Modulation(hidden_size, double=False)
|
507 |
+
|
508 |
+
processor = SingleStreamBlockProcessor()
|
509 |
+
self.set_processor(processor)
|
510 |
+
|
511 |
+
|
512 |
+
def set_processor(self, processor) -> None:
|
513 |
+
self.processor = processor
|
514 |
+
|
515 |
+
def get_processor(self):
|
516 |
+
return self.processor
|
517 |
+
|
518 |
+
def forward(
|
519 |
+
self,
|
520 |
+
x: Tensor,
|
521 |
+
vec: Tensor,
|
522 |
+
pe: Tensor,
|
523 |
+
image_proj: Tensor | None = None,
|
524 |
+
ip_scale: float = 1.0
|
525 |
+
) -> Tensor:
|
526 |
+
if image_proj is None:
|
527 |
+
return self.processor(self, x, vec, pe)
|
528 |
+
else:
|
529 |
+
return self.processor(self, x, vec, pe, image_proj, ip_scale)
|
530 |
+
|
531 |
+
|
532 |
+
|
533 |
+
class LastLayer(nn.Module):
|
534 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
535 |
+
super().__init__()
|
536 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
537 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
538 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
539 |
+
|
540 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
541 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
542 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
543 |
+
x = self.linear(x)
|
544 |
+
return x
|
545 |
+
|
546 |
+
class ImageProjModel(torch.nn.Module):
|
547 |
+
"""Projection Model
|
548 |
+
https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
|
549 |
+
"""
|
550 |
+
|
551 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
552 |
+
super().__init__()
|
553 |
+
|
554 |
+
self.generator = None
|
555 |
+
self.cross_attention_dim = cross_attention_dim
|
556 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
557 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
558 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
559 |
+
|
560 |
+
def forward(self, image_embeds):
|
561 |
+
embeds = image_embeds
|
562 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
563 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
564 |
+
)
|
565 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
566 |
+
return clip_extra_context_tokens
|
567 |
+
|
src/flux/sampling.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from .model import Flux
|
9 |
+
from .modules.conditioner import HFEmbedder
|
10 |
+
|
11 |
+
|
12 |
+
def get_noise(
|
13 |
+
num_samples: int,
|
14 |
+
height: int,
|
15 |
+
width: int,
|
16 |
+
device: torch.device,
|
17 |
+
dtype: torch.dtype,
|
18 |
+
seed: int,
|
19 |
+
):
|
20 |
+
return torch.randn(
|
21 |
+
num_samples,
|
22 |
+
16,
|
23 |
+
# allow for packing
|
24 |
+
2 * math.ceil(height / 16),
|
25 |
+
2 * math.ceil(width / 16),
|
26 |
+
device=device,
|
27 |
+
dtype=dtype,
|
28 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
33 |
+
bs, c, h, w = img.shape
|
34 |
+
if bs == 1 and not isinstance(prompt, str):
|
35 |
+
bs = len(prompt)
|
36 |
+
|
37 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
38 |
+
if img.shape[0] == 1 and bs > 1:
|
39 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
40 |
+
|
41 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
42 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
43 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
44 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
45 |
+
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
txt = t5(prompt)
|
49 |
+
if txt.shape[0] == 1 and bs > 1:
|
50 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
51 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
52 |
+
|
53 |
+
vec = clip(prompt)
|
54 |
+
if vec.shape[0] == 1 and bs > 1:
|
55 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
56 |
+
|
57 |
+
return {
|
58 |
+
"img": img,
|
59 |
+
"img_ids": img_ids.to(img.device),
|
60 |
+
"txt": txt.to(img.device),
|
61 |
+
"txt_ids": txt_ids.to(img.device),
|
62 |
+
"vec": vec.to(img.device),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
67 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
68 |
+
|
69 |
+
|
70 |
+
def get_lin_function(
|
71 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
72 |
+
) -> Callable[[float], float]:
|
73 |
+
m = (y2 - y1) / (x2 - x1)
|
74 |
+
b = y1 - m * x1
|
75 |
+
return lambda x: m * x + b
|
76 |
+
|
77 |
+
|
78 |
+
def get_schedule(
|
79 |
+
num_steps: int,
|
80 |
+
image_seq_len: int,
|
81 |
+
base_shift: float = 0.5,
|
82 |
+
max_shift: float = 1.15,
|
83 |
+
shift: bool = True,
|
84 |
+
) -> list[float]:
|
85 |
+
# extra step for zero
|
86 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
87 |
+
|
88 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
89 |
+
if shift:
|
90 |
+
# eastimate mu based on linear estimation between two points
|
91 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
92 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
93 |
+
|
94 |
+
return timesteps.tolist()
|
95 |
+
|
96 |
+
|
97 |
+
def denoise(
|
98 |
+
model: Flux,
|
99 |
+
# model input
|
100 |
+
img: Tensor,
|
101 |
+
img_ids: Tensor,
|
102 |
+
txt: Tensor,
|
103 |
+
txt_ids: Tensor,
|
104 |
+
vec: Tensor,
|
105 |
+
neg_txt: Tensor,
|
106 |
+
neg_txt_ids: Tensor,
|
107 |
+
neg_vec: Tensor,
|
108 |
+
# sampling parameters
|
109 |
+
timesteps: list[float],
|
110 |
+
guidance: float = 4.0,
|
111 |
+
true_gs = 1,
|
112 |
+
timestep_to_start_cfg=0,
|
113 |
+
# ip-adapter parameters
|
114 |
+
image_proj: Tensor=None,
|
115 |
+
neg_image_proj: Tensor=None,
|
116 |
+
ip_scale: Tensor | float = 1.0,
|
117 |
+
neg_ip_scale: Tensor | float = 1.0
|
118 |
+
):
|
119 |
+
i = 0
|
120 |
+
# this is ignored for schnell
|
121 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
122 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
123 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
124 |
+
pred = model(
|
125 |
+
img=img,
|
126 |
+
img_ids=img_ids,
|
127 |
+
txt=txt,
|
128 |
+
txt_ids=txt_ids,
|
129 |
+
y=vec,
|
130 |
+
timesteps=t_vec,
|
131 |
+
guidance=guidance_vec,
|
132 |
+
image_proj=image_proj,
|
133 |
+
ip_scale=ip_scale,
|
134 |
+
)
|
135 |
+
if i >= timestep_to_start_cfg:
|
136 |
+
neg_pred = model(
|
137 |
+
img=img,
|
138 |
+
img_ids=img_ids,
|
139 |
+
txt=neg_txt,
|
140 |
+
txt_ids=neg_txt_ids,
|
141 |
+
y=neg_vec,
|
142 |
+
timesteps=t_vec,
|
143 |
+
guidance=guidance_vec,
|
144 |
+
image_proj=neg_image_proj,
|
145 |
+
ip_scale=neg_ip_scale,
|
146 |
+
)
|
147 |
+
pred = neg_pred + true_gs * (pred - neg_pred)
|
148 |
+
img = img + (t_prev - t_curr) * pred
|
149 |
+
i += 1
|
150 |
+
return img
|
151 |
+
|
152 |
+
def denoise_controlnet(
|
153 |
+
model: Flux,
|
154 |
+
controlnet:None,
|
155 |
+
# model input
|
156 |
+
img: Tensor,
|
157 |
+
img_ids: Tensor,
|
158 |
+
txt: Tensor,
|
159 |
+
txt_ids: Tensor,
|
160 |
+
vec: Tensor,
|
161 |
+
neg_txt: Tensor,
|
162 |
+
neg_txt_ids: Tensor,
|
163 |
+
neg_vec: Tensor,
|
164 |
+
controlnet_cond,
|
165 |
+
# sampling parameters
|
166 |
+
timesteps: list[float],
|
167 |
+
guidance: float = 4.0,
|
168 |
+
true_gs = 1,
|
169 |
+
controlnet_gs=0.7,
|
170 |
+
timestep_to_start_cfg=0,
|
171 |
+
# ip-adapter parameters
|
172 |
+
image_proj: Tensor=None,
|
173 |
+
neg_image_proj: Tensor=None,
|
174 |
+
ip_scale: Tensor | float = 1,
|
175 |
+
neg_ip_scale: Tensor | float = 1,
|
176 |
+
):
|
177 |
+
# this is ignored for schnell
|
178 |
+
i = 0
|
179 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
180 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
181 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
182 |
+
block_res_samples = controlnet(
|
183 |
+
img=img,
|
184 |
+
img_ids=img_ids,
|
185 |
+
controlnet_cond=controlnet_cond,
|
186 |
+
txt=txt,
|
187 |
+
txt_ids=txt_ids,
|
188 |
+
y=vec,
|
189 |
+
timesteps=t_vec,
|
190 |
+
guidance=guidance_vec,
|
191 |
+
)
|
192 |
+
pred = model(
|
193 |
+
img=img,
|
194 |
+
img_ids=img_ids,
|
195 |
+
txt=txt,
|
196 |
+
txt_ids=txt_ids,
|
197 |
+
y=vec,
|
198 |
+
timesteps=t_vec,
|
199 |
+
guidance=guidance_vec,
|
200 |
+
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples],
|
201 |
+
image_proj=image_proj,
|
202 |
+
ip_scale=ip_scale,
|
203 |
+
)
|
204 |
+
if i >= timestep_to_start_cfg:
|
205 |
+
neg_block_res_samples = controlnet(
|
206 |
+
img=img,
|
207 |
+
img_ids=img_ids,
|
208 |
+
controlnet_cond=controlnet_cond,
|
209 |
+
txt=neg_txt,
|
210 |
+
txt_ids=neg_txt_ids,
|
211 |
+
y=neg_vec,
|
212 |
+
timesteps=t_vec,
|
213 |
+
guidance=guidance_vec,
|
214 |
+
)
|
215 |
+
neg_pred = model(
|
216 |
+
img=img,
|
217 |
+
img_ids=img_ids,
|
218 |
+
txt=neg_txt,
|
219 |
+
txt_ids=neg_txt_ids,
|
220 |
+
y=neg_vec,
|
221 |
+
timesteps=t_vec,
|
222 |
+
guidance=guidance_vec,
|
223 |
+
block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples],
|
224 |
+
image_proj=neg_image_proj,
|
225 |
+
ip_scale=neg_ip_scale,
|
226 |
+
)
|
227 |
+
pred = neg_pred + true_gs * (pred - neg_pred)
|
228 |
+
|
229 |
+
img = img + (t_prev - t_curr) * pred
|
230 |
+
|
231 |
+
i += 1
|
232 |
+
return img
|
233 |
+
|
234 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
235 |
+
return rearrange(
|
236 |
+
x,
|
237 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
238 |
+
h=math.ceil(height / 16),
|
239 |
+
w=math.ceil(width / 16),
|
240 |
+
ph=2,
|
241 |
+
pw=2,
|
242 |
+
)
|
src/flux/util.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import json
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from safetensors import safe_open
|
11 |
+
from safetensors.torch import load_file as load_sft
|
12 |
+
|
13 |
+
from optimum.quanto import requantize
|
14 |
+
|
15 |
+
from .model import Flux, FluxParams
|
16 |
+
from .controlnet import ControlNetFlux
|
17 |
+
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
18 |
+
from .modules.conditioner import HFEmbedder
|
19 |
+
from .annotator.dwpose import DWposeDetector
|
20 |
+
from .annotator.mlsd import MLSDdetector
|
21 |
+
from .annotator.canny import CannyDetector
|
22 |
+
from .annotator.midas import MidasDetector
|
23 |
+
from .annotator.hed import HEDdetector
|
24 |
+
from .annotator.tile import TileDetector
|
25 |
+
|
26 |
+
|
27 |
+
def load_safetensors(path):
|
28 |
+
tensors = {}
|
29 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
30 |
+
for key in f.keys():
|
31 |
+
tensors[key] = f.get_tensor(key)
|
32 |
+
return tensors
|
33 |
+
|
34 |
+
def get_lora_rank(checkpoint):
|
35 |
+
for k in checkpoint.keys():
|
36 |
+
if k.endswith(".down.weight"):
|
37 |
+
return checkpoint[k].shape[0]
|
38 |
+
|
39 |
+
def load_checkpoint(local_path, repo_id, name):
|
40 |
+
if local_path is not None:
|
41 |
+
if '.safetensors' in local_path:
|
42 |
+
print("Loading .safetensors checkpoint...")
|
43 |
+
checkpoint = load_safetensors(local_path)
|
44 |
+
else:
|
45 |
+
print("Loading checkpoint...")
|
46 |
+
checkpoint = torch.load(local_path, map_location='cpu')
|
47 |
+
elif repo_id is not None and name is not None:
|
48 |
+
print("Loading checkpoint from repo id...")
|
49 |
+
checkpoint = load_from_repo_id(repo_id, name)
|
50 |
+
else:
|
51 |
+
raise ValueError(
|
52 |
+
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
|
53 |
+
)
|
54 |
+
return checkpoint
|
55 |
+
|
56 |
+
|
57 |
+
def c_crop(image):
|
58 |
+
width, height = image.size
|
59 |
+
new_size = min(width, height)
|
60 |
+
left = (width - new_size) / 2
|
61 |
+
top = (height - new_size) / 2
|
62 |
+
right = (width + new_size) / 2
|
63 |
+
bottom = (height + new_size) / 2
|
64 |
+
return image.crop((left, top, right, bottom))
|
65 |
+
|
66 |
+
|
67 |
+
class Annotator:
|
68 |
+
def __init__(self, name: str, device: str):
|
69 |
+
if name == "canny":
|
70 |
+
processor = CannyDetector()
|
71 |
+
elif name == "openpose":
|
72 |
+
processor = DWposeDetector(device)
|
73 |
+
elif name == "depth":
|
74 |
+
processor = MidasDetector()
|
75 |
+
elif name == "hed":
|
76 |
+
processor = HEDdetector()
|
77 |
+
elif name == "hough":
|
78 |
+
processor = MLSDdetector()
|
79 |
+
elif name == "tile":
|
80 |
+
processor = TileDetector()
|
81 |
+
self.name = name
|
82 |
+
self.processor = processor
|
83 |
+
|
84 |
+
def __call__(self, image: Image, width: int, height: int):
|
85 |
+
image = c_crop(image)
|
86 |
+
image = image.resize((width, height))
|
87 |
+
image = np.array(image)
|
88 |
+
if self.name == "canny":
|
89 |
+
result = self.processor(image, low_threshold=100, high_threshold=200)
|
90 |
+
elif self.name == "hough":
|
91 |
+
result = self.processor(image, thr_v=0.05, thr_d=5)
|
92 |
+
elif self.name == "depth":
|
93 |
+
result = self.processor(image)
|
94 |
+
result, _ = result
|
95 |
+
else:
|
96 |
+
result = self.processor(image)
|
97 |
+
|
98 |
+
if result.ndim != 3:
|
99 |
+
result = result[:, :, None]
|
100 |
+
result = np.concatenate([result, result, result], axis=2)
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class ModelSpec:
|
106 |
+
params: FluxParams
|
107 |
+
ae_params: AutoEncoderParams
|
108 |
+
ckpt_path: str | None
|
109 |
+
ae_path: str | None
|
110 |
+
repo_id: str | None
|
111 |
+
repo_flow: str | None
|
112 |
+
repo_ae: str | None
|
113 |
+
repo_id_ae: str | None
|
114 |
+
|
115 |
+
|
116 |
+
configs = {
|
117 |
+
"flux-dev": ModelSpec(
|
118 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
119 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
120 |
+
repo_flow="flux1-dev.safetensors",
|
121 |
+
repo_ae="ae.safetensors",
|
122 |
+
ckpt_path=os.getenv("FLUX_DEV"),
|
123 |
+
params=FluxParams(
|
124 |
+
in_channels=64,
|
125 |
+
vec_in_dim=768,
|
126 |
+
context_in_dim=4096,
|
127 |
+
hidden_size=3072,
|
128 |
+
mlp_ratio=4.0,
|
129 |
+
num_heads=24,
|
130 |
+
depth=19,
|
131 |
+
depth_single_blocks=38,
|
132 |
+
axes_dim=[16, 56, 56],
|
133 |
+
theta=10_000,
|
134 |
+
qkv_bias=True,
|
135 |
+
guidance_embed=True,
|
136 |
+
),
|
137 |
+
ae_path=os.getenv("AE"),
|
138 |
+
ae_params=AutoEncoderParams(
|
139 |
+
resolution=256,
|
140 |
+
in_channels=3,
|
141 |
+
ch=128,
|
142 |
+
out_ch=3,
|
143 |
+
ch_mult=[1, 2, 4, 4],
|
144 |
+
num_res_blocks=2,
|
145 |
+
z_channels=16,
|
146 |
+
scale_factor=0.3611,
|
147 |
+
shift_factor=0.1159,
|
148 |
+
),
|
149 |
+
),
|
150 |
+
"flux-dev-fp8": ModelSpec(
|
151 |
+
repo_id="XLabs-AI/flux-dev-fp8",
|
152 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
153 |
+
repo_flow="flux-dev-fp8.safetensors",
|
154 |
+
repo_ae="ae.safetensors",
|
155 |
+
ckpt_path=os.getenv("FLUX_DEV_FP8"),
|
156 |
+
params=FluxParams(
|
157 |
+
in_channels=64,
|
158 |
+
vec_in_dim=768,
|
159 |
+
context_in_dim=4096,
|
160 |
+
hidden_size=3072,
|
161 |
+
mlp_ratio=4.0,
|
162 |
+
num_heads=24,
|
163 |
+
depth=19,
|
164 |
+
depth_single_blocks=38,
|
165 |
+
axes_dim=[16, 56, 56],
|
166 |
+
theta=10_000,
|
167 |
+
qkv_bias=True,
|
168 |
+
guidance_embed=True,
|
169 |
+
),
|
170 |
+
ae_path=os.getenv("AE"),
|
171 |
+
ae_params=AutoEncoderParams(
|
172 |
+
resolution=256,
|
173 |
+
in_channels=3,
|
174 |
+
ch=128,
|
175 |
+
out_ch=3,
|
176 |
+
ch_mult=[1, 2, 4, 4],
|
177 |
+
num_res_blocks=2,
|
178 |
+
z_channels=16,
|
179 |
+
scale_factor=0.3611,
|
180 |
+
shift_factor=0.1159,
|
181 |
+
),
|
182 |
+
),
|
183 |
+
"flux-schnell": ModelSpec(
|
184 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
185 |
+
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
186 |
+
repo_flow="flux1-schnell.safetensors",
|
187 |
+
repo_ae="ae.safetensors",
|
188 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
189 |
+
params=FluxParams(
|
190 |
+
in_channels=64,
|
191 |
+
vec_in_dim=768,
|
192 |
+
context_in_dim=4096,
|
193 |
+
hidden_size=3072,
|
194 |
+
mlp_ratio=4.0,
|
195 |
+
num_heads=24,
|
196 |
+
depth=19,
|
197 |
+
depth_single_blocks=38,
|
198 |
+
axes_dim=[16, 56, 56],
|
199 |
+
theta=10_000,
|
200 |
+
qkv_bias=True,
|
201 |
+
guidance_embed=False,
|
202 |
+
),
|
203 |
+
ae_path=os.getenv("AE"),
|
204 |
+
ae_params=AutoEncoderParams(
|
205 |
+
resolution=256,
|
206 |
+
in_channels=3,
|
207 |
+
ch=128,
|
208 |
+
out_ch=3,
|
209 |
+
ch_mult=[1, 2, 4, 4],
|
210 |
+
num_res_blocks=2,
|
211 |
+
z_channels=16,
|
212 |
+
scale_factor=0.3611,
|
213 |
+
shift_factor=0.1159,
|
214 |
+
),
|
215 |
+
),
|
216 |
+
}
|
217 |
+
|
218 |
+
|
219 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
220 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
221 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
222 |
+
print("\n" + "-" * 79 + "\n")
|
223 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
224 |
+
elif len(missing) > 0:
|
225 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
226 |
+
elif len(unexpected) > 0:
|
227 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
228 |
+
|
229 |
+
def load_from_repo_id(repo_id, checkpoint_name):
|
230 |
+
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
|
231 |
+
sd = load_sft(ckpt_path, device='cpu')
|
232 |
+
return sd
|
233 |
+
|
234 |
+
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
235 |
+
# Loading Flux
|
236 |
+
print("Init model")
|
237 |
+
ckpt_path = configs[name].ckpt_path
|
238 |
+
if (
|
239 |
+
ckpt_path is None
|
240 |
+
and configs[name].repo_id is not None
|
241 |
+
and configs[name].repo_flow is not None
|
242 |
+
and hf_download
|
243 |
+
):
|
244 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
245 |
+
|
246 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
247 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
248 |
+
|
249 |
+
if ckpt_path is not None:
|
250 |
+
print("Loading checkpoint")
|
251 |
+
# load_sft doesn't support torch.device
|
252 |
+
sd = load_sft(ckpt_path, device=str(device))
|
253 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
254 |
+
print_load_warning(missing, unexpected)
|
255 |
+
return model
|
256 |
+
|
257 |
+
def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
258 |
+
# Loading Flux
|
259 |
+
print("Init model")
|
260 |
+
ckpt_path = configs[name].ckpt_path
|
261 |
+
if (
|
262 |
+
ckpt_path is None
|
263 |
+
and configs[name].repo_id is not None
|
264 |
+
and configs[name].repo_flow is not None
|
265 |
+
and hf_download
|
266 |
+
):
|
267 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
268 |
+
|
269 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
270 |
+
model = Flux(configs[name].params)
|
271 |
+
|
272 |
+
if ckpt_path is not None:
|
273 |
+
print("Loading checkpoint")
|
274 |
+
# load_sft doesn't support torch.device
|
275 |
+
sd = load_sft(ckpt_path, device=str(device))
|
276 |
+
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
277 |
+
print_load_warning(missing, unexpected)
|
278 |
+
return model
|
279 |
+
|
280 |
+
def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
281 |
+
# Loading Flux
|
282 |
+
print("Init model")
|
283 |
+
ckpt_path = configs[name].ckpt_path
|
284 |
+
if (
|
285 |
+
ckpt_path is None
|
286 |
+
and configs[name].repo_id is not None
|
287 |
+
and configs[name].repo_flow is not None
|
288 |
+
and hf_download
|
289 |
+
):
|
290 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
291 |
+
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
|
292 |
+
|
293 |
+
|
294 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
295 |
+
|
296 |
+
print("Loading checkpoint")
|
297 |
+
# load_sft doesn't support torch.device
|
298 |
+
sd = load_sft(ckpt_path, device='cpu')
|
299 |
+
with open(json_path, "r") as f:
|
300 |
+
quantization_map = json.load(f)
|
301 |
+
print("Start a quantization process...")
|
302 |
+
requantize(model, sd, quantization_map, device=device)
|
303 |
+
print("Model is quantized!")
|
304 |
+
return model
|
305 |
+
|
306 |
+
def load_controlnet(name, device, transformer=None):
|
307 |
+
with torch.device(device):
|
308 |
+
controlnet = ControlNetFlux(configs[name].params)
|
309 |
+
if transformer is not None:
|
310 |
+
controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
311 |
+
return controlnet
|
312 |
+
|
313 |
+
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
314 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
315 |
+
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
316 |
+
|
317 |
+
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
318 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
319 |
+
|
320 |
+
|
321 |
+
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
322 |
+
ckpt_path = configs[name].ae_path
|
323 |
+
if (
|
324 |
+
ckpt_path is None
|
325 |
+
and configs[name].repo_id is not None
|
326 |
+
and configs[name].repo_ae is not None
|
327 |
+
and hf_download
|
328 |
+
):
|
329 |
+
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
|
330 |
+
|
331 |
+
# Loading the autoencoder
|
332 |
+
print("Init AE")
|
333 |
+
with torch.device("meta" if ckpt_path is not None else device):
|
334 |
+
ae = AutoEncoder(configs[name].ae_params)
|
335 |
+
|
336 |
+
if ckpt_path is not None:
|
337 |
+
sd = load_sft(ckpt_path, device=str(device))
|
338 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
339 |
+
print_load_warning(missing, unexpected)
|
340 |
+
return ae
|
341 |
+
|
342 |
+
|
343 |
+
class WatermarkEmbedder:
|
344 |
+
def __init__(self, watermark):
|
345 |
+
self.watermark = watermark
|
346 |
+
self.num_bits = len(WATERMARK_BITS)
|
347 |
+
self.encoder = WatermarkEncoder()
|
348 |
+
self.encoder.set_watermark("bits", self.watermark)
|
349 |
+
|
350 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
351 |
+
"""
|
352 |
+
Adds a predefined watermark to the input image
|
353 |
+
|
354 |
+
Args:
|
355 |
+
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
same as input but watermarked
|
359 |
+
"""
|
360 |
+
image = 0.5 * image + 0.5
|
361 |
+
squeeze = len(image.shape) == 4
|
362 |
+
if squeeze:
|
363 |
+
image = image[None, ...]
|
364 |
+
n = image.shape[0]
|
365 |
+
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
366 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
367 |
+
# watermarking libary expects input as cv2 BGR format
|
368 |
+
for k in range(image_np.shape[0]):
|
369 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
370 |
+
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
371 |
+
image.device
|
372 |
+
)
|
373 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
374 |
+
if squeeze:
|
375 |
+
image = image[0]
|
376 |
+
image = 2 * image - 1
|
377 |
+
return image
|
378 |
+
|
379 |
+
|
380 |
+
# A fixed 48-bit message that was choosen at random
|
381 |
+
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
382 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
383 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|