fountai commited on
Commit
fca8815
·
1 Parent(s): 4d207eb
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +167 -0
  2. LICENSE +201 -0
  3. app.py +1 -1
  4. cog.yaml +24 -0
  5. flux +0 -1
  6. image_datasets/canny_dataset.py +59 -0
  7. image_datasets/dataset.py +45 -0
  8. main.py +180 -0
  9. models_licence/LICENSE-FLUX1-dev +42 -0
  10. predict.py +134 -0
  11. src/flux/__init__.py +11 -0
  12. src/flux/__main__.py +4 -0
  13. src/flux/annotator/canny/__init__.py +6 -0
  14. src/flux/annotator/ckpts/ckpts.txt +1 -0
  15. src/flux/annotator/dwpose/__init__.py +68 -0
  16. src/flux/annotator/dwpose/onnxdet.py +125 -0
  17. src/flux/annotator/dwpose/onnxpose.py +360 -0
  18. src/flux/annotator/dwpose/util.py +297 -0
  19. src/flux/annotator/dwpose/wholebody.py +48 -0
  20. src/flux/annotator/hed/__init__.py +95 -0
  21. src/flux/annotator/midas/LICENSE +21 -0
  22. src/flux/annotator/midas/__init__.py +42 -0
  23. src/flux/annotator/midas/api.py +168 -0
  24. src/flux/annotator/midas/midas/__init__.py +0 -0
  25. src/flux/annotator/midas/midas/base_model.py +16 -0
  26. src/flux/annotator/midas/midas/blocks.py +342 -0
  27. src/flux/annotator/midas/midas/dpt_depth.py +109 -0
  28. src/flux/annotator/midas/midas/midas_net.py +76 -0
  29. src/flux/annotator/midas/midas/midas_net_custom.py +128 -0
  30. src/flux/annotator/midas/midas/transforms.py +234 -0
  31. src/flux/annotator/midas/midas/vit.py +491 -0
  32. src/flux/annotator/midas/utils.py +189 -0
  33. src/flux/annotator/mlsd/LICENSE +201 -0
  34. src/flux/annotator/mlsd/__init__.py +40 -0
  35. src/flux/annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
  36. src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
  37. src/flux/annotator/mlsd/utils.py +580 -0
  38. src/flux/annotator/tile/__init__.py +26 -0
  39. src/flux/annotator/tile/guided_filter.py +280 -0
  40. src/flux/annotator/util.py +38 -0
  41. src/flux/api.py +194 -0
  42. src/flux/cli.py +254 -0
  43. src/flux/controlnet.py +222 -0
  44. src/flux/math.py +30 -0
  45. src/flux/model.py +228 -0
  46. src/flux/modules/autoencoder.py +312 -0
  47. src/flux/modules/conditioner.py +38 -0
  48. src/flux/modules/layers.py +567 -0
  49. src/flux/sampling.py +242 -0
  50. src/flux/util.py +383 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ Makefile
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ weights/
25
+
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache/
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
- from flux.src.flux.xflux_pipeline import XFluxPipeline
5
  import random
6
  import spaces
7
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
+ from src.flux.xflux_pipeline import XFluxPipeline
5
  import random
6
  import spaces
7
 
cog.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ gpu: true
6
+ cuda: "12.1"
7
+ python_version: "3.11"
8
+ python_packages:
9
+ - "accelerate==0.30.1"
10
+ - "deepspeed==0.14.4"
11
+ - "einops==0.8.0"
12
+ - "transformers==4.43.3"
13
+ - "huggingface-hub==0.24.5"
14
+ - "einops==0.8.0"
15
+ - "pandas==2.2.2"
16
+ - "opencv-python==4.10.0.84"
17
+ - "pillow==10.4.0"
18
+ - "optimum-quanto==0.2.4"
19
+ - "sentencepiece==0.2.0"
20
+ run:
21
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
22
+
23
+ # predict.py defines how predictions are run on your model
24
+ predict: "predict.py:Predictor"
flux DELETED
@@ -1 +0,0 @@
1
- Subproject commit 9e1dd391b2316b1cfc20e523e2885fd30134a2e4
 
 
image_datasets/canny_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import json
8
+ import random
9
+ import cv2
10
+
11
+
12
+ def canny_processor(image, low_threshold=100, high_threshold=200):
13
+ image = np.array(image)
14
+ image = cv2.Canny(image, low_threshold, high_threshold)
15
+ image = image[:, :, None]
16
+ image = np.concatenate([image, image, image], axis=2)
17
+ canny_image = Image.fromarray(image)
18
+ return canny_image
19
+
20
+
21
+ def c_crop(image):
22
+ width, height = image.size
23
+ new_size = min(width, height)
24
+ left = (width - new_size) / 2
25
+ top = (height - new_size) / 2
26
+ right = (width + new_size) / 2
27
+ bottom = (height + new_size) / 2
28
+ return image.crop((left, top, right, bottom))
29
+
30
+ class CustomImageDataset(Dataset):
31
+ def __init__(self, img_dir, img_size=512):
32
+ self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
33
+ self.images.sort()
34
+ self.img_size = img_size
35
+
36
+ def __len__(self):
37
+ return len(self.images)
38
+
39
+ def __getitem__(self, idx):
40
+ try:
41
+ img = Image.open(self.images[idx])
42
+ img = c_crop(img)
43
+ img = img.resize((self.img_size, self.img_size))
44
+ hint = canny_processor(img)
45
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
46
+ img = img.permute(2, 0, 1)
47
+ hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
48
+ hint = hint.permute(2, 0, 1)
49
+ json_path = self.images[idx].split('.')[0] + '.json'
50
+ prompt = json.load(open(json_path))['caption']
51
+ return img, hint, prompt
52
+ except Exception as e:
53
+ print(e)
54
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
55
+
56
+
57
+ def loader(train_batch_size, num_workers, **args):
58
+ dataset = CustomImageDataset(**args)
59
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
image_datasets/dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import json
8
+ import random
9
+
10
+ def c_crop(image):
11
+ width, height = image.size
12
+ new_size = min(width, height)
13
+ left = (width - new_size) / 2
14
+ top = (height - new_size) / 2
15
+ right = (width + new_size) / 2
16
+ bottom = (height + new_size) / 2
17
+ return image.crop((left, top, right, bottom))
18
+
19
+ class CustomImageDataset(Dataset):
20
+ def __init__(self, img_dir, img_size=512):
21
+ self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
22
+ self.images.sort()
23
+ self.img_size = img_size
24
+
25
+ def __len__(self):
26
+ return len(self.images)
27
+
28
+ def __getitem__(self, idx):
29
+ try:
30
+ img = Image.open(self.images[idx])
31
+ img = c_crop(img)
32
+ img = img.resize((self.img_size, self.img_size))
33
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
34
+ img = img.permute(2, 0, 1)
35
+ json_path = self.images[idx].split('.')[0] + '.json'
36
+ prompt = json.load(open(json_path))['caption']
37
+ return img, prompt
38
+ except Exception as e:
39
+ print(e)
40
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
41
+
42
+
43
+ def loader(train_batch_size, num_workers, **args):
44
+ dataset = CustomImageDataset(**args)
45
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
main.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from PIL import Image
3
+ import os
4
+
5
+ from src.flux.xflux_pipeline import XFluxPipeline
6
+
7
+
8
+ def create_argparser():
9
+ parser = argparse.ArgumentParser()
10
+
11
+ parser.add_argument(
12
+ "--prompt", type=str, required=True,
13
+ help="The input text prompt"
14
+ )
15
+ parser.add_argument(
16
+ "--neg_prompt", type=str, default="",
17
+ help="The input text negative prompt"
18
+ )
19
+ parser.add_argument(
20
+ "--img_prompt", type=str, default=None,
21
+ help="Path to input image prompt"
22
+ )
23
+ parser.add_argument(
24
+ "--neg_img_prompt", type=str, default=None,
25
+ help="Path to input negative image prompt"
26
+ )
27
+ parser.add_argument(
28
+ "--ip_scale", type=float, default=1.0,
29
+ help="Strength of input image prompt"
30
+ )
31
+ parser.add_argument(
32
+ "--neg_ip_scale", type=float, default=1.0,
33
+ help="Strength of negative input image prompt"
34
+ )
35
+ parser.add_argument(
36
+ "--local_path", type=str, default=None,
37
+ help="Local path to the model checkpoint (Controlnet)"
38
+ )
39
+ parser.add_argument(
40
+ "--repo_id", type=str, default=None,
41
+ help="A HuggingFace repo id to download model (Controlnet)"
42
+ )
43
+ parser.add_argument(
44
+ "--name", type=str, default=None,
45
+ help="A filename to download from HuggingFace"
46
+ )
47
+ parser.add_argument(
48
+ "--ip_repo_id", type=str, default=None,
49
+ help="A HuggingFace repo id to download model (IP-Adapter)"
50
+ )
51
+ parser.add_argument(
52
+ "--ip_name", type=str, default=None,
53
+ help="A IP-Adapter filename to download from HuggingFace"
54
+ )
55
+ parser.add_argument(
56
+ "--ip_local_path", type=str, default=None,
57
+ help="Local path to the model checkpoint (IP-Adapter)"
58
+ )
59
+ parser.add_argument(
60
+ "--lora_repo_id", type=str, default=None,
61
+ help="A HuggingFace repo id to download model (LoRA)"
62
+ )
63
+ parser.add_argument(
64
+ "--lora_name", type=str, default=None,
65
+ help="A LoRA filename to download from HuggingFace"
66
+ )
67
+ parser.add_argument(
68
+ "--lora_local_path", type=str, default=None,
69
+ help="Local path to the model checkpoint (Controlnet)"
70
+ )
71
+ parser.add_argument(
72
+ "--device", type=str, default="cuda",
73
+ help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
74
+ )
75
+ parser.add_argument(
76
+ "--offload", action='store_true', help="Offload model to CPU when not in use"
77
+ )
78
+ parser.add_argument(
79
+ "--use_ip", action='store_true', help="Load IP model"
80
+ )
81
+ parser.add_argument(
82
+ "--use_lora", action='store_true', help="Load Lora model"
83
+ )
84
+ parser.add_argument(
85
+ "--use_controlnet", action='store_true', help="Load Controlnet model"
86
+ )
87
+ parser.add_argument(
88
+ "--num_images_per_prompt", type=int, default=1,
89
+ help="The number of images to generate per prompt"
90
+ )
91
+ parser.add_argument(
92
+ "--image", type=str, default=None, help="Path to image"
93
+ )
94
+ parser.add_argument(
95
+ "--lora_weight", type=float, default=0.9, help="Lora model strength (from 0 to 1.0)"
96
+ )
97
+ parser.add_argument(
98
+ "--control_type", type=str, default="canny",
99
+ choices=("canny", "openpose", "depth", "hed", "hough", "tile"),
100
+ help="Name of controlnet condition, example: canny"
101
+ )
102
+ parser.add_argument(
103
+ "--model_type", type=str, default="flux-dev",
104
+ choices=("flux-dev", "flux-dev-fp8", "flux-schnell"),
105
+ help="Model type to use (flux-dev, flux-dev-fp8, flux-schnell)"
106
+ )
107
+ parser.add_argument(
108
+ "--width", type=int, default=1024, help="The width for generated image"
109
+ )
110
+ parser.add_argument(
111
+ "--height", type=int, default=1024, help="The height for generated image"
112
+ )
113
+ parser.add_argument(
114
+ "--num_steps", type=int, default=25, help="The num_steps for diffusion process"
115
+ )
116
+ parser.add_argument(
117
+ "--guidance", type=float, default=4, help="The guidance for diffusion process"
118
+ )
119
+ parser.add_argument(
120
+ "--seed", type=int, default=123456789, help="A seed for reproducible inference"
121
+ )
122
+ parser.add_argument(
123
+ "--true_gs", type=float, default=3.5, help="true guidance"
124
+ )
125
+ parser.add_argument(
126
+ "--timestep_to_start_cfg", type=int, default=5, help="timestep to start true guidance"
127
+ )
128
+ parser.add_argument(
129
+ "--save_path", type=str, default='results', help="Path to save"
130
+ )
131
+ return parser
132
+
133
+
134
+ def main(args):
135
+ if args.image:
136
+ image = Image.open(args.image)
137
+ else:
138
+ image = None
139
+
140
+ xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload)
141
+ if args.use_ip:
142
+ print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name)
143
+ xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name)
144
+ if args.use_lora:
145
+ print('load lora:', args.lora_local_path, args.lora_repo_id, args.lora_name)
146
+ xflux_pipeline.set_lora(args.lora_local_path, args.lora_repo_id, args.lora_name, args.lora_weight)
147
+ if args.use_controlnet:
148
+ print('load controlnet:', args.local_path, args.repo_id, args.name)
149
+ xflux_pipeline.set_controlnet(args.control_type, args.local_path, args.repo_id, args.name)
150
+
151
+ image_prompt = Image.open(args.img_prompt) if args.img_prompt else None
152
+ neg_image_prompt = Image.open(args.neg_img_prompt) if args.neg_img_prompt else None
153
+
154
+ for _ in range(args.num_images_per_prompt):
155
+ result = xflux_pipeline(
156
+ prompt=args.prompt,
157
+ controlnet_image=image,
158
+ width=args.width,
159
+ height=args.height,
160
+ guidance=args.guidance,
161
+ num_steps=args.num_steps,
162
+ seed=args.seed,
163
+ true_gs=args.true_gs,
164
+ neg_prompt=args.neg_prompt,
165
+ timestep_to_start_cfg=args.timestep_to_start_cfg,
166
+ image_prompt=image_prompt,
167
+ neg_image_prompt=neg_image_prompt,
168
+ ip_scale=args.ip_scale,
169
+ neg_ip_scale=args.neg_ip_scale,
170
+ )
171
+ if not os.path.exists(args.save_path):
172
+ os.mkdir(args.save_path)
173
+ ind = len(os.listdir(args.save_path))
174
+ result.save(os.path.join(args.save_path, f"result_{ind}.png"))
175
+ args.seed = args.seed + 1
176
+
177
+
178
+ if __name__ == "__main__":
179
+ args = create_argparser().parse_args()
180
+ main(args)
models_licence/LICENSE-FLUX1-dev ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FLUX.1 [dev] Non-Commercial License
2
+ Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] text-to-image AI model and its elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI model made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”).
3
+ By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
4
+ 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings:
5
+ a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
6
+ b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
7
+ c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose.
8
+ d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
9
+ e. “you” or “your” means the individual or entity entering into this License with Company.
10
+ 2. License Grant.
11
+ a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf.
12
+ b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: [email protected].
13
+ c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
14
+ d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model.
15
+ 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
16
+ a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
17
+ b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
18
+ “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc.
19
+ IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
20
+ c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and
21
+ d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions.
22
+ e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
23
+ 4. Restrictions. You will not, and will not permit, assist or cause any third party to
24
+ a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
25
+ b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
26
+ c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or
27
+ d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License.
28
+ e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
29
+ f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
30
+ 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
31
+ 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
32
+ 7. INDEMNIFICATION
33
+
34
+ You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
35
+ 8. Termination; Survival.
36
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
37
+ b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
38
+ c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
39
+ d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11.
40
+ 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
41
+ 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
42
+ 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company.
predict.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ import os
6
+ import time
7
+ import torch
8
+ import subprocess
9
+ from PIL import Image
10
+ from typing import List
11
+ from image_datasets.canny_dataset import canny_processor, c_crop
12
+ from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors
13
+
14
+ OUTPUT_DIR = "controlnet_results"
15
+ MODEL_CACHE = "checkpoints"
16
+ CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors"
17
+ T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar"
18
+ CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar"
19
+ HF_TOKEN = "hf_..." # Your HuggingFace token
20
+
21
+ def download_weights(url, dest):
22
+ start = time.time()
23
+ print("downloading url: ", url)
24
+ print("downloading to: ", dest)
25
+ subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
26
+ print("downloading took: ", time.time() - start)
27
+
28
+ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
29
+ t5 = load_t5(device, max_length=256 if is_schnell else 512)
30
+ clip = load_clip(device)
31
+ model = load_flow_model(name, device="cpu" if offload else device)
32
+ ae = load_ae(name, device="cpu" if offload else device)
33
+ controlnet = load_controlnet(name, device).to(torch.bfloat16)
34
+ return model, ae, t5, clip, controlnet
35
+
36
+ class Predictor(BasePredictor):
37
+ def setup(self) -> None:
38
+ """Load the model into memory to make running multiple predictions efficient"""
39
+ t1 = time.time()
40
+ os.system(f"huggingface-cli login --token {HF_TOKEN}")
41
+ name = "flux-dev"
42
+ self.offload = False
43
+ checkpoint = "controlnet.safetensors"
44
+
45
+ print("Checking ControlNet weights")
46
+ checkpoint = "controlnet.safetensors"
47
+ if not os.path.exists(checkpoint):
48
+ os.system(f"wget {CONTROLNET_URL}")
49
+ print("Checking T5 weights")
50
+ if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"):
51
+ download_weights(T5_URL, MODEL_CACHE)
52
+ print("Checking CLIP weights")
53
+ if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"):
54
+ download_weights(CLIP_URL, MODEL_CACHE)
55
+
56
+ self.is_schnell = False
57
+ device = "cuda"
58
+ self.torch_device = torch.device(device)
59
+ model, ae, t5, clip, controlnet = get_models(
60
+ name,
61
+ device=self.torch_device,
62
+ offload=self.offload,
63
+ is_schnell=self.is_schnell,
64
+ )
65
+ self.ae = ae
66
+ self.t5 = t5
67
+ self.clip = clip
68
+ self.controlnet = controlnet
69
+ self.model = model.to(self.torch_device)
70
+ if '.safetensors' in checkpoint:
71
+ checkpoint1 = load_safetensors(checkpoint)
72
+ else:
73
+ checkpoint1 = torch.load(checkpoint, map_location='cpu')
74
+
75
+ controlnet.load_state_dict(checkpoint1, strict=False)
76
+ t2 = time.time()
77
+ print(f"Setup time: {t2 - t1}")
78
+
79
+ def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512):
80
+ image = Image.open(image_path)
81
+ image = c_crop(image)
82
+ image = image.resize((width, height))
83
+ image = canny_processor(image)
84
+ return image
85
+
86
+ def predict(
87
+ self,
88
+ prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"),
89
+ image: Path = Input(description="Input image", default=None),
90
+ num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28),
91
+ cfg: float = Input(description="CFG", ge=0, le=10, default=3.5),
92
+ seed: int = Input(description="Random seed", default=None)
93
+ ) -> List[Path]:
94
+ """Run a single prediction on the model"""
95
+ if seed is None:
96
+ seed = int.from_bytes(os.urandom(2), "big")
97
+ print(f"Using seed: {seed}")
98
+
99
+ # clean output dir
100
+ output_dir = "controlnet_results"
101
+ os.system(f"rm -rf {output_dir}")
102
+
103
+ input_image = str(image)
104
+ img = Image.open(input_image)
105
+ width, height = img.size
106
+ # Resize input image if it's too large
107
+ max_image_size = 1536
108
+ scale = min(max_image_size / width, max_image_size / height, 1)
109
+ if scale < 1:
110
+ width = int(width * scale)
111
+ height = int(height * scale)
112
+ print(f"Scaling image down to {width}x{height}")
113
+ img = img.resize((width, height), resample=Image.Resampling.LANCZOS)
114
+ input_image = "/tmp/resized_image.png"
115
+ img.save(input_image)
116
+
117
+ subprocess.check_call(
118
+ ["python3", "main.py",
119
+ "--local_path", "controlnet.safetensors",
120
+ "--image", input_image,
121
+ "--use_controlnet",
122
+ "--control_type", "canny",
123
+ "--prompt", prompt,
124
+ "--width", str(width),
125
+ "--height", str(height),
126
+ "--num_steps", str(num_inference_steps),
127
+ "--guidance", str(cfg),
128
+ "--seed", str(seed)
129
+ ], close_fds=False)
130
+
131
+ # Find the first file that begins with "controlnet_result_"
132
+ for file in os.listdir(output_dir):
133
+ if file.startswith("controlnet_result_"):
134
+ return [Path(os.path.join(output_dir, file))]
src/flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
src/flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
src/flux/annotator/canny/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class CannyDetector:
5
+ def __call__(self, img, low_threshold, high_threshold):
6
+ return cv2.Canny(img, low_threshold, high_threshold)
src/flux/annotator/ckpts/ckpts.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Weights here.
src/flux/annotator/dwpose/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+
7
+ import os
8
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9
+
10
+ import torch
11
+ import numpy as np
12
+ from . import util
13
+ from .wholebody import Wholebody
14
+
15
+ def draw_pose(pose, H, W):
16
+ bodies = pose['bodies']
17
+ faces = pose['faces']
18
+ hands = pose['hands']
19
+ candidate = bodies['candidate']
20
+ subset = bodies['subset']
21
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
22
+
23
+ canvas = util.draw_bodypose(canvas, candidate, subset)
24
+
25
+ canvas = util.draw_handpose(canvas, hands)
26
+
27
+ canvas = util.draw_facepose(canvas, faces)
28
+
29
+ return canvas
30
+
31
+
32
+ class DWposeDetector:
33
+ def __init__(self, device):
34
+
35
+ self.pose_estimation = Wholebody(device)
36
+
37
+ def __call__(self, oriImg):
38
+ oriImg = oriImg.copy()
39
+ H, W, C = oriImg.shape
40
+ with torch.no_grad():
41
+ candidate, subset = self.pose_estimation(oriImg)
42
+ nums, keys, locs = candidate.shape
43
+ candidate[..., 0] /= float(W)
44
+ candidate[..., 1] /= float(H)
45
+ body = candidate[:,:18].copy()
46
+ body = body.reshape(nums*18, locs)
47
+ score = subset[:,:18]
48
+ for i in range(len(score)):
49
+ for j in range(len(score[i])):
50
+ if score[i][j] > 0.3:
51
+ score[i][j] = int(18*i+j)
52
+ else:
53
+ score[i][j] = -1
54
+
55
+ un_visible = subset<0.3
56
+ candidate[un_visible] = -1
57
+
58
+ foot = candidate[:,18:24]
59
+
60
+ faces = candidate[:,24:92]
61
+
62
+ hands = candidate[:,92:113]
63
+ hands = np.vstack([hands, candidate[:,113:]])
64
+
65
+ bodies = dict(candidate=body, subset=score)
66
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
67
+
68
+ return draw_pose(pose, H, W)
src/flux/annotator/dwpose/onnxdet.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime
5
+
6
+ def nms(boxes, scores, nms_thr):
7
+ """Single class NMS implemented in Numpy."""
8
+ x1 = boxes[:, 0]
9
+ y1 = boxes[:, 1]
10
+ x2 = boxes[:, 2]
11
+ y2 = boxes[:, 3]
12
+
13
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
14
+ order = scores.argsort()[::-1]
15
+
16
+ keep = []
17
+ while order.size > 0:
18
+ i = order[0]
19
+ keep.append(i)
20
+ xx1 = np.maximum(x1[i], x1[order[1:]])
21
+ yy1 = np.maximum(y1[i], y1[order[1:]])
22
+ xx2 = np.minimum(x2[i], x2[order[1:]])
23
+ yy2 = np.minimum(y2[i], y2[order[1:]])
24
+
25
+ w = np.maximum(0.0, xx2 - xx1 + 1)
26
+ h = np.maximum(0.0, yy2 - yy1 + 1)
27
+ inter = w * h
28
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
29
+
30
+ inds = np.where(ovr <= nms_thr)[0]
31
+ order = order[inds + 1]
32
+
33
+ return keep
34
+
35
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
36
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
37
+ final_dets = []
38
+ num_classes = scores.shape[1]
39
+ for cls_ind in range(num_classes):
40
+ cls_scores = scores[:, cls_ind]
41
+ valid_score_mask = cls_scores > score_thr
42
+ if valid_score_mask.sum() == 0:
43
+ continue
44
+ else:
45
+ valid_scores = cls_scores[valid_score_mask]
46
+ valid_boxes = boxes[valid_score_mask]
47
+ keep = nms(valid_boxes, valid_scores, nms_thr)
48
+ if len(keep) > 0:
49
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
50
+ dets = np.concatenate(
51
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
52
+ )
53
+ final_dets.append(dets)
54
+ if len(final_dets) == 0:
55
+ return None
56
+ return np.concatenate(final_dets, 0)
57
+
58
+ def demo_postprocess(outputs, img_size, p6=False):
59
+ grids = []
60
+ expanded_strides = []
61
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
62
+
63
+ hsizes = [img_size[0] // stride for stride in strides]
64
+ wsizes = [img_size[1] // stride for stride in strides]
65
+
66
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
67
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
68
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
69
+ grids.append(grid)
70
+ shape = grid.shape[:2]
71
+ expanded_strides.append(np.full((*shape, 1), stride))
72
+
73
+ grids = np.concatenate(grids, 1)
74
+ expanded_strides = np.concatenate(expanded_strides, 1)
75
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
76
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
77
+
78
+ return outputs
79
+
80
+ def preprocess(img, input_size, swap=(2, 0, 1)):
81
+ if len(img.shape) == 3:
82
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
83
+ else:
84
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
85
+
86
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
87
+ resized_img = cv2.resize(
88
+ img,
89
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
90
+ interpolation=cv2.INTER_LINEAR,
91
+ ).astype(np.uint8)
92
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
93
+
94
+ padded_img = padded_img.transpose(swap)
95
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
96
+ return padded_img, r
97
+
98
+ def inference_detector(session, oriImg):
99
+ input_shape = (640,640)
100
+ img, ratio = preprocess(oriImg, input_shape)
101
+
102
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
103
+ output = session.run(None, ort_inputs)
104
+ predictions = demo_postprocess(output[0], input_shape)[0]
105
+
106
+ boxes = predictions[:, :4]
107
+ scores = predictions[:, 4:5] * predictions[:, 5:]
108
+
109
+ boxes_xyxy = np.ones_like(boxes)
110
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
111
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
112
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
113
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
114
+ boxes_xyxy /= ratio
115
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
116
+ if dets is not None:
117
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
118
+ isscore = final_scores>0.3
119
+ iscat = final_cls_inds == 0
120
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
121
+ final_boxes = final_boxes[isbbox]
122
+ else:
123
+ final_boxes = np.array([])
124
+
125
+ return final_boxes
src/flux/annotator/dwpose/onnxpose.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for RTMPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+
52
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53
+ """Inference RTMPose model.
54
+
55
+ Args:
56
+ sess (ort.InferenceSession): ONNXRuntime session.
57
+ img (np.ndarray): Input image in shape.
58
+
59
+ Returns:
60
+ outputs (np.ndarray): Output of RTMPose model.
61
+ """
62
+ all_out = []
63
+ # build input
64
+ for i in range(len(img)):
65
+ input = [img[i].transpose(2, 0, 1)]
66
+
67
+ # build output
68
+ sess_input = {sess.get_inputs()[0].name: input}
69
+ sess_output = []
70
+ for out in sess.get_outputs():
71
+ sess_output.append(out.name)
72
+
73
+ # run model
74
+ outputs = sess.run(sess_output, sess_input)
75
+ all_out.append(outputs)
76
+
77
+ return all_out
78
+
79
+
80
+ def postprocess(outputs: List[np.ndarray],
81
+ model_input_size: Tuple[int, int],
82
+ center: Tuple[int, int],
83
+ scale: Tuple[int, int],
84
+ simcc_split_ratio: float = 2.0
85
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
+ """Postprocess for RTMPose model output.
87
+
88
+ Args:
89
+ outputs (np.ndarray): Output of RTMPose model.
90
+ model_input_size (tuple): RTMPose model Input image size.
91
+ center (tuple): Center of bbox in shape (x, y).
92
+ scale (tuple): Scale of bbox in shape (w, h).
93
+ simcc_split_ratio (float): Split ratio of simcc.
94
+
95
+ Returns:
96
+ tuple:
97
+ - keypoints (np.ndarray): Rescaled keypoints.
98
+ - scores (np.ndarray): Model predict scores.
99
+ """
100
+ all_key = []
101
+ all_score = []
102
+ for i in range(len(outputs)):
103
+ # use simcc to decode
104
+ simcc_x, simcc_y = outputs[i]
105
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106
+
107
+ # rescale keypoints
108
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109
+ all_key.append(keypoints[0])
110
+ all_score.append(scores[0])
111
+
112
+ return np.array(all_key), np.array(all_score)
113
+
114
+
115
+ def bbox_xyxy2cs(bbox: np.ndarray,
116
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
118
+
119
+ Args:
120
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121
+ as (left, top, right, bottom)
122
+ padding (float): BBox padding factor that will be multilied to scale.
123
+ Default: 1.0
124
+
125
+ Returns:
126
+ tuple: A tuple containing center and scale.
127
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128
+ (n, 2)
129
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ """
132
+ # convert single bbox from (4, ) to (1, 4)
133
+ dim = bbox.ndim
134
+ if dim == 1:
135
+ bbox = bbox[None, :]
136
+
137
+ # get bbox center and scale
138
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
141
+
142
+ if dim == 1:
143
+ center = center[0]
144
+ scale = scale[0]
145
+
146
+ return center, scale
147
+
148
+
149
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
150
+ aspect_ratio: float) -> np.ndarray:
151
+ """Extend the scale to match the given aspect ratio.
152
+
153
+ Args:
154
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
155
+ aspect_ratio (float): The ratio of ``w/h``
156
+
157
+ Returns:
158
+ np.ndarray: The reshaped image scale in (2, )
159
+ """
160
+ w, h = np.hsplit(bbox_scale, [1])
161
+ bbox_scale = np.where(w > h * aspect_ratio,
162
+ np.hstack([w, w / aspect_ratio]),
163
+ np.hstack([h * aspect_ratio, h]))
164
+ return bbox_scale
165
+
166
+
167
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168
+ """Rotate a point by an angle.
169
+
170
+ Args:
171
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172
+ angle_rad (float): rotation angle in radian
173
+
174
+ Returns:
175
+ np.ndarray: Rotated point in shape (2, )
176
+ """
177
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
179
+ return rot_mat @ pt
180
+
181
+
182
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183
+ """To calculate the affine matrix, three pairs of points are required. This
184
+ function is used to get the 3rd point, given 2D points a & b.
185
+
186
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
187
+ anticlockwise, using b as the rotation center.
188
+
189
+ Args:
190
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
191
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
192
+
193
+ Returns:
194
+ np.ndarray: The 3rd point.
195
+ """
196
+ direction = a - b
197
+ c = b + np.r_[-direction[1], direction[0]]
198
+ return c
199
+
200
+
201
+ def get_warp_matrix(center: np.ndarray,
202
+ scale: np.ndarray,
203
+ rot: float,
204
+ output_size: Tuple[int, int],
205
+ shift: Tuple[float, float] = (0., 0.),
206
+ inv: bool = False) -> np.ndarray:
207
+ """Calculate the affine transformation matrix that can warp the bbox area
208
+ in the input image to the output size.
209
+
210
+ Args:
211
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
212
+ scale (np.ndarray[2, ]): Scale of the bounding box
213
+ wrt [width, height].
214
+ rot (float): Rotation angle (degree).
215
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
216
+ destination heatmaps.
217
+ shift (0-100%): Shift translation ratio wrt the width/height.
218
+ Default (0., 0.).
219
+ inv (bool): Option to inverse the affine transform direction.
220
+ (inv=False: src->dst or inv=True: dst->src)
221
+
222
+ Returns:
223
+ np.ndarray: A 2x3 transformation matrix
224
+ """
225
+ shift = np.array(shift)
226
+ src_w = scale[0]
227
+ dst_w = output_size[0]
228
+ dst_h = output_size[1]
229
+
230
+ # compute transformation matrix
231
+ rot_rad = np.deg2rad(rot)
232
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233
+ dst_dir = np.array([0., dst_w * -0.5])
234
+
235
+ # get four corners of the src rectangle in the original image
236
+ src = np.zeros((3, 2), dtype=np.float32)
237
+ src[0, :] = center + scale * shift
238
+ src[1, :] = center + src_dir + scale * shift
239
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240
+
241
+ # get four corners of the dst rectangle in the input image
242
+ dst = np.zeros((3, 2), dtype=np.float32)
243
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246
+
247
+ if inv:
248
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249
+ else:
250
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251
+
252
+ return warp_mat
253
+
254
+
255
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Get the bbox image as the model input by affine transform.
258
+
259
+ Args:
260
+ input_size (dict): The input size of the model.
261
+ bbox_scale (dict): The bbox scale of the img.
262
+ bbox_center (dict): The bbox center of the img.
263
+ img (np.ndarray): The original image.
264
+
265
+ Returns:
266
+ tuple: A tuple containing center and scale.
267
+ - np.ndarray[float32]: img after affine transform.
268
+ - np.ndarray[float32]: bbox scale after affine transform.
269
+ """
270
+ w, h = input_size
271
+ warp_size = (int(w), int(h))
272
+
273
+ # reshape bbox to fixed aspect ratio
274
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275
+
276
+ # get the affine matrix
277
+ center = bbox_center
278
+ scale = bbox_scale
279
+ rot = 0
280
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281
+
282
+ # do affine transform
283
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284
+
285
+ return img, bbox_scale
286
+
287
+
288
+ def get_simcc_maximum(simcc_x: np.ndarray,
289
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290
+ """Get maximum response location and value from simcc representations.
291
+
292
+ Note:
293
+ instance number: N
294
+ num_keypoints: K
295
+ heatmap height: H
296
+ heatmap width: W
297
+
298
+ Args:
299
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301
+
302
+ Returns:
303
+ tuple:
304
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
305
+ (K, 2) or (N, K, 2)
306
+ - vals (np.ndarray): values of maximum heatmap responses in shape
307
+ (K,) or (N, K)
308
+ """
309
+ N, K, Wx = simcc_x.shape
310
+ simcc_x = simcc_x.reshape(N * K, -1)
311
+ simcc_y = simcc_y.reshape(N * K, -1)
312
+
313
+ # get maximum value locations
314
+ x_locs = np.argmax(simcc_x, axis=1)
315
+ y_locs = np.argmax(simcc_y, axis=1)
316
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317
+ max_val_x = np.amax(simcc_x, axis=1)
318
+ max_val_y = np.amax(simcc_y, axis=1)
319
+
320
+ # get maximum value across x and y axis
321
+ mask = max_val_x > max_val_y
322
+ max_val_x[mask] = max_val_y[mask]
323
+ vals = max_val_x
324
+ locs[vals <= 0.] = -1
325
+
326
+ # reshape
327
+ locs = locs.reshape(N, K, 2)
328
+ vals = vals.reshape(N, K)
329
+
330
+ return locs, vals
331
+
332
+
333
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335
+ """Modulate simcc distribution with Gaussian.
336
+
337
+ Args:
338
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340
+ simcc_split_ratio (int): The split ratio of simcc.
341
+
342
+ Returns:
343
+ tuple: A tuple containing center and scale.
344
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
346
+ """
347
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348
+ keypoints /= simcc_split_ratio
349
+
350
+ return keypoints, scores
351
+
352
+
353
+ def inference_pose(session, out_bbox, oriImg):
354
+ h, w = session.get_inputs()[0].shape[2:]
355
+ model_input_size = (w, h)
356
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
357
+ outputs = inference(session, resized_img)
358
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
359
+
360
+ return keypoints, scores
src/flux/annotator/dwpose/util.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import cv2
5
+
6
+
7
+ eps = 0.01
8
+
9
+
10
+ def smart_resize(x, s):
11
+ Ht, Wt = s
12
+ if x.ndim == 2:
13
+ Ho, Wo = x.shape
14
+ Co = 1
15
+ else:
16
+ Ho, Wo, Co = x.shape
17
+ if Co == 3 or Co == 1:
18
+ k = float(Ht + Wt) / float(Ho + Wo)
19
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
20
+ else:
21
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
22
+
23
+
24
+ def smart_resize_k(x, fx, fy):
25
+ if x.ndim == 2:
26
+ Ho, Wo = x.shape
27
+ Co = 1
28
+ else:
29
+ Ho, Wo, Co = x.shape
30
+ Ht, Wt = Ho * fy, Wo * fx
31
+ if Co == 3 or Co == 1:
32
+ k = float(Ht + Wt) / float(Ho + Wo)
33
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
34
+ else:
35
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
36
+
37
+
38
+ def padRightDownCorner(img, stride, padValue):
39
+ h = img.shape[0]
40
+ w = img.shape[1]
41
+
42
+ pad = 4 * [None]
43
+ pad[0] = 0 # up
44
+ pad[1] = 0 # left
45
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
46
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
47
+
48
+ img_padded = img
49
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
50
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
51
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
52
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
53
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
54
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
55
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
56
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
57
+
58
+ return img_padded, pad
59
+
60
+
61
+ def transfer(model, model_weights):
62
+ transfered_model_weights = {}
63
+ for weights_name in model.state_dict().keys():
64
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
65
+ return transfered_model_weights
66
+
67
+
68
+ def draw_bodypose(canvas, candidate, subset):
69
+ H, W, C = canvas.shape
70
+ candidate = np.array(candidate)
71
+ subset = np.array(subset)
72
+
73
+ stickwidth = 4
74
+
75
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
76
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
77
+ [1, 16], [16, 18], [3, 17], [6, 18]]
78
+
79
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
80
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
81
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
82
+
83
+ for i in range(17):
84
+ for n in range(len(subset)):
85
+ index = subset[n][np.array(limbSeq[i]) - 1]
86
+ if -1 in index:
87
+ continue
88
+ Y = candidate[index.astype(int), 0] * float(W)
89
+ X = candidate[index.astype(int), 1] * float(H)
90
+ mX = np.mean(X)
91
+ mY = np.mean(Y)
92
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
93
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
94
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
95
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
96
+
97
+ canvas = (canvas * 0.6).astype(np.uint8)
98
+
99
+ for i in range(18):
100
+ for n in range(len(subset)):
101
+ index = int(subset[n][i])
102
+ if index == -1:
103
+ continue
104
+ x, y = candidate[index][0:2]
105
+ x = int(x * W)
106
+ y = int(y * H)
107
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
108
+
109
+ return canvas
110
+
111
+
112
+ def draw_handpose(canvas, all_hand_peaks):
113
+ H, W, C = canvas.shape
114
+
115
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
116
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
117
+
118
+ for peaks in all_hand_peaks:
119
+ peaks = np.array(peaks)
120
+
121
+ for ie, e in enumerate(edges):
122
+ x1, y1 = peaks[e[0]]
123
+ x2, y2 = peaks[e[1]]
124
+ x1 = int(x1 * W)
125
+ y1 = int(y1 * H)
126
+ x2 = int(x2 * W)
127
+ y2 = int(y2 * H)
128
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
129
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
130
+
131
+ for i, keyponit in enumerate(peaks):
132
+ x, y = keyponit
133
+ x = int(x * W)
134
+ y = int(y * H)
135
+ if x > eps and y > eps:
136
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
137
+ return canvas
138
+
139
+
140
+ def draw_facepose(canvas, all_lmks):
141
+ H, W, C = canvas.shape
142
+ for lmks in all_lmks:
143
+ lmks = np.array(lmks)
144
+ for lmk in lmks:
145
+ x, y = lmk
146
+ x = int(x * W)
147
+ y = int(y * H)
148
+ if x > eps and y > eps:
149
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
150
+ return canvas
151
+
152
+
153
+ # detect hand according to body pose keypoints
154
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
155
+ def handDetect(candidate, subset, oriImg):
156
+ # right hand: wrist 4, elbow 3, shoulder 2
157
+ # left hand: wrist 7, elbow 6, shoulder 5
158
+ ratioWristElbow = 0.33
159
+ detect_result = []
160
+ image_height, image_width = oriImg.shape[0:2]
161
+ for person in subset.astype(int):
162
+ # if any of three not detected
163
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
164
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
165
+ if not (has_left or has_right):
166
+ continue
167
+ hands = []
168
+ #left hand
169
+ if has_left:
170
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
171
+ x1, y1 = candidate[left_shoulder_index][:2]
172
+ x2, y2 = candidate[left_elbow_index][:2]
173
+ x3, y3 = candidate[left_wrist_index][:2]
174
+ hands.append([x1, y1, x2, y2, x3, y3, True])
175
+ # right hand
176
+ if has_right:
177
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
178
+ x1, y1 = candidate[right_shoulder_index][:2]
179
+ x2, y2 = candidate[right_elbow_index][:2]
180
+ x3, y3 = candidate[right_wrist_index][:2]
181
+ hands.append([x1, y1, x2, y2, x3, y3, False])
182
+
183
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
184
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
185
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
186
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
187
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
188
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
189
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
190
+ x = x3 + ratioWristElbow * (x3 - x2)
191
+ y = y3 + ratioWristElbow * (y3 - y2)
192
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
193
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
194
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
195
+ # x-y refers to the center --> offset to topLeft point
196
+ # handRectangle.x -= handRectangle.width / 2.f;
197
+ # handRectangle.y -= handRectangle.height / 2.f;
198
+ x -= width / 2
199
+ y -= width / 2 # width = height
200
+ # overflow the image
201
+ if x < 0: x = 0
202
+ if y < 0: y = 0
203
+ width1 = width
204
+ width2 = width
205
+ if x + width > image_width: width1 = image_width - x
206
+ if y + width > image_height: width2 = image_height - y
207
+ width = min(width1, width2)
208
+ # the max hand box value is 20 pixels
209
+ if width >= 20:
210
+ detect_result.append([int(x), int(y), int(width), is_left])
211
+
212
+ '''
213
+ return value: [[x, y, w, True if left hand else False]].
214
+ width=height since the network require squared input.
215
+ x, y is the coordinate of top left
216
+ '''
217
+ return detect_result
218
+
219
+
220
+ # Written by Lvmin
221
+ def faceDetect(candidate, subset, oriImg):
222
+ # left right eye ear 14 15 16 17
223
+ detect_result = []
224
+ image_height, image_width = oriImg.shape[0:2]
225
+ for person in subset.astype(int):
226
+ has_head = person[0] > -1
227
+ if not has_head:
228
+ continue
229
+
230
+ has_left_eye = person[14] > -1
231
+ has_right_eye = person[15] > -1
232
+ has_left_ear = person[16] > -1
233
+ has_right_ear = person[17] > -1
234
+
235
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
236
+ continue
237
+
238
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
239
+
240
+ width = 0.0
241
+ x0, y0 = candidate[head][:2]
242
+
243
+ if has_left_eye:
244
+ x1, y1 = candidate[left_eye][:2]
245
+ d = max(abs(x0 - x1), abs(y0 - y1))
246
+ width = max(width, d * 3.0)
247
+
248
+ if has_right_eye:
249
+ x1, y1 = candidate[right_eye][:2]
250
+ d = max(abs(x0 - x1), abs(y0 - y1))
251
+ width = max(width, d * 3.0)
252
+
253
+ if has_left_ear:
254
+ x1, y1 = candidate[left_ear][:2]
255
+ d = max(abs(x0 - x1), abs(y0 - y1))
256
+ width = max(width, d * 1.5)
257
+
258
+ if has_right_ear:
259
+ x1, y1 = candidate[right_ear][:2]
260
+ d = max(abs(x0 - x1), abs(y0 - y1))
261
+ width = max(width, d * 1.5)
262
+
263
+ x, y = x0, y0
264
+
265
+ x -= width
266
+ y -= width
267
+
268
+ if x < 0:
269
+ x = 0
270
+
271
+ if y < 0:
272
+ y = 0
273
+
274
+ width1 = width * 2
275
+ width2 = width * 2
276
+
277
+ if x + width > image_width:
278
+ width1 = image_width - x
279
+
280
+ if y + width > image_height:
281
+ width2 = image_height - y
282
+
283
+ width = min(width1, width2)
284
+
285
+ if width >= 20:
286
+ detect_result.append([int(x), int(y), int(width)])
287
+
288
+ return detect_result
289
+
290
+
291
+ # get max index of 2d array
292
+ def npmax(array):
293
+ arrayindex = array.argmax(1)
294
+ arrayvalue = array.max(1)
295
+ i = arrayvalue.argmax()
296
+ j = arrayindex[i]
297
+ return i, j
src/flux/annotator/dwpose/wholebody.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import onnxruntime as ort
5
+ from huggingface_hub import hf_hub_download
6
+ from .onnxdet import inference_detector
7
+ from .onnxpose import inference_pose
8
+
9
+
10
+ class Wholebody:
11
+ def __init__(self, device="cuda:0"):
12
+ providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13
+ onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx")
14
+ onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx")
15
+
16
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
17
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
18
+
19
+ def __call__(self, oriImg):
20
+ det_result = inference_detector(self.session_det, oriImg)
21
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
22
+
23
+ keypoints_info = np.concatenate(
24
+ (keypoints, scores[..., None]), axis=-1)
25
+ # compute neck joint
26
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
27
+ # neck score when visualizing pred
28
+ neck[:, 2:4] = np.logical_and(
29
+ keypoints_info[:, 5, 2:4] > 0.3,
30
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
31
+ new_keypoints_info = np.insert(
32
+ keypoints_info, 17, neck, axis=1)
33
+ mmpose_idx = [
34
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
35
+ ]
36
+ openpose_idx = [
37
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
38
+ ]
39
+ new_keypoints_info[:, openpose_idx] = \
40
+ new_keypoints_info[:, mmpose_idx]
41
+ keypoints_info = new_keypoints_info
42
+
43
+ keypoints, scores = keypoints_info[
44
+ ..., :2], keypoints_info[..., 2]
45
+
46
+ return keypoints, scores
47
+
48
+
src/flux/annotator/hed/__init__.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import cv2
10
+ import torch
11
+ import numpy as np
12
+
13
+ from huggingface_hub import hf_hub_download
14
+ from einops import rearrange
15
+ from ...annotator.util import annotator_ckpts_path
16
+
17
+
18
+ class DoubleConvBlock(torch.nn.Module):
19
+ def __init__(self, input_channel, output_channel, layer_number):
20
+ super().__init__()
21
+ self.convs = torch.nn.Sequential()
22
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
23
+ for i in range(1, layer_number):
24
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
25
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
26
+
27
+ def __call__(self, x, down_sampling=False):
28
+ h = x
29
+ if down_sampling:
30
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
31
+ for conv in self.convs:
32
+ h = conv(h)
33
+ h = torch.nn.functional.relu(h)
34
+ return h, self.projection(h)
35
+
36
+
37
+ class ControlNetHED_Apache2(torch.nn.Module):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
41
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
42
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
43
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
44
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
45
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
46
+
47
+ def __call__(self, x):
48
+ h = x - self.norm
49
+ h, projection1 = self.block1(h)
50
+ h, projection2 = self.block2(h, down_sampling=True)
51
+ h, projection3 = self.block3(h, down_sampling=True)
52
+ h, projection4 = self.block4(h, down_sampling=True)
53
+ h, projection5 = self.block5(h, down_sampling=True)
54
+ return projection1, projection2, projection3, projection4, projection5
55
+
56
+
57
+ class HEDdetector:
58
+ def __init__(self):
59
+ modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
60
+ if not os.path.exists(modelpath):
61
+ modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
62
+ self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
63
+ self.netNetwork.load_state_dict(torch.load(modelpath))
64
+
65
+ def __call__(self, input_image):
66
+ assert input_image.ndim == 3
67
+ H, W, C = input_image.shape
68
+ with torch.no_grad():
69
+ image_hed = torch.from_numpy(input_image.copy()).float().cuda()
70
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
71
+ edges = self.netNetwork(image_hed)
72
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
73
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
74
+ edges = np.stack(edges, axis=2)
75
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
76
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
77
+ return edge
78
+
79
+
80
+ def nms(x, t, s):
81
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
82
+
83
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
84
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
85
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
86
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
87
+
88
+ y = np.zeros_like(x)
89
+
90
+ for f in [f1, f2, f3, f4]:
91
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
92
+
93
+ z = np.zeros_like(y, dtype=np.uint8)
94
+ z[y > t] = 255
95
+ return z
src/flux/annotator/midas/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/flux/annotator/midas/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Midas Depth Estimation
2
+ # From https://github.com/isl-org/MiDaS
3
+ # MIT LICENSE
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from einops import rearrange
10
+ from .api import MiDaSInference
11
+
12
+
13
+ class MidasDetector:
14
+ def __init__(self):
15
+ self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
16
+
17
+ def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
18
+ assert input_image.ndim == 3
19
+ image_depth = input_image
20
+ with torch.no_grad():
21
+ image_depth = torch.from_numpy(image_depth).float().cuda()
22
+ image_depth = image_depth / 127.5 - 1.0
23
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
24
+ depth = self.model(image_depth)[0]
25
+
26
+ depth_pt = depth.clone()
27
+ depth_pt -= torch.min(depth_pt)
28
+ depth_pt /= torch.max(depth_pt)
29
+ depth_pt = depth_pt.cpu().numpy()
30
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
31
+
32
+ depth_np = depth.cpu().numpy()
33
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
34
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
35
+ z = np.ones_like(x) * a
36
+ x[depth_pt < bg_th] = 0
37
+ y[depth_pt < bg_th] = 0
38
+ normal = np.stack([x, y, z], axis=2)
39
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
40
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
41
+
42
+ return depth_image, normal_image
src/flux/annotator/midas/api.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import Compose
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ from .midas.dpt_depth import DPTDepthModel
12
+ from .midas.midas_net import MidasNet
13
+ from .midas.midas_net_custom import MidasNet_small
14
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
15
+ from ...annotator.util import annotator_ckpts_path
16
+
17
+
18
+ ISL_PATHS = {
19
+ "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
20
+ "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
21
+ "midas_v21": "",
22
+ "midas_v21_small": "",
23
+ }
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == "dpt_large": # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = "minimal"
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39
+
40
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
41
+ net_w, net_h = 384, 384
42
+ resize_mode = "minimal"
43
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
+
45
+ elif model_type == "midas_v21":
46
+ net_w, net_h = 384, 384
47
+ resize_mode = "upper_bound"
48
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49
+
50
+ elif model_type == "midas_v21_small":
51
+ net_w, net_h = 256, 256
52
+ resize_mode = "upper_bound"
53
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
+
55
+ else:
56
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57
+
58
+ transform = Compose(
59
+ [
60
+ Resize(
61
+ net_w,
62
+ net_h,
63
+ resize_target=None,
64
+ keep_aspect_ratio=True,
65
+ ensure_multiple_of=32,
66
+ resize_method=resize_mode,
67
+ image_interpolation_method=cv2.INTER_CUBIC,
68
+ ),
69
+ normalization,
70
+ PrepareForNet(),
71
+ ]
72
+ )
73
+
74
+ return transform
75
+
76
+
77
+ def load_model(model_type):
78
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
79
+ # load network
80
+ model_path = ISL_PATHS[model_type]
81
+ if model_type == "dpt_large": # DPT-Large
82
+ model = DPTDepthModel(
83
+ path=model_path,
84
+ backbone="vitl16_384",
85
+ non_negative=True,
86
+ )
87
+ net_w, net_h = 384, 384
88
+ resize_mode = "minimal"
89
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90
+
91
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
92
+ if not os.path.exists(model_path):
93
+ model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt")
94
+
95
+ model = DPTDepthModel(
96
+ path=model_path,
97
+ backbone="vitb_rn50_384",
98
+ non_negative=True,
99
+ )
100
+ net_w, net_h = 384, 384
101
+ resize_mode = "minimal"
102
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
103
+
104
+ elif model_type == "midas_v21":
105
+ model = MidasNet(model_path, non_negative=True)
106
+ net_w, net_h = 384, 384
107
+ resize_mode = "upper_bound"
108
+ normalization = NormalizeImage(
109
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
110
+ )
111
+
112
+ elif model_type == "midas_v21_small":
113
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
114
+ non_negative=True, blocks={'expand': True})
115
+ net_w, net_h = 256, 256
116
+ resize_mode = "upper_bound"
117
+ normalization = NormalizeImage(
118
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
119
+ )
120
+
121
+ else:
122
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
123
+ assert False
124
+
125
+ transform = Compose(
126
+ [
127
+ Resize(
128
+ net_w,
129
+ net_h,
130
+ resize_target=None,
131
+ keep_aspect_ratio=True,
132
+ ensure_multiple_of=32,
133
+ resize_method=resize_mode,
134
+ image_interpolation_method=cv2.INTER_CUBIC,
135
+ ),
136
+ normalization,
137
+ PrepareForNet(),
138
+ ]
139
+ )
140
+
141
+ return model.eval(), transform
142
+
143
+
144
+ class MiDaSInference(nn.Module):
145
+ MODEL_TYPES_TORCH_HUB = [
146
+ "DPT_Large",
147
+ "DPT_Hybrid",
148
+ "MiDaS_small"
149
+ ]
150
+ MODEL_TYPES_ISL = [
151
+ "dpt_large",
152
+ "dpt_hybrid",
153
+ "midas_v21",
154
+ "midas_v21_small",
155
+ ]
156
+
157
+ def __init__(self, model_type):
158
+ super().__init__()
159
+ assert (model_type in self.MODEL_TYPES_ISL)
160
+ model, _ = load_model(model_type)
161
+ self.model = model
162
+ self.model.train = disabled_train
163
+
164
+ def forward(self, x):
165
+ with torch.no_grad():
166
+ prediction = self.model(x)
167
+ return prediction
168
+
src/flux/annotator/midas/midas/__init__.py ADDED
File without changes
src/flux/annotator/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
src/flux/annotator/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
src/flux/annotator/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
src/flux/annotator/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
src/flux/annotator/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
src/flux/annotator/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
src/flux/annotator/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
src/flux/annotator/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
src/flux/annotator/mlsd/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021-present NAVER Corp.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
src/flux/annotator/mlsd/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MLSD Line Detection
2
+ # From https://github.com/navervision/mlsd
3
+ # Apache-2.0 license
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import os
9
+
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
13
+ from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
14
+ from .utils import pred_lines
15
+
16
+ from ...annotator.util import annotator_ckpts_path
17
+
18
+
19
+ class MLSDdetector:
20
+ def __init__(self):
21
+ model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
22
+ if not os.path.exists(model_path):
23
+ model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth")
24
+ model = MobileV2_MLSD_Large()
25
+ model.load_state_dict(torch.load(model_path), strict=True)
26
+ self.model = model.cuda().eval()
27
+
28
+ def __call__(self, input_image, thr_v, thr_d):
29
+ assert input_image.ndim == 3
30
+ img = input_image
31
+ img_output = np.zeros_like(img)
32
+ try:
33
+ with torch.no_grad():
34
+ lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
35
+ for line in lines:
36
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
37
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
38
+ except Exception as e:
39
+ pass
40
+ return img_output[:, :, 0]
src/flux/annotator/mlsd/models/mbv2_mlsd_large.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class BlockTypeA(nn.Module):
10
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11
+ super(BlockTypeA, self).__init__()
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
14
+ nn.BatchNorm2d(out_c2),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ self.conv2 = nn.Sequential(
18
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
19
+ nn.BatchNorm2d(out_c1),
20
+ nn.ReLU(inplace=True)
21
+ )
22
+ self.upscale = upscale
23
+
24
+ def forward(self, a, b):
25
+ b = self.conv1(b)
26
+ a = self.conv2(a)
27
+ if self.upscale:
28
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
29
+ return torch.cat((a, b), dim=1)
30
+
31
+
32
+ class BlockTypeB(nn.Module):
33
+ def __init__(self, in_c, out_c):
34
+ super(BlockTypeB, self).__init__()
35
+ self.conv1 = nn.Sequential(
36
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
37
+ nn.BatchNorm2d(in_c),
38
+ nn.ReLU()
39
+ )
40
+ self.conv2 = nn.Sequential(
41
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(out_c),
43
+ nn.ReLU()
44
+ )
45
+
46
+ def forward(self, x):
47
+ x = self.conv1(x) + x
48
+ x = self.conv2(x)
49
+ return x
50
+
51
+ class BlockTypeC(nn.Module):
52
+ def __init__(self, in_c, out_c):
53
+ super(BlockTypeC, self).__init__()
54
+ self.conv1 = nn.Sequential(
55
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
56
+ nn.BatchNorm2d(in_c),
57
+ nn.ReLU()
58
+ )
59
+ self.conv2 = nn.Sequential(
60
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
61
+ nn.BatchNorm2d(in_c),
62
+ nn.ReLU()
63
+ )
64
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
65
+
66
+ def forward(self, x):
67
+ x = self.conv1(x)
68
+ x = self.conv2(x)
69
+ x = self.conv3(x)
70
+ return x
71
+
72
+ def _make_divisible(v, divisor, min_value=None):
73
+ """
74
+ This function is taken from the original tf repo.
75
+ It ensures that all layers have a channel number that is divisible by 8
76
+ It can be seen here:
77
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
78
+ :param v:
79
+ :param divisor:
80
+ :param min_value:
81
+ :return:
82
+ """
83
+ if min_value is None:
84
+ min_value = divisor
85
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
86
+ # Make sure that round down does not go down by more than 10%.
87
+ if new_v < 0.9 * v:
88
+ new_v += divisor
89
+ return new_v
90
+
91
+
92
+ class ConvBNReLU(nn.Sequential):
93
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
94
+ self.channel_pad = out_planes - in_planes
95
+ self.stride = stride
96
+ #padding = (kernel_size - 1) // 2
97
+
98
+ # TFLite uses slightly different padding than PyTorch
99
+ if stride == 2:
100
+ padding = 0
101
+ else:
102
+ padding = (kernel_size - 1) // 2
103
+
104
+ super(ConvBNReLU, self).__init__(
105
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
106
+ nn.BatchNorm2d(out_planes),
107
+ nn.ReLU6(inplace=True)
108
+ )
109
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
110
+
111
+
112
+ def forward(self, x):
113
+ # TFLite uses different padding
114
+ if self.stride == 2:
115
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
116
+ #print(x.shape)
117
+
118
+ for module in self:
119
+ if not isinstance(module, nn.MaxPool2d):
120
+ x = module(x)
121
+ return x
122
+
123
+
124
+ class InvertedResidual(nn.Module):
125
+ def __init__(self, inp, oup, stride, expand_ratio):
126
+ super(InvertedResidual, self).__init__()
127
+ self.stride = stride
128
+ assert stride in [1, 2]
129
+
130
+ hidden_dim = int(round(inp * expand_ratio))
131
+ self.use_res_connect = self.stride == 1 and inp == oup
132
+
133
+ layers = []
134
+ if expand_ratio != 1:
135
+ # pw
136
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
137
+ layers.extend([
138
+ # dw
139
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
140
+ # pw-linear
141
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
142
+ nn.BatchNorm2d(oup),
143
+ ])
144
+ self.conv = nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ if self.use_res_connect:
148
+ return x + self.conv(x)
149
+ else:
150
+ return self.conv(x)
151
+
152
+
153
+ class MobileNetV2(nn.Module):
154
+ def __init__(self, pretrained=True):
155
+ """
156
+ MobileNet V2 main class
157
+ Args:
158
+ num_classes (int): Number of classes
159
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
160
+ inverted_residual_setting: Network structure
161
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
162
+ Set to 1 to turn off rounding
163
+ block: Module specifying inverted residual building block for mobilenet
164
+ """
165
+ super(MobileNetV2, self).__init__()
166
+
167
+ block = InvertedResidual
168
+ input_channel = 32
169
+ last_channel = 1280
170
+ width_mult = 1.0
171
+ round_nearest = 8
172
+
173
+ inverted_residual_setting = [
174
+ # t, c, n, s
175
+ [1, 16, 1, 1],
176
+ [6, 24, 2, 2],
177
+ [6, 32, 3, 2],
178
+ [6, 64, 4, 2],
179
+ [6, 96, 3, 1],
180
+ #[6, 160, 3, 2],
181
+ #[6, 320, 1, 1],
182
+ ]
183
+
184
+ # only check the first element, assuming user knows t,c,n,s are required
185
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
186
+ raise ValueError("inverted_residual_setting should be non-empty "
187
+ "or a 4-element list, got {}".format(inverted_residual_setting))
188
+
189
+ # building first layer
190
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
191
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
192
+ features = [ConvBNReLU(4, input_channel, stride=2)]
193
+ # building inverted residual blocks
194
+ for t, c, n, s in inverted_residual_setting:
195
+ output_channel = _make_divisible(c * width_mult, round_nearest)
196
+ for i in range(n):
197
+ stride = s if i == 0 else 1
198
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
199
+ input_channel = output_channel
200
+
201
+ self.features = nn.Sequential(*features)
202
+ self.fpn_selected = [1, 3, 6, 10, 13]
203
+ # weight initialization
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
207
+ if m.bias is not None:
208
+ nn.init.zeros_(m.bias)
209
+ elif isinstance(m, nn.BatchNorm2d):
210
+ nn.init.ones_(m.weight)
211
+ nn.init.zeros_(m.bias)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.normal_(m.weight, 0, 0.01)
214
+ nn.init.zeros_(m.bias)
215
+ if pretrained:
216
+ self._load_pretrained_model()
217
+
218
+ def _forward_impl(self, x):
219
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
220
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221
+ fpn_features = []
222
+ for i, f in enumerate(self.features):
223
+ if i > self.fpn_selected[-1]:
224
+ break
225
+ x = f(x)
226
+ if i in self.fpn_selected:
227
+ fpn_features.append(x)
228
+
229
+ c1, c2, c3, c4, c5 = fpn_features
230
+ return c1, c2, c3, c4, c5
231
+
232
+
233
+ def forward(self, x):
234
+ return self._forward_impl(x)
235
+
236
+ def _load_pretrained_model(self):
237
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238
+ model_dict = {}
239
+ state_dict = self.state_dict()
240
+ for k, v in pretrain_dict.items():
241
+ if k in state_dict:
242
+ model_dict[k] = v
243
+ state_dict.update(model_dict)
244
+ self.load_state_dict(state_dict)
245
+
246
+
247
+ class MobileV2_MLSD_Large(nn.Module):
248
+ def __init__(self):
249
+ super(MobileV2_MLSD_Large, self).__init__()
250
+
251
+ self.backbone = MobileNetV2(pretrained=False)
252
+ ## A, B
253
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
254
+ out_c1= 64, out_c2=64,
255
+ upscale=False)
256
+ self.block16 = BlockTypeB(128, 64)
257
+
258
+ ## A, B
259
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
260
+ out_c1= 64, out_c2= 64)
261
+ self.block18 = BlockTypeB(128, 64)
262
+
263
+ ## A, B
264
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
265
+ out_c1=64, out_c2=64)
266
+ self.block20 = BlockTypeB(128, 64)
267
+
268
+ ## A, B, C
269
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
270
+ out_c1=64, out_c2=64)
271
+ self.block22 = BlockTypeB(128, 64)
272
+
273
+ self.block23 = BlockTypeC(64, 16)
274
+
275
+ def forward(self, x):
276
+ c1, c2, c3, c4, c5 = self.backbone(x)
277
+
278
+ x = self.block15(c4, c5)
279
+ x = self.block16(x)
280
+
281
+ x = self.block17(c3, x)
282
+ x = self.block18(x)
283
+
284
+ x = self.block19(c2, x)
285
+ x = self.block20(x)
286
+
287
+ x = self.block21(c1, x)
288
+ x = self.block22(x)
289
+ x = self.block23(x)
290
+ x = x[:, 7:, :, :]
291
+
292
+ return x
src/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class BlockTypeA(nn.Module):
10
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
11
+ super(BlockTypeA, self).__init__()
12
+ self.conv1 = nn.Sequential(
13
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
14
+ nn.BatchNorm2d(out_c2),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ self.conv2 = nn.Sequential(
18
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
19
+ nn.BatchNorm2d(out_c1),
20
+ nn.ReLU(inplace=True)
21
+ )
22
+ self.upscale = upscale
23
+
24
+ def forward(self, a, b):
25
+ b = self.conv1(b)
26
+ a = self.conv2(a)
27
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
28
+ return torch.cat((a, b), dim=1)
29
+
30
+
31
+ class BlockTypeB(nn.Module):
32
+ def __init__(self, in_c, out_c):
33
+ super(BlockTypeB, self).__init__()
34
+ self.conv1 = nn.Sequential(
35
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
36
+ nn.BatchNorm2d(in_c),
37
+ nn.ReLU()
38
+ )
39
+ self.conv2 = nn.Sequential(
40
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
41
+ nn.BatchNorm2d(out_c),
42
+ nn.ReLU()
43
+ )
44
+
45
+ def forward(self, x):
46
+ x = self.conv1(x) + x
47
+ x = self.conv2(x)
48
+ return x
49
+
50
+ class BlockTypeC(nn.Module):
51
+ def __init__(self, in_c, out_c):
52
+ super(BlockTypeC, self).__init__()
53
+ self.conv1 = nn.Sequential(
54
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
55
+ nn.BatchNorm2d(in_c),
56
+ nn.ReLU()
57
+ )
58
+ self.conv2 = nn.Sequential(
59
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
60
+ nn.BatchNorm2d(in_c),
61
+ nn.ReLU()
62
+ )
63
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
64
+
65
+ def forward(self, x):
66
+ x = self.conv1(x)
67
+ x = self.conv2(x)
68
+ x = self.conv3(x)
69
+ return x
70
+
71
+ def _make_divisible(v, divisor, min_value=None):
72
+ """
73
+ This function is taken from the original tf repo.
74
+ It ensures that all layers have a channel number that is divisible by 8
75
+ It can be seen here:
76
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
77
+ :param v:
78
+ :param divisor:
79
+ :param min_value:
80
+ :return:
81
+ """
82
+ if min_value is None:
83
+ min_value = divisor
84
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
85
+ # Make sure that round down does not go down by more than 10%.
86
+ if new_v < 0.9 * v:
87
+ new_v += divisor
88
+ return new_v
89
+
90
+
91
+ class ConvBNReLU(nn.Sequential):
92
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
93
+ self.channel_pad = out_planes - in_planes
94
+ self.stride = stride
95
+ #padding = (kernel_size - 1) // 2
96
+
97
+ # TFLite uses slightly different padding than PyTorch
98
+ if stride == 2:
99
+ padding = 0
100
+ else:
101
+ padding = (kernel_size - 1) // 2
102
+
103
+ super(ConvBNReLU, self).__init__(
104
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
105
+ nn.BatchNorm2d(out_planes),
106
+ nn.ReLU6(inplace=True)
107
+ )
108
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
109
+
110
+
111
+ def forward(self, x):
112
+ # TFLite uses different padding
113
+ if self.stride == 2:
114
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
115
+ #print(x.shape)
116
+
117
+ for module in self:
118
+ if not isinstance(module, nn.MaxPool2d):
119
+ x = module(x)
120
+ return x
121
+
122
+
123
+ class InvertedResidual(nn.Module):
124
+ def __init__(self, inp, oup, stride, expand_ratio):
125
+ super(InvertedResidual, self).__init__()
126
+ self.stride = stride
127
+ assert stride in [1, 2]
128
+
129
+ hidden_dim = int(round(inp * expand_ratio))
130
+ self.use_res_connect = self.stride == 1 and inp == oup
131
+
132
+ layers = []
133
+ if expand_ratio != 1:
134
+ # pw
135
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
136
+ layers.extend([
137
+ # dw
138
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
139
+ # pw-linear
140
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
141
+ nn.BatchNorm2d(oup),
142
+ ])
143
+ self.conv = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ if self.use_res_connect:
147
+ return x + self.conv(x)
148
+ else:
149
+ return self.conv(x)
150
+
151
+
152
+ class MobileNetV2(nn.Module):
153
+ def __init__(self, pretrained=True):
154
+ """
155
+ MobileNet V2 main class
156
+ Args:
157
+ num_classes (int): Number of classes
158
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
159
+ inverted_residual_setting: Network structure
160
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
161
+ Set to 1 to turn off rounding
162
+ block: Module specifying inverted residual building block for mobilenet
163
+ """
164
+ super(MobileNetV2, self).__init__()
165
+
166
+ block = InvertedResidual
167
+ input_channel = 32
168
+ last_channel = 1280
169
+ width_mult = 1.0
170
+ round_nearest = 8
171
+
172
+ inverted_residual_setting = [
173
+ # t, c, n, s
174
+ [1, 16, 1, 1],
175
+ [6, 24, 2, 2],
176
+ [6, 32, 3, 2],
177
+ [6, 64, 4, 2],
178
+ #[6, 96, 3, 1],
179
+ #[6, 160, 3, 2],
180
+ #[6, 320, 1, 1],
181
+ ]
182
+
183
+ # only check the first element, assuming user knows t,c,n,s are required
184
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
185
+ raise ValueError("inverted_residual_setting should be non-empty "
186
+ "or a 4-element list, got {}".format(inverted_residual_setting))
187
+
188
+ # building first layer
189
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
190
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
191
+ features = [ConvBNReLU(4, input_channel, stride=2)]
192
+ # building inverted residual blocks
193
+ for t, c, n, s in inverted_residual_setting:
194
+ output_channel = _make_divisible(c * width_mult, round_nearest)
195
+ for i in range(n):
196
+ stride = s if i == 0 else 1
197
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
198
+ input_channel = output_channel
199
+ self.features = nn.Sequential(*features)
200
+
201
+ self.fpn_selected = [3, 6, 10]
202
+ # weight initialization
203
+ for m in self.modules():
204
+ if isinstance(m, nn.Conv2d):
205
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
206
+ if m.bias is not None:
207
+ nn.init.zeros_(m.bias)
208
+ elif isinstance(m, nn.BatchNorm2d):
209
+ nn.init.ones_(m.weight)
210
+ nn.init.zeros_(m.bias)
211
+ elif isinstance(m, nn.Linear):
212
+ nn.init.normal_(m.weight, 0, 0.01)
213
+ nn.init.zeros_(m.bias)
214
+
215
+ #if pretrained:
216
+ # self._load_pretrained_model()
217
+
218
+ def _forward_impl(self, x):
219
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
220
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
221
+ fpn_features = []
222
+ for i, f in enumerate(self.features):
223
+ if i > self.fpn_selected[-1]:
224
+ break
225
+ x = f(x)
226
+ if i in self.fpn_selected:
227
+ fpn_features.append(x)
228
+
229
+ c2, c3, c4 = fpn_features
230
+ return c2, c3, c4
231
+
232
+
233
+ def forward(self, x):
234
+ return self._forward_impl(x)
235
+
236
+ def _load_pretrained_model(self):
237
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
238
+ model_dict = {}
239
+ state_dict = self.state_dict()
240
+ for k, v in pretrain_dict.items():
241
+ if k in state_dict:
242
+ model_dict[k] = v
243
+ state_dict.update(model_dict)
244
+ self.load_state_dict(state_dict)
245
+
246
+
247
+ class MobileV2_MLSD_Tiny(nn.Module):
248
+ def __init__(self):
249
+ super(MobileV2_MLSD_Tiny, self).__init__()
250
+
251
+ self.backbone = MobileNetV2(pretrained=True)
252
+
253
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
254
+ out_c1= 64, out_c2=64)
255
+ self.block13 = BlockTypeB(128, 64)
256
+
257
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
258
+ out_c1= 32, out_c2= 32)
259
+ self.block15 = BlockTypeB(64, 64)
260
+
261
+ self.block16 = BlockTypeC(64, 16)
262
+
263
+ def forward(self, x):
264
+ c2, c3, c4 = self.backbone(x)
265
+
266
+ x = self.block12(c3, c4)
267
+ x = self.block13(x)
268
+ x = self.block14(c2, x)
269
+ x = self.block15(x)
270
+ x = self.block16(x)
271
+ x = x[:, 7:, :, :]
272
+ #print(x.shape)
273
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
274
+
275
+ return x
src/flux/annotator/mlsd/utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ modified by lihaoweicv
3
+ pytorch version
4
+ '''
5
+
6
+ '''
7
+ M-LSD
8
+ Copyright 2021-present NAVER Corp.
9
+ Apache License v2.0
10
+ '''
11
+
12
+ import os
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+ from torch.nn import functional as F
17
+
18
+
19
+ def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
20
+ '''
21
+ tpMap:
22
+ center: tpMap[1, 0, :, :]
23
+ displacement: tpMap[1, 1:5, :, :]
24
+ '''
25
+ b, c, h, w = tpMap.shape
26
+ assert b==1, 'only support bsize==1'
27
+ displacement = tpMap[:, 1:5, :, :][0]
28
+ center = tpMap[:, 0, :, :]
29
+ heat = torch.sigmoid(center)
30
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
31
+ keep = (hmax == heat).float()
32
+ heat = heat * keep
33
+ heat = heat.reshape(-1, )
34
+
35
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
36
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
37
+ xx = torch.fmod(indices, w).unsqueeze(-1)
38
+ ptss = torch.cat((yy, xx),dim=-1)
39
+
40
+ ptss = ptss.detach().cpu().numpy()
41
+ scores = scores.detach().cpu().numpy()
42
+ displacement = displacement.detach().cpu().numpy()
43
+ displacement = displacement.transpose((1,2,0))
44
+ return ptss, scores, displacement
45
+
46
+
47
+ def pred_lines(image, model,
48
+ input_shape=[512, 512],
49
+ score_thr=0.10,
50
+ dist_thr=20.0):
51
+ h, w, _ = image.shape
52
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
53
+
54
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
55
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
56
+
57
+ resized_image = resized_image.transpose((2,0,1))
58
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
59
+ batch_image = (batch_image / 127.5) - 1.0
60
+
61
+ batch_image = torch.from_numpy(batch_image).float().to("cuda:4")
62
+ outputs = model(batch_image)
63
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
64
+ start = vmap[:, :, :2]
65
+ end = vmap[:, :, 2:]
66
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
67
+
68
+ segments_list = []
69
+ for center, score in zip(pts, pts_score):
70
+ y, x = center
71
+ distance = dist_map[y, x]
72
+ if score > score_thr and distance > dist_thr:
73
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
74
+ x_start = x + disp_x_start
75
+ y_start = y + disp_y_start
76
+ x_end = x + disp_x_end
77
+ y_end = y + disp_y_end
78
+ segments_list.append([x_start, y_start, x_end, y_end])
79
+
80
+ lines = 2 * np.array(segments_list) # 256 > 512
81
+ lines[:, 0] = lines[:, 0] * w_ratio
82
+ lines[:, 1] = lines[:, 1] * h_ratio
83
+ lines[:, 2] = lines[:, 2] * w_ratio
84
+ lines[:, 3] = lines[:, 3] * h_ratio
85
+
86
+ return lines
87
+
88
+
89
+ def pred_squares(image,
90
+ model,
91
+ input_shape=[512, 512],
92
+ params={'score': 0.06,
93
+ 'outside_ratio': 0.28,
94
+ 'inside_ratio': 0.45,
95
+ 'w_overlap': 0.0,
96
+ 'w_degree': 1.95,
97
+ 'w_length': 0.0,
98
+ 'w_area': 1.86,
99
+ 'w_center': 0.14}):
100
+ '''
101
+ shape = [height, width]
102
+ '''
103
+ h, w, _ = image.shape
104
+ original_shape = [h, w]
105
+
106
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
107
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
108
+ resized_image = resized_image.transpose((2, 0, 1))
109
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
110
+ batch_image = (batch_image / 127.5) - 1.0
111
+
112
+ batch_image = torch.from_numpy(batch_image).float().cuda()
113
+ outputs = model(batch_image)
114
+
115
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
116
+ start = vmap[:, :, :2] # (x, y)
117
+ end = vmap[:, :, 2:] # (x, y)
118
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
119
+
120
+ junc_list = []
121
+ segments_list = []
122
+ for junc, score in zip(pts, pts_score):
123
+ y, x = junc
124
+ distance = dist_map[y, x]
125
+ if score > params['score'] and distance > 20.0:
126
+ junc_list.append([x, y])
127
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
128
+ d_arrow = 1.0
129
+ x_start = x + d_arrow * disp_x_start
130
+ y_start = y + d_arrow * disp_y_start
131
+ x_end = x + d_arrow * disp_x_end
132
+ y_end = y + d_arrow * disp_y_end
133
+ segments_list.append([x_start, y_start, x_end, y_end])
134
+
135
+ segments = np.array(segments_list)
136
+
137
+ ####### post processing for squares
138
+ # 1. get unique lines
139
+ point = np.array([[0, 0]])
140
+ point = point[0]
141
+ start = segments[:, :2]
142
+ end = segments[:, 2:]
143
+ diff = start - end
144
+ a = diff[:, 1]
145
+ b = -diff[:, 0]
146
+ c = a * start[:, 0] + b * start[:, 1]
147
+
148
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
149
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
150
+ theta[theta < 0.0] += 180
151
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
152
+
153
+ d_quant = 1
154
+ theta_quant = 2
155
+ hough[:, 0] //= d_quant
156
+ hough[:, 1] //= theta_quant
157
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
158
+
159
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
160
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
161
+ yx_indices = hough[indices, :].astype('int32')
162
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
163
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
164
+
165
+ acc_map_np = acc_map
166
+ # acc_map = acc_map[None, :, :, None]
167
+ #
168
+ # ### fast suppression using tensorflow op
169
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
170
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
171
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
172
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
173
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
174
+ # _, h, w, _ = acc_map.shape
175
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
176
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
177
+ # yx = tf.concat([y, x], axis=-1)
178
+
179
+ ### fast suppression using pytorch op
180
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
181
+ _,_, h, w = acc_map.shape
182
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
183
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
184
+ flatten_acc_map = acc_map.reshape([-1, ])
185
+
186
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
187
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
188
+ xx = torch.fmod(indices, w).unsqueeze(-1)
189
+ yx = torch.cat((yy, xx), dim=-1)
190
+
191
+ yx = yx.detach().cpu().numpy()
192
+
193
+ topk_values = scores.detach().cpu().numpy()
194
+ indices = idx_map[yx[:, 0], yx[:, 1]]
195
+ basis = 5 // 2
196
+
197
+ merged_segments = []
198
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
199
+ y, x = yx_pt
200
+ if max_indice == -1 or value == 0:
201
+ continue
202
+ segment_list = []
203
+ for y_offset in range(-basis, basis + 1):
204
+ for x_offset in range(-basis, basis + 1):
205
+ indice = idx_map[y + y_offset, x + x_offset]
206
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
207
+ if indice != -1:
208
+ segment_list.append(segments[indice])
209
+ if cnt > 1:
210
+ check_cnt = 1
211
+ current_hough = hough[indice]
212
+ for new_indice, new_hough in enumerate(hough):
213
+ if (current_hough == new_hough).all() and indice != new_indice:
214
+ segment_list.append(segments[new_indice])
215
+ check_cnt += 1
216
+ if check_cnt == cnt:
217
+ break
218
+ group_segments = np.array(segment_list).reshape([-1, 2])
219
+ sorted_group_segments = np.sort(group_segments, axis=0)
220
+ x_min, y_min = sorted_group_segments[0, :]
221
+ x_max, y_max = sorted_group_segments[-1, :]
222
+
223
+ deg = theta[max_indice]
224
+ if deg >= 90:
225
+ merged_segments.append([x_min, y_max, x_max, y_min])
226
+ else:
227
+ merged_segments.append([x_min, y_min, x_max, y_max])
228
+
229
+ # 2. get intersections
230
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
231
+ start = new_segments[:, :2] # (x1, y1)
232
+ end = new_segments[:, 2:] # (x2, y2)
233
+ new_centers = (start + end) / 2.0
234
+ diff = start - end
235
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
236
+
237
+ # ax + by = c
238
+ a = diff[:, 1]
239
+ b = -diff[:, 0]
240
+ c = a * start[:, 0] + b * start[:, 1]
241
+ pre_det = a[:, None] * b[None, :]
242
+ det = pre_det - np.transpose(pre_det)
243
+
244
+ pre_inter_y = a[:, None] * c[None, :]
245
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
246
+ pre_inter_x = c[:, None] * b[None, :]
247
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
248
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
249
+
250
+ # 3. get corner information
251
+ # 3.1 get distance
252
+ '''
253
+ dist_segments:
254
+ | dist(0), dist(1), dist(2), ...|
255
+ dist_inter_to_segment1:
256
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
257
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
258
+ ...
259
+ dist_inter_to_semgnet2:
260
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
261
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
262
+ ...
263
+ '''
264
+
265
+ dist_inter_to_segment1_start = np.sqrt(
266
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
267
+ dist_inter_to_segment1_end = np.sqrt(
268
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
269
+ dist_inter_to_segment2_start = np.sqrt(
270
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
271
+ dist_inter_to_segment2_end = np.sqrt(
272
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
273
+
274
+ # sort ascending
275
+ dist_inter_to_segment1 = np.sort(
276
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
277
+ axis=-1) # [n_batch, n_batch, 2]
278
+ dist_inter_to_segment2 = np.sort(
279
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
280
+ axis=-1) # [n_batch, n_batch, 2]
281
+
282
+ # 3.2 get degree
283
+ inter_to_start = new_centers[:, None, :] - inter_pts
284
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
285
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
286
+ inter_to_end = new_centers[None, :, :] - inter_pts
287
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
288
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
289
+
290
+ '''
291
+ B -- G
292
+ | |
293
+ C -- R
294
+ B : blue / G: green / C: cyan / R: red
295
+
296
+ 0 -- 1
297
+ | |
298
+ 3 -- 2
299
+ '''
300
+ # rename variables
301
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
302
+ # sort deg ascending
303
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
304
+
305
+ deg_diff_map = np.abs(deg1_map - deg2_map)
306
+ # we only consider the smallest degree of intersect
307
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
308
+
309
+ # define available degree range
310
+ deg_range = [60, 120]
311
+
312
+ corner_dict = {corner_info: [] for corner_info in range(4)}
313
+ inter_points = []
314
+ for i in range(inter_pts.shape[0]):
315
+ for j in range(i + 1, inter_pts.shape[1]):
316
+ # i, j > line index, always i < j
317
+ x, y = inter_pts[i, j, :]
318
+ deg1, deg2 = deg_sort[i, j, :]
319
+ deg_diff = deg_diff_map[i, j]
320
+
321
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
322
+
323
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
324
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
325
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
326
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
327
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
328
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
329
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
330
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
331
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
332
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
333
+
334
+ if check_degree and check_distance:
335
+ corner_info = None
336
+
337
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
338
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
339
+ corner_info, color_info = 0, 'blue'
340
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
341
+ corner_info, color_info = 1, 'green'
342
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
343
+ corner_info, color_info = 2, 'black'
344
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
345
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
346
+ corner_info, color_info = 3, 'cyan'
347
+ else:
348
+ corner_info, color_info = 4, 'red' # we don't use it
349
+ continue
350
+
351
+ corner_dict[corner_info].append([x, y, i, j])
352
+ inter_points.append([x, y])
353
+
354
+ square_list = []
355
+ connect_list = []
356
+ segments_list = []
357
+ for corner0 in corner_dict[0]:
358
+ for corner1 in corner_dict[1]:
359
+ connect01 = False
360
+ for corner0_line in corner0[2:]:
361
+ if corner0_line in corner1[2:]:
362
+ connect01 = True
363
+ break
364
+ if connect01:
365
+ for corner2 in corner_dict[2]:
366
+ connect12 = False
367
+ for corner1_line in corner1[2:]:
368
+ if corner1_line in corner2[2:]:
369
+ connect12 = True
370
+ break
371
+ if connect12:
372
+ for corner3 in corner_dict[3]:
373
+ connect23 = False
374
+ for corner2_line in corner2[2:]:
375
+ if corner2_line in corner3[2:]:
376
+ connect23 = True
377
+ break
378
+ if connect23:
379
+ for corner3_line in corner3[2:]:
380
+ if corner3_line in corner0[2:]:
381
+ # SQUARE!!!
382
+ '''
383
+ 0 -- 1
384
+ | |
385
+ 3 -- 2
386
+ square_list:
387
+ order: 0 > 1 > 2 > 3
388
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
389
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
390
+ ...
391
+ connect_list:
392
+ order: 01 > 12 > 23 > 30
393
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
394
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
395
+ ...
396
+ segments_list:
397
+ order: 0 > 1 > 2 > 3
398
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
399
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
400
+ ...
401
+ '''
402
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
403
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
404
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
405
+
406
+ def check_outside_inside(segments_info, connect_idx):
407
+ # return 'outside or inside', min distance, cover_param, peri_param
408
+ if connect_idx == segments_info[0]:
409
+ check_dist_mat = dist_inter_to_segment1
410
+ else:
411
+ check_dist_mat = dist_inter_to_segment2
412
+
413
+ i, j = segments_info
414
+ min_dist, max_dist = check_dist_mat[i, j, :]
415
+ connect_dist = dist_segments[connect_idx]
416
+ if max_dist > connect_dist:
417
+ return 'outside', min_dist, 0, 1
418
+ else:
419
+ return 'inside', min_dist, -1, -1
420
+
421
+ top_square = None
422
+
423
+ try:
424
+ map_size = input_shape[0] / 2
425
+ squares = np.array(square_list).reshape([-1, 4, 2])
426
+ score_array = []
427
+ connect_array = np.array(connect_list)
428
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
429
+
430
+ # get degree of corners:
431
+ squares_rollup = np.roll(squares, 1, axis=1)
432
+ squares_rolldown = np.roll(squares, -1, axis=1)
433
+ vec1 = squares_rollup - squares
434
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
435
+ vec2 = squares_rolldown - squares
436
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
437
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
438
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
439
+
440
+ # get square score
441
+ overlap_scores = []
442
+ degree_scores = []
443
+ length_scores = []
444
+
445
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
446
+ '''
447
+ 0 -- 1
448
+ | |
449
+ 3 -- 2
450
+
451
+ # segments: [4, 2]
452
+ # connects: [4]
453
+ '''
454
+
455
+ ###################################### OVERLAP SCORES
456
+ cover = 0
457
+ perimeter = 0
458
+ # check 0 > 1 > 2 > 3
459
+ square_length = []
460
+
461
+ for start_idx in range(4):
462
+ end_idx = (start_idx + 1) % 4
463
+
464
+ connect_idx = connects[start_idx] # segment idx of segment01
465
+ start_segments = segments[start_idx]
466
+ end_segments = segments[end_idx]
467
+
468
+ start_point = square[start_idx]
469
+ end_point = square[end_idx]
470
+
471
+ # check whether outside or inside
472
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
473
+ connect_idx)
474
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
475
+
476
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
477
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
478
+
479
+ square_length.append(
480
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
481
+
482
+ overlap_scores.append(cover / perimeter)
483
+ ######################################
484
+ ###################################### DEGREE SCORES
485
+ '''
486
+ deg0 vs deg2
487
+ deg1 vs deg3
488
+ '''
489
+ deg0, deg1, deg2, deg3 = degree
490
+ deg_ratio1 = deg0 / deg2
491
+ if deg_ratio1 > 1.0:
492
+ deg_ratio1 = 1 / deg_ratio1
493
+ deg_ratio2 = deg1 / deg3
494
+ if deg_ratio2 > 1.0:
495
+ deg_ratio2 = 1 / deg_ratio2
496
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
497
+ ######################################
498
+ ###################################### LENGTH SCORES
499
+ '''
500
+ len0 vs len2
501
+ len1 vs len3
502
+ '''
503
+ len0, len1, len2, len3 = square_length
504
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
505
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
506
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
507
+
508
+ ######################################
509
+
510
+ overlap_scores = np.array(overlap_scores)
511
+ overlap_scores /= np.max(overlap_scores)
512
+
513
+ degree_scores = np.array(degree_scores)
514
+ # degree_scores /= np.max(degree_scores)
515
+
516
+ length_scores = np.array(length_scores)
517
+
518
+ ###################################### AREA SCORES
519
+ area_scores = np.reshape(squares, [-1, 4, 2])
520
+ area_x = area_scores[:, :, 0]
521
+ area_y = area_scores[:, :, 1]
522
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
523
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
524
+ area_scores = 0.5 * np.abs(area_scores + correction)
525
+ area_scores /= (map_size * map_size) # np.max(area_scores)
526
+ ######################################
527
+
528
+ ###################################### CENTER SCORES
529
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
530
+ # squares: [n, 4, 2]
531
+ square_centers = np.mean(squares, axis=1) # [n, 2]
532
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
533
+ center_scores = center2center / (map_size / np.sqrt(2.0))
534
+
535
+ '''
536
+ score_w = [overlap, degree, area, center, length]
537
+ '''
538
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
539
+ score_array = params['w_overlap'] * overlap_scores \
540
+ + params['w_degree'] * degree_scores \
541
+ + params['w_area'] * area_scores \
542
+ - params['w_center'] * center_scores \
543
+ + params['w_length'] * length_scores
544
+
545
+ best_square = []
546
+
547
+ sorted_idx = np.argsort(score_array)[::-1]
548
+ score_array = score_array[sorted_idx]
549
+ squares = squares[sorted_idx]
550
+
551
+ except Exception as e:
552
+ pass
553
+
554
+ '''return list
555
+ merged_lines, squares, scores
556
+ '''
557
+
558
+ try:
559
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
560
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
561
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
562
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
563
+ except:
564
+ new_segments = []
565
+
566
+ try:
567
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
568
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
569
+ except:
570
+ squares = []
571
+ score_array = []
572
+
573
+ try:
574
+ inter_points = np.array(inter_points)
575
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
576
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
577
+ except:
578
+ inter_points = []
579
+
580
+ return new_segments, squares, score_array, inter_points
src/flux/annotator/tile/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import cv2
3
+ from .guided_filter import FastGuidedFilter
4
+
5
+
6
+ class TileDetector:
7
+ # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0
8
+ def __init__(self):
9
+ pass
10
+
11
+ def __call__(self, image):
12
+ blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0]
13
+ radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]
14
+ eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0]
15
+ scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0]
16
+
17
+ ksize = int(blur_strength)
18
+ if ksize % 2 == 0:
19
+ ksize += 1
20
+
21
+ if random.random() > 0.5:
22
+ image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2)
23
+ if random.random() > 0.5:
24
+ filter = FastGuidedFilter(image, radius, eps, scale_factor)
25
+ image = filter.filter(image)
26
+ return image
src/flux/annotator/tile/guided_filter.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ ## @package guided_filter.core.filters
3
+ #
4
+ # Implementation of guided filter.
5
+ # * GuidedFilter: Original guided filter.
6
+ # * FastGuidedFilter: Fast version of the guided filter.
7
+ # @author tody
8
+ # @date 2015/08/26
9
+
10
+ import numpy as np
11
+ import cv2
12
+
13
+ ## Convert image into float32 type.
14
+ def to32F(img):
15
+ if img.dtype == np.float32:
16
+ return img
17
+ return (1.0 / 255.0) * np.float32(img)
18
+
19
+ ## Convert image into uint8 type.
20
+ def to8U(img):
21
+ if img.dtype == np.uint8:
22
+ return img
23
+ return np.clip(np.uint8(255.0 * img), 0, 255)
24
+
25
+ ## Return if the input image is gray or not.
26
+ def _isGray(I):
27
+ return len(I.shape) == 2
28
+
29
+
30
+ ## Return down sampled image.
31
+ # @param scale (w/s, h/s) image will be created.
32
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
33
+ def _downSample(I, scale=4, shape=None):
34
+ if shape is not None:
35
+ h, w = shape
36
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST)
37
+
38
+ h, w = I.shape[:2]
39
+ return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST)
40
+
41
+
42
+ ## Return up sampled image.
43
+ # @param scale (w*s, h*s) image will be created.
44
+ # @param shape I.shape[:2]=(h, w). numpy friendly size parameter.
45
+ def _upSample(I, scale=2, shape=None):
46
+ if shape is not None:
47
+ h, w = shape
48
+ return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR)
49
+
50
+ h, w = I.shape[:2]
51
+ return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
52
+
53
+ ## Fast guide filter.
54
+ class FastGuidedFilter:
55
+ ## Constructor.
56
+ # @param I Input guidance image. Color or gray.
57
+ # @param radius Radius of Guided Filter.
58
+ # @param epsilon Regularization term of Guided Filter.
59
+ # @param scale Down sampled scale.
60
+ def __init__(self, I, radius=5, epsilon=0.4, scale=4):
61
+ I_32F = to32F(I)
62
+ self._I = I_32F
63
+ h, w = I.shape[:2]
64
+
65
+ I_sub = _downSample(I_32F, scale)
66
+
67
+ self._I_sub = I_sub
68
+ radius = int(radius / scale)
69
+
70
+ if _isGray(I):
71
+ self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon)
72
+ else:
73
+ self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon)
74
+
75
+ ## Apply filter for the input image.
76
+ # @param p Input image for the filtering.
77
+ def filter(self, p):
78
+ p_32F = to32F(p)
79
+ shape_original = p.shape[:2]
80
+
81
+ p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2])
82
+
83
+ if _isGray(p_sub):
84
+ return self._filterGray(p_sub, shape_original)
85
+
86
+ cs = p.shape[2]
87
+ q = np.array(p_32F)
88
+
89
+ for ci in range(cs):
90
+ q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original)
91
+ return to8U(q)
92
+
93
+ def _filterGray(self, p_sub, shape_original):
94
+ ab_sub = self._guided_filter._computeCoefficients(p_sub)
95
+ ab = [_upSample(abi, shape=shape_original) for abi in ab_sub]
96
+ return self._guided_filter._computeOutput(ab, self._I)
97
+
98
+
99
+ ## Guide filter.
100
+ class GuidedFilter:
101
+ ## Constructor.
102
+ # @param I Input guidance image. Color or gray.
103
+ # @param radius Radius of Guided Filter.
104
+ # @param epsilon Regularization term of Guided Filter.
105
+ def __init__(self, I, radius=5, epsilon=0.4):
106
+ I_32F = to32F(I)
107
+
108
+ if _isGray(I):
109
+ self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon)
110
+ else:
111
+ self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon)
112
+
113
+ ## Apply filter for the input image.
114
+ # @param p Input image for the filtering.
115
+ def filter(self, p):
116
+ return to8U(self._guided_filter.filter(p))
117
+
118
+
119
+ ## Common parts of guided filter.
120
+ #
121
+ # This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor.
122
+ # Based on guided_filter._computeCoefficients, guided_filter._computeOutput,
123
+ # GuidedFilterCommon.filter computes filtered image for color and gray.
124
+ class GuidedFilterCommon:
125
+ def __init__(self, guided_filter):
126
+ self._guided_filter = guided_filter
127
+
128
+ ## Apply filter for the input image.
129
+ # @param p Input image for the filtering.
130
+ def filter(self, p):
131
+ p_32F = to32F(p)
132
+ if _isGray(p_32F):
133
+ return self._filterGray(p_32F)
134
+
135
+ cs = p.shape[2]
136
+ q = np.array(p_32F)
137
+
138
+ for ci in range(cs):
139
+ q[:, :, ci] = self._filterGray(p_32F[:, :, ci])
140
+ return q
141
+
142
+ def _filterGray(self, p):
143
+ ab = self._guided_filter._computeCoefficients(p)
144
+ return self._guided_filter._computeOutput(ab, self._guided_filter._I)
145
+
146
+
147
+ ## Guided filter for gray guidance image.
148
+ class GuidedFilterGray:
149
+ # @param I Input gray guidance image.
150
+ # @param radius Radius of Guided Filter.
151
+ # @param epsilon Regularization term of Guided Filter.
152
+ def __init__(self, I, radius=5, epsilon=0.4):
153
+ self._radius = 2 * radius + 1
154
+ self._epsilon = epsilon
155
+ self._I = to32F(I)
156
+ self._initFilter()
157
+ self._filter_common = GuidedFilterCommon(self)
158
+
159
+ ## Apply filter for the input image.
160
+ # @param p Input image for the filtering.
161
+ def filter(self, p):
162
+ return self._filter_common.filter(p)
163
+
164
+ def _initFilter(self):
165
+ I = self._I
166
+ r = self._radius
167
+ self._I_mean = cv2.blur(I, (r, r))
168
+ I_mean_sq = cv2.blur(I ** 2, (r, r))
169
+ self._I_var = I_mean_sq - self._I_mean ** 2
170
+
171
+ def _computeCoefficients(self, p):
172
+ r = self._radius
173
+ p_mean = cv2.blur(p, (r, r))
174
+ p_cov = p_mean - self._I_mean * p_mean
175
+ a = p_cov / (self._I_var + self._epsilon)
176
+ b = p_mean - a * self._I_mean
177
+ a_mean = cv2.blur(a, (r, r))
178
+ b_mean = cv2.blur(b, (r, r))
179
+ return a_mean, b_mean
180
+
181
+ def _computeOutput(self, ab, I):
182
+ a_mean, b_mean = ab
183
+ return a_mean * I + b_mean
184
+
185
+
186
+ ## Guided filter for color guidance image.
187
+ class GuidedFilterColor:
188
+ # @param I Input color guidance image.
189
+ # @param radius Radius of Guided Filter.
190
+ # @param epsilon Regularization term of Guided Filter.
191
+ def __init__(self, I, radius=5, epsilon=0.2):
192
+ self._radius = 2 * radius + 1
193
+ self._epsilon = epsilon
194
+ self._I = to32F(I)
195
+ self._initFilter()
196
+ self._filter_common = GuidedFilterCommon(self)
197
+
198
+ ## Apply filter for the input image.
199
+ # @param p Input image for the filtering.
200
+ def filter(self, p):
201
+ return self._filter_common.filter(p)
202
+
203
+ def _initFilter(self):
204
+ I = self._I
205
+ r = self._radius
206
+ eps = self._epsilon
207
+
208
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
209
+
210
+ self._Ir_mean = cv2.blur(Ir, (r, r))
211
+ self._Ig_mean = cv2.blur(Ig, (r, r))
212
+ self._Ib_mean = cv2.blur(Ib, (r, r))
213
+
214
+ Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps
215
+ Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean
216
+ Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean
217
+ Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps
218
+ Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean
219
+ Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps
220
+
221
+ Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var
222
+ Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var
223
+ Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var
224
+ Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var
225
+ Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var
226
+ Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var
227
+
228
+ I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var
229
+ Irr_inv /= I_cov
230
+ Irg_inv /= I_cov
231
+ Irb_inv /= I_cov
232
+ Igg_inv /= I_cov
233
+ Igb_inv /= I_cov
234
+ Ibb_inv /= I_cov
235
+
236
+ self._Irr_inv = Irr_inv
237
+ self._Irg_inv = Irg_inv
238
+ self._Irb_inv = Irb_inv
239
+ self._Igg_inv = Igg_inv
240
+ self._Igb_inv = Igb_inv
241
+ self._Ibb_inv = Ibb_inv
242
+
243
+ def _computeCoefficients(self, p):
244
+ r = self._radius
245
+ I = self._I
246
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
247
+
248
+ p_mean = cv2.blur(p, (r, r))
249
+
250
+ Ipr_mean = cv2.blur(Ir * p, (r, r))
251
+ Ipg_mean = cv2.blur(Ig * p, (r, r))
252
+ Ipb_mean = cv2.blur(Ib * p, (r, r))
253
+
254
+ Ipr_cov = Ipr_mean - self._Ir_mean * p_mean
255
+ Ipg_cov = Ipg_mean - self._Ig_mean * p_mean
256
+ Ipb_cov = Ipb_mean - self._Ib_mean * p_mean
257
+
258
+ ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov
259
+ ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov
260
+ ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov
261
+ b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean
262
+
263
+ ar_mean = cv2.blur(ar, (r, r))
264
+ ag_mean = cv2.blur(ag, (r, r))
265
+ ab_mean = cv2.blur(ab, (r, r))
266
+ b_mean = cv2.blur(b, (r, r))
267
+
268
+ return ar_mean, ag_mean, ab_mean, b_mean
269
+
270
+ def _computeOutput(self, ab, I):
271
+ ar_mean, ag_mean, ab_mean, b_mean = ab
272
+
273
+ Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2]
274
+
275
+ q = (ar_mean * Ir +
276
+ ag_mean * Ig +
277
+ ab_mean * Ib +
278
+ b_mean)
279
+
280
+ return q
src/flux/annotator/util.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+
5
+
6
+ annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
7
+
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
src/flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
src/flux/cli.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from fire import Fire
10
+ from PIL import ExifTags, Image
11
+
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
13
+ from flux.util import (configs, embed_watermark, load_ae, load_clip,
14
+ load_flow_model, load_t5)
15
+ from transformers import pipeline
16
+
17
+ NSFW_THRESHOLD = 0.85
18
+
19
+ @dataclass
20
+ class SamplingOptions:
21
+ prompt: str
22
+ width: int
23
+ height: int
24
+ num_steps: int
25
+ guidance: float
26
+ seed: int | None
27
+
28
+
29
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
30
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
31
+ usage = (
32
+ "Usage: Either write your prompt directly, leave this field empty "
33
+ "to repeat the prompt or write a command starting with a slash:\n"
34
+ "- '/w <width>' will set the width of the generated image\n"
35
+ "- '/h <height>' will set the height of the generated image\n"
36
+ "- '/s <seed>' sets the next seed\n"
37
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
38
+ "- '/n <steps>' sets the number of steps\n"
39
+ "- '/q' to quit"
40
+ )
41
+
42
+ while (prompt := input(user_question)).startswith("/"):
43
+ if prompt.startswith("/w"):
44
+ if prompt.count(" ") != 1:
45
+ print(f"Got invalid command '{prompt}'\n{usage}")
46
+ continue
47
+ _, width = prompt.split()
48
+ options.width = 16 * (int(width) // 16)
49
+ print(
50
+ f"Setting resolution to {options.width} x {options.height} "
51
+ f"({options.height *options.width/1e6:.2f}MP)"
52
+ )
53
+ elif prompt.startswith("/h"):
54
+ if prompt.count(" ") != 1:
55
+ print(f"Got invalid command '{prompt}'\n{usage}")
56
+ continue
57
+ _, height = prompt.split()
58
+ options.height = 16 * (int(height) // 16)
59
+ print(
60
+ f"Setting resolution to {options.width} x {options.height} "
61
+ f"({options.height *options.width/1e6:.2f}MP)"
62
+ )
63
+ elif prompt.startswith("/g"):
64
+ if prompt.count(" ") != 1:
65
+ print(f"Got invalid command '{prompt}'\n{usage}")
66
+ continue
67
+ _, guidance = prompt.split()
68
+ options.guidance = float(guidance)
69
+ print(f"Setting guidance to {options.guidance}")
70
+ elif prompt.startswith("/s"):
71
+ if prompt.count(" ") != 1:
72
+ print(f"Got invalid command '{prompt}'\n{usage}")
73
+ continue
74
+ _, seed = prompt.split()
75
+ options.seed = int(seed)
76
+ print(f"Setting seed to {options.seed}")
77
+ elif prompt.startswith("/n"):
78
+ if prompt.count(" ") != 1:
79
+ print(f"Got invalid command '{prompt}'\n{usage}")
80
+ continue
81
+ _, steps = prompt.split()
82
+ options.num_steps = int(steps)
83
+ print(f"Setting seed to {options.num_steps}")
84
+ elif prompt.startswith("/q"):
85
+ print("Quitting")
86
+ return None
87
+ else:
88
+ if not prompt.startswith("/h"):
89
+ print(f"Got invalid command '{prompt}'\n{usage}")
90
+ print(usage)
91
+ if prompt != "":
92
+ options.prompt = prompt
93
+ return options
94
+
95
+
96
+ @torch.inference_mode()
97
+ def main(
98
+ name: str = "flux-schnell",
99
+ width: int = 1360,
100
+ height: int = 768,
101
+ seed: int | None = None,
102
+ prompt: str = (
103
+ "a photo of a forest with mist swirling around the tree trunks. The word "
104
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
105
+ ),
106
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
107
+ num_steps: int | None = None,
108
+ loop: bool = False,
109
+ guidance: float = 3.5,
110
+ offload: bool = False,
111
+ output_dir: str = "output",
112
+ add_sampling_metadata: bool = True,
113
+ ):
114
+ """
115
+ Sample the flux model. Either interactively (set `--loop`) or run for a
116
+ single image.
117
+
118
+ Args:
119
+ name: Name of the model to load
120
+ height: height of the sample in pixels (should be a multiple of 16)
121
+ width: width of the sample in pixels (should be a multiple of 16)
122
+ seed: Set a seed for sampling
123
+ output_name: where to save the output image, `{idx}` will be replaced
124
+ by the index of the sample
125
+ prompt: Prompt used for sampling
126
+ device: Pytorch device
127
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
128
+ loop: start an interactive session and sample multiple times
129
+ guidance: guidance value used for guidance distillation
130
+ add_sampling_metadata: Add the prompt to the image Exif metadata
131
+ """
132
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
133
+
134
+ if name not in configs:
135
+ available = ", ".join(configs.keys())
136
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
137
+
138
+ torch_device = torch.device(device)
139
+ if num_steps is None:
140
+ num_steps = 4 if name == "flux-schnell" else 50
141
+
142
+ # allow for packing and conversion to latent space
143
+ height = 16 * (height // 16)
144
+ width = 16 * (width // 16)
145
+
146
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
147
+ if not os.path.exists(output_dir):
148
+ os.makedirs(output_dir)
149
+ idx = 0
150
+ else:
151
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)]
152
+ if len(fns) > 0:
153
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
154
+ else:
155
+ idx = 0
156
+
157
+ # init all components
158
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
159
+ clip = load_clip(torch_device)
160
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
161
+ ae = load_ae(name, device="cpu" if offload else torch_device)
162
+
163
+ rng = torch.Generator(device="cpu")
164
+ opts = SamplingOptions(
165
+ prompt=prompt,
166
+ width=width,
167
+ height=height,
168
+ num_steps=num_steps,
169
+ guidance=guidance,
170
+ seed=seed,
171
+ )
172
+
173
+ if loop:
174
+ opts = parse_prompt(opts)
175
+
176
+ while opts is not None:
177
+ if opts.seed is None:
178
+ opts.seed = rng.seed()
179
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
180
+ t0 = time.perf_counter()
181
+
182
+ # prepare input
183
+ x = get_noise(
184
+ 1,
185
+ opts.height,
186
+ opts.width,
187
+ device=torch_device,
188
+ dtype=torch.bfloat16,
189
+ seed=opts.seed,
190
+ )
191
+ opts.seed = None
192
+ if offload:
193
+ ae = ae.cpu()
194
+ torch.cuda.empty_cache()
195
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
196
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
197
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
198
+
199
+ # offload TEs to CPU, load model to gpu
200
+ if offload:
201
+ t5, clip = t5.cpu(), clip.cpu()
202
+ torch.cuda.empty_cache()
203
+ model = model.to(torch_device)
204
+
205
+ # denoise initial noise
206
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
207
+
208
+ # offload model, load autoencoder to gpu
209
+ if offload:
210
+ model.cpu()
211
+ torch.cuda.empty_cache()
212
+ ae.decoder.to(x.device)
213
+
214
+ # decode latents to pixel space
215
+ x = unpack(x.float(), opts.height, opts.width)
216
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
217
+ x = ae.decode(x)
218
+ t1 = time.perf_counter()
219
+
220
+ fn = output_name.format(idx=idx)
221
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
222
+ # bring into PIL format and save
223
+ x = x.clamp(-1, 1)
224
+ x = embed_watermark(x.float())
225
+ x = rearrange(x[0], "c h w -> h w c")
226
+
227
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
228
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
229
+
230
+ if nsfw_score < NSFW_THRESHOLD:
231
+ exif_data = Image.Exif()
232
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
233
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
234
+ exif_data[ExifTags.Base.Model] = name
235
+ if add_sampling_metadata:
236
+ exif_data[ExifTags.Base.ImageDescription] = prompt
237
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
238
+ idx += 1
239
+ else:
240
+ print("Your generated image may contain NSFW content.")
241
+
242
+ if loop:
243
+ print("-" * 80)
244
+ opts = parse_prompt(opts)
245
+ else:
246
+ opts = None
247
+
248
+
249
+ def app():
250
+ Fire(main)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ app()
src/flux/controlnet.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from einops import rearrange
6
+
7
+ from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
8
+ MLPEmbedder, SingleStreamBlock,
9
+ timestep_embedding)
10
+
11
+
12
+ @dataclass
13
+ class FluxParams:
14
+ in_channels: int
15
+ vec_in_dim: int
16
+ context_in_dim: int
17
+ hidden_size: int
18
+ mlp_ratio: float
19
+ num_heads: int
20
+ depth: int
21
+ depth_single_blocks: int
22
+ axes_dim: list[int]
23
+ theta: int
24
+ qkv_bias: bool
25
+ guidance_embed: bool
26
+
27
+ def zero_module(module):
28
+ for p in module.parameters():
29
+ nn.init.zeros_(p)
30
+ return module
31
+
32
+
33
+ class ControlNetFlux(nn.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+ _supports_gradient_checkpointing = True
38
+
39
+ def __init__(self, params: FluxParams, controlnet_depth=2):
40
+ super().__init__()
41
+
42
+ self.params = params
43
+ self.in_channels = params.in_channels
44
+ self.out_channels = self.in_channels
45
+ if params.hidden_size % params.num_heads != 0:
46
+ raise ValueError(
47
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
48
+ )
49
+ pe_dim = params.hidden_size // params.num_heads
50
+ if sum(params.axes_dim) != pe_dim:
51
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52
+ self.hidden_size = params.hidden_size
53
+ self.num_heads = params.num_heads
54
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58
+ self.guidance_in = (
59
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
60
+ )
61
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
62
+
63
+ self.double_blocks = nn.ModuleList(
64
+ [
65
+ DoubleStreamBlock(
66
+ self.hidden_size,
67
+ self.num_heads,
68
+ mlp_ratio=params.mlp_ratio,
69
+ qkv_bias=params.qkv_bias,
70
+ )
71
+ for _ in range(controlnet_depth)
72
+ ]
73
+ )
74
+
75
+ # add ControlNet blocks
76
+ self.controlnet_blocks = nn.ModuleList([])
77
+ for _ in range(controlnet_depth):
78
+ controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
79
+ controlnet_block = zero_module(controlnet_block)
80
+ self.controlnet_blocks.append(controlnet_block)
81
+ self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
82
+ self.gradient_checkpointing = False
83
+ self.input_hint_block = nn.Sequential(
84
+ nn.Conv2d(3, 16, 3, padding=1),
85
+ nn.SiLU(),
86
+ nn.Conv2d(16, 16, 3, padding=1),
87
+ nn.SiLU(),
88
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
89
+ nn.SiLU(),
90
+ nn.Conv2d(16, 16, 3, padding=1),
91
+ nn.SiLU(),
92
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
93
+ nn.SiLU(),
94
+ nn.Conv2d(16, 16, 3, padding=1),
95
+ nn.SiLU(),
96
+ nn.Conv2d(16, 16, 3, padding=1, stride=2),
97
+ nn.SiLU(),
98
+ zero_module(nn.Conv2d(16, 16, 3, padding=1))
99
+ )
100
+
101
+ def _set_gradient_checkpointing(self, module, value=False):
102
+ if hasattr(module, "gradient_checkpointing"):
103
+ module.gradient_checkpointing = value
104
+
105
+
106
+ @property
107
+ def attn_processors(self):
108
+ # set recursively
109
+ processors = {}
110
+
111
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
112
+ if hasattr(module, "set_processor"):
113
+ processors[f"{name}.processor"] = module.processor
114
+
115
+ for sub_name, child in module.named_children():
116
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
117
+
118
+ return processors
119
+
120
+ for name, module in self.named_children():
121
+ fn_recursive_add_processors(name, module, processors)
122
+
123
+ return processors
124
+
125
+ def set_attn_processor(self, processor):
126
+ r"""
127
+ Sets the attention processor to use to compute attention.
128
+
129
+ Parameters:
130
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
131
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
132
+ for **all** `Attention` layers.
133
+
134
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
135
+ processor. This is strongly recommended when setting trainable attention processors.
136
+
137
+ """
138
+ count = len(self.attn_processors.keys())
139
+
140
+ if isinstance(processor, dict) and len(processor) != count:
141
+ raise ValueError(
142
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
143
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
144
+ )
145
+
146
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
147
+ if hasattr(module, "set_processor"):
148
+ if not isinstance(processor, dict):
149
+ module.set_processor(processor)
150
+ else:
151
+ module.set_processor(processor.pop(f"{name}.processor"))
152
+
153
+ for sub_name, child in module.named_children():
154
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
155
+
156
+ for name, module in self.named_children():
157
+ fn_recursive_attn_processor(name, module, processor)
158
+
159
+ def forward(
160
+ self,
161
+ img: Tensor,
162
+ img_ids: Tensor,
163
+ controlnet_cond: Tensor,
164
+ txt: Tensor,
165
+ txt_ids: Tensor,
166
+ timesteps: Tensor,
167
+ y: Tensor,
168
+ guidance: Tensor | None = None,
169
+ ) -> Tensor:
170
+ if img.ndim != 3 or txt.ndim != 3:
171
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
172
+
173
+ # running on sequences img
174
+ img = self.img_in(img)
175
+ controlnet_cond = self.input_hint_block(controlnet_cond)
176
+ controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
177
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
178
+ img = img + controlnet_cond
179
+ vec = self.time_in(timestep_embedding(timesteps, 256))
180
+ if self.params.guidance_embed:
181
+ if guidance is None:
182
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
183
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
184
+ vec = vec + self.vector_in(y)
185
+ txt = self.txt_in(txt)
186
+
187
+ ids = torch.cat((txt_ids, img_ids), dim=1)
188
+ pe = self.pe_embedder(ids)
189
+
190
+ block_res_samples = ()
191
+
192
+ for block in self.double_blocks:
193
+ if self.training and self.gradient_checkpointing:
194
+
195
+ def create_custom_forward(module, return_dict=None):
196
+ def custom_forward(*inputs):
197
+ if return_dict is not None:
198
+ return module(*inputs, return_dict=return_dict)
199
+ else:
200
+ return module(*inputs)
201
+
202
+ return custom_forward
203
+
204
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
205
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
206
+ create_custom_forward(block),
207
+ img,
208
+ txt,
209
+ vec,
210
+ pe,
211
+ )
212
+ else:
213
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
214
+
215
+ block_res_samples = block_res_samples + (img,)
216
+
217
+ controlnet_block_res_samples = ()
218
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
219
+ block_res_sample = controlnet_block(block_res_sample)
220
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
221
+
222
+ return controlnet_block_res_samples
src/flux/math.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18
+ omega = 1.0 / (theta**scale)
19
+ out = torch.einsum("...n,d->...nd", pos, omega)
20
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
+ return out.float()
23
+
24
+
25
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
src/flux/model.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+ from einops import rearrange
6
+
7
+ from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
8
+ MLPEmbedder, SingleStreamBlock,
9
+ timestep_embedding)
10
+
11
+
12
+ @dataclass
13
+ class FluxParams:
14
+ in_channels: int
15
+ vec_in_dim: int
16
+ context_in_dim: int
17
+ hidden_size: int
18
+ mlp_ratio: float
19
+ num_heads: int
20
+ depth: int
21
+ depth_single_blocks: int
22
+ axes_dim: list[int]
23
+ theta: int
24
+ qkv_bias: bool
25
+ guidance_embed: bool
26
+
27
+
28
+ class Flux(nn.Module):
29
+ """
30
+ Transformer model for flow matching on sequences.
31
+ """
32
+ _supports_gradient_checkpointing = True
33
+
34
+ def __init__(self, params: FluxParams):
35
+ super().__init__()
36
+
37
+ self.params = params
38
+ self.in_channels = params.in_channels
39
+ self.out_channels = self.in_channels
40
+ if params.hidden_size % params.num_heads != 0:
41
+ raise ValueError(
42
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
43
+ )
44
+ pe_dim = params.hidden_size // params.num_heads
45
+ if sum(params.axes_dim) != pe_dim:
46
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
47
+ self.hidden_size = params.hidden_size
48
+ self.num_heads = params.num_heads
49
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
50
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
51
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
52
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
53
+ self.guidance_in = (
54
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
55
+ )
56
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
57
+
58
+ self.double_blocks = nn.ModuleList(
59
+ [
60
+ DoubleStreamBlock(
61
+ self.hidden_size,
62
+ self.num_heads,
63
+ mlp_ratio=params.mlp_ratio,
64
+ qkv_bias=params.qkv_bias,
65
+ )
66
+ for _ in range(params.depth)
67
+ ]
68
+ )
69
+
70
+ self.single_blocks = nn.ModuleList(
71
+ [
72
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
73
+ for _ in range(params.depth_single_blocks)
74
+ ]
75
+ )
76
+
77
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
78
+ self.gradient_checkpointing = False
79
+
80
+ def _set_gradient_checkpointing(self, module, value=False):
81
+ if hasattr(module, "gradient_checkpointing"):
82
+ module.gradient_checkpointing = value
83
+
84
+ @property
85
+ def attn_processors(self):
86
+ # set recursively
87
+ processors = {}
88
+
89
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
90
+ if hasattr(module, "set_processor"):
91
+ processors[f"{name}.processor"] = module.processor
92
+
93
+ for sub_name, child in module.named_children():
94
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
95
+
96
+ return processors
97
+
98
+ for name, module in self.named_children():
99
+ fn_recursive_add_processors(name, module, processors)
100
+
101
+ return processors
102
+
103
+ def set_attn_processor(self, processor):
104
+ r"""
105
+ Sets the attention processor to use to compute attention.
106
+
107
+ Parameters:
108
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
109
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
110
+ for **all** `Attention` layers.
111
+
112
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
113
+ processor. This is strongly recommended when setting trainable attention processors.
114
+
115
+ """
116
+ count = len(self.attn_processors.keys())
117
+
118
+ if isinstance(processor, dict) and len(processor) != count:
119
+ raise ValueError(
120
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
121
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
122
+ )
123
+
124
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
125
+ if hasattr(module, "set_processor"):
126
+ if not isinstance(processor, dict):
127
+ module.set_processor(processor)
128
+ else:
129
+ module.set_processor(processor.pop(f"{name}.processor"))
130
+
131
+ for sub_name, child in module.named_children():
132
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
133
+
134
+ for name, module in self.named_children():
135
+ fn_recursive_attn_processor(name, module, processor)
136
+
137
+ def forward(
138
+ self,
139
+ img: Tensor,
140
+ img_ids: Tensor,
141
+ txt: Tensor,
142
+ txt_ids: Tensor,
143
+ timesteps: Tensor,
144
+ y: Tensor,
145
+ block_controlnet_hidden_states=None,
146
+ guidance: Tensor | None = None,
147
+ image_proj: Tensor | None = None,
148
+ ip_scale: Tensor | float = 1.0,
149
+ ) -> Tensor:
150
+ if img.ndim != 3 or txt.ndim != 3:
151
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
152
+
153
+ # running on sequences img
154
+ img = self.img_in(img)
155
+ vec = self.time_in(timestep_embedding(timesteps, 256))
156
+ if self.params.guidance_embed:
157
+ if guidance is None:
158
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
159
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
160
+ vec = vec + self.vector_in(y)
161
+ txt = self.txt_in(txt)
162
+
163
+ ids = torch.cat((txt_ids, img_ids), dim=1)
164
+ pe = self.pe_embedder(ids)
165
+ if block_controlnet_hidden_states is not None:
166
+ controlnet_depth = len(block_controlnet_hidden_states)
167
+ for index_block, block in enumerate(self.double_blocks):
168
+ if self.training and self.gradient_checkpointing:
169
+
170
+ def create_custom_forward(module, return_dict=None):
171
+ def custom_forward(*inputs):
172
+ if return_dict is not None:
173
+ return module(*inputs, return_dict=return_dict)
174
+ else:
175
+ return module(*inputs)
176
+
177
+ return custom_forward
178
+
179
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
180
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
181
+ create_custom_forward(block),
182
+ img,
183
+ txt,
184
+ vec,
185
+ pe,
186
+ image_proj,
187
+ ip_scale,
188
+ )
189
+ else:
190
+ img, txt = block(
191
+ img=img,
192
+ txt=txt,
193
+ vec=vec,
194
+ pe=pe,
195
+ image_proj=image_proj,
196
+ ip_scale=ip_scale,
197
+ )
198
+ # controlnet residual
199
+ if block_controlnet_hidden_states is not None:
200
+ img = img + block_controlnet_hidden_states[index_block % 2]
201
+
202
+
203
+ img = torch.cat((txt, img), 1)
204
+ for block in self.single_blocks:
205
+ if self.training and self.gradient_checkpointing:
206
+
207
+ def create_custom_forward(module, return_dict=None):
208
+ def custom_forward(*inputs):
209
+ if return_dict is not None:
210
+ return module(*inputs, return_dict=return_dict)
211
+ else:
212
+ return module(*inputs)
213
+
214
+ return custom_forward
215
+
216
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
217
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
218
+ create_custom_forward(block),
219
+ img,
220
+ vec,
221
+ pe,
222
+ )
223
+ else:
224
+ img = block(img, vec=vec, pe=pe)
225
+ img = img[:, txt.shape[1] :, ...]
226
+
227
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
228
+ return img
src/flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
src/flux/modules/conditioner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = version.startswith("openai")
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
+
20
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
21
+
22
+ def forward(self, text: list[str]) -> Tensor:
23
+ batch_encoding = self.tokenizer(
24
+ text,
25
+ truncation=True,
26
+ max_length=self.max_length,
27
+ return_length=False,
28
+ return_overflowing_tokens=False,
29
+ padding="max_length",
30
+ return_tensors="pt",
31
+ )
32
+
33
+ outputs = self.hf_module(
34
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
35
+ attention_mask=None,
36
+ output_hidden_states=False,
37
+ )
38
+ return outputs[self.output_key]
src/flux/modules/layers.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from ..math import attention, rope
9
+ import torch.nn.functional as F
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+ class LoRALinearLayer(nn.Module):
87
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
88
+ super().__init__()
89
+
90
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
91
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
92
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
93
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
94
+ self.network_alpha = network_alpha
95
+ self.rank = rank
96
+
97
+ nn.init.normal_(self.down.weight, std=1 / rank)
98
+ nn.init.zeros_(self.up.weight)
99
+
100
+ def forward(self, hidden_states):
101
+ orig_dtype = hidden_states.dtype
102
+ dtype = self.down.weight.dtype
103
+
104
+ down_hidden_states = self.down(hidden_states.to(dtype))
105
+ up_hidden_states = self.up(down_hidden_states)
106
+
107
+ if self.network_alpha is not None:
108
+ up_hidden_states *= self.network_alpha / self.rank
109
+
110
+ return up_hidden_states.to(orig_dtype)
111
+
112
+ class FLuxSelfAttnProcessor:
113
+ def __call__(self, attn, x, pe, **attention_kwargs):
114
+ print('2' * 30)
115
+
116
+ qkv = attn.qkv(x)
117
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
118
+ q, k = attn.norm(q, k, v)
119
+ x = attention(q, k, v, pe=pe)
120
+ x = attn.proj(x)
121
+ return x
122
+
123
+ class LoraFluxAttnProcessor(nn.Module):
124
+
125
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
126
+ super().__init__()
127
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
128
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
129
+ self.lora_weight = lora_weight
130
+
131
+
132
+ def __call__(self, attn, x, pe, **attention_kwargs):
133
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
134
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
135
+ q, k = attn.norm(q, k, v)
136
+ x = attention(q, k, v, pe=pe)
137
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
138
+ print('1' * 30)
139
+ print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
140
+ return x
141
+
142
+ class SelfAttention(nn.Module):
143
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
144
+ super().__init__()
145
+ self.num_heads = num_heads
146
+ head_dim = dim // num_heads
147
+
148
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
149
+ self.norm = QKNorm(head_dim)
150
+ self.proj = nn.Linear(dim, dim)
151
+ def forward():
152
+ pass
153
+
154
+
155
+ @dataclass
156
+ class ModulationOut:
157
+ shift: Tensor
158
+ scale: Tensor
159
+ gate: Tensor
160
+
161
+
162
+ class Modulation(nn.Module):
163
+ def __init__(self, dim: int, double: bool):
164
+ super().__init__()
165
+ self.is_double = double
166
+ self.multiplier = 6 if double else 3
167
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
168
+
169
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
170
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
171
+
172
+ return (
173
+ ModulationOut(*out[:3]),
174
+ ModulationOut(*out[3:]) if self.is_double else None,
175
+ )
176
+
177
+ class DoubleStreamBlockLoraProcessor(nn.Module):
178
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
179
+ super().__init__()
180
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
181
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
182
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
183
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
184
+ self.lora_weight = lora_weight
185
+
186
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
187
+ img_mod1, img_mod2 = attn.img_mod(vec)
188
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
189
+
190
+ # prepare image for attention
191
+ img_modulated = attn.img_norm1(img)
192
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
193
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
194
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
195
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
196
+
197
+ # prepare txt for attention
198
+ txt_modulated = attn.txt_norm1(txt)
199
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
200
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
201
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
202
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
203
+
204
+ # run actual attention
205
+ q = torch.cat((txt_q, img_q), dim=2)
206
+ k = torch.cat((txt_k, img_k), dim=2)
207
+ v = torch.cat((txt_v, img_v), dim=2)
208
+
209
+ attn1 = attention(q, k, v, pe=pe)
210
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
211
+
212
+ # calculate the img bloks
213
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
214
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
215
+
216
+ # calculate the txt bloks
217
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
218
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
219
+ return img, txt
220
+
221
+ class IPDoubleStreamBlockProcessor(nn.Module):
222
+ """Attention processor for handling IP-adapter with double stream block."""
223
+
224
+ def __init__(self, context_dim, hidden_dim):
225
+ super().__init__()
226
+ if not hasattr(F, "scaled_dot_product_attention"):
227
+ raise ImportError(
228
+ "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
229
+ )
230
+
231
+ # Ensure context_dim matches the dimension of image_proj
232
+ self.context_dim = context_dim
233
+ self.hidden_dim = hidden_dim
234
+
235
+ # Initialize projections for IP-adapter
236
+ self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True)
237
+ self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True)
238
+
239
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight)
240
+ nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias)
241
+
242
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight)
243
+ nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias)
244
+
245
+ def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs):
246
+
247
+ # Prepare image for attention
248
+ img_mod1, img_mod2 = attn.img_mod(vec)
249
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
250
+
251
+ img_modulated = attn.img_norm1(img)
252
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
253
+ img_qkv = attn.img_attn.qkv(img_modulated)
254
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
255
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
256
+
257
+ txt_modulated = attn.txt_norm1(txt)
258
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
259
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
260
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
261
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
262
+
263
+ q = torch.cat((txt_q, img_q), dim=2)
264
+ k = torch.cat((txt_k, img_k), dim=2)
265
+ v = torch.cat((txt_v, img_v), dim=2)
266
+
267
+ attn1 = attention(q, k, v, pe=pe)
268
+ txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:]
269
+
270
+ # print(f"txt_attn shape: {txt_attn.size()}")
271
+ # print(f"img_attn shape: {img_attn.size()}")
272
+
273
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
274
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
275
+
276
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
277
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
278
+
279
+
280
+ # IP-adapter processing
281
+ ip_query = img_q # latent sample query
282
+ ip_key = self.ip_adapter_double_stream_k_proj(image_proj)
283
+ ip_value = self.ip_adapter_double_stream_v_proj(image_proj)
284
+
285
+ # Reshape projections for multi-head attention
286
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
287
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
288
+
289
+ # Compute attention between IP projections and the latent query
290
+ ip_attention = F.scaled_dot_product_attention(
291
+ ip_query,
292
+ ip_key,
293
+ ip_value,
294
+ dropout_p=0.0,
295
+ is_causal=False
296
+ )
297
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim)
298
+
299
+ img = img + ip_scale * ip_attention
300
+
301
+ return img, txt
302
+
303
+ class DoubleStreamBlockProcessor:
304
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
305
+ img_mod1, img_mod2 = attn.img_mod(vec)
306
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
307
+
308
+ # prepare image for attention
309
+ img_modulated = attn.img_norm1(img)
310
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
311
+ img_qkv = attn.img_attn.qkv(img_modulated)
312
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
313
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
314
+
315
+ # prepare txt for attention
316
+ txt_modulated = attn.txt_norm1(txt)
317
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
318
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
319
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
320
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
321
+
322
+ # run actual attention
323
+ q = torch.cat((txt_q, img_q), dim=2)
324
+ k = torch.cat((txt_k, img_k), dim=2)
325
+ v = torch.cat((txt_v, img_v), dim=2)
326
+
327
+ attn1 = attention(q, k, v, pe=pe)
328
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
329
+
330
+ # calculate the img bloks
331
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
332
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
333
+
334
+ # calculate the txt bloks
335
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
336
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
337
+ return img, txt
338
+
339
+ class DoubleStreamBlock(nn.Module):
340
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
341
+ super().__init__()
342
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
343
+ self.num_heads = num_heads
344
+ self.hidden_size = hidden_size
345
+ self.head_dim = hidden_size // num_heads
346
+
347
+ self.img_mod = Modulation(hidden_size, double=True)
348
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
349
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
350
+
351
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
352
+ self.img_mlp = nn.Sequential(
353
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
354
+ nn.GELU(approximate="tanh"),
355
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
356
+ )
357
+
358
+ self.txt_mod = Modulation(hidden_size, double=True)
359
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
360
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
361
+
362
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
363
+ self.txt_mlp = nn.Sequential(
364
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
365
+ nn.GELU(approximate="tanh"),
366
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
367
+ )
368
+ processor = DoubleStreamBlockProcessor()
369
+ self.set_processor(processor)
370
+
371
+ def set_processor(self, processor) -> None:
372
+ self.processor = processor
373
+
374
+ def get_processor(self):
375
+ return self.processor
376
+
377
+ def forward(
378
+ self,
379
+ img: Tensor,
380
+ txt: Tensor,
381
+ vec: Tensor,
382
+ pe: Tensor,
383
+ image_proj: Tensor = None,
384
+ ip_scale: float =1.0,
385
+ ) -> tuple[Tensor, Tensor]:
386
+ if image_proj is None:
387
+ return self.processor(self, img, txt, vec, pe)
388
+ else:
389
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
390
+
391
+ class IPSingleStreamBlockProcessor(nn.Module):
392
+ """Attention processor for handling IP-adapter with single stream block."""
393
+ def __init__(self, context_dim, hidden_dim):
394
+ super().__init__()
395
+ if not hasattr(F, "scaled_dot_product_attention"):
396
+ raise ImportError(
397
+ "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch."
398
+ )
399
+
400
+ # Ensure context_dim matches the dimension of image_proj
401
+ self.context_dim = context_dim
402
+ self.hidden_dim = hidden_dim
403
+
404
+ # Initialize projections for IP-adapter
405
+ self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
406
+ self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
407
+
408
+ nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight)
409
+ nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight)
410
+
411
+ def __call__(
412
+ self,
413
+ attn: nn.Module,
414
+ x: Tensor,
415
+ vec: Tensor,
416
+ pe: Tensor,
417
+ image_proj: Tensor | None = None,
418
+ ip_scale: float = 1.0
419
+ ) -> Tensor:
420
+
421
+ mod, _ = attn.modulation(vec)
422
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
423
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
424
+
425
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
426
+ q, k = attn.norm(q, k, v)
427
+
428
+ # compute attention
429
+ attn_1 = attention(q, k, v, pe=pe)
430
+
431
+ # IP-adapter processing
432
+ ip_query = q
433
+ ip_key = self.ip_adapter_single_stream_k_proj(image_proj)
434
+ ip_value = self.ip_adapter_single_stream_v_proj(image_proj)
435
+
436
+ # Reshape projections for multi-head attention
437
+ ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
438
+ ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim)
439
+
440
+
441
+ # Compute attention between IP projections and the latent query
442
+ ip_attention = F.scaled_dot_product_attention(
443
+ ip_query,
444
+ ip_key,
445
+ ip_value
446
+ )
447
+ ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)")
448
+
449
+ attn_out = attn_1 + ip_scale * ip_attention
450
+
451
+ # compute activation in mlp stream, cat again and run second linear layer
452
+ output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2))
453
+ out = x + mod.gate * output
454
+
455
+ return out
456
+
457
+ class SingleStreamBlockProcessor:
458
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
459
+
460
+ mod, _ = attn.modulation(vec)
461
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
462
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
463
+
464
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
465
+ q, k = attn.norm(q, k, v)
466
+
467
+ # compute attention
468
+ attn_1 = attention(q, k, v, pe=pe)
469
+
470
+ # compute activation in mlp stream, cat again and run second linear layer
471
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
472
+ output = x + mod.gate * output
473
+ return output
474
+
475
+ class SingleStreamBlock(nn.Module):
476
+ """
477
+ A DiT block with parallel linear layers as described in
478
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ hidden_size: int,
484
+ num_heads: int,
485
+ mlp_ratio: float = 4.0,
486
+ qk_scale: float | None = None,
487
+ ):
488
+ super().__init__()
489
+ self.hidden_dim = hidden_size
490
+ self.num_heads = num_heads
491
+ self.head_dim = hidden_size // num_heads
492
+ self.scale = qk_scale or self.head_dim**-0.5
493
+
494
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
495
+ # qkv and mlp_in
496
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
497
+ # proj and mlp_out
498
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
499
+
500
+ self.norm = QKNorm(self.head_dim)
501
+
502
+ self.hidden_size = hidden_size
503
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
504
+
505
+ self.mlp_act = nn.GELU(approximate="tanh")
506
+ self.modulation = Modulation(hidden_size, double=False)
507
+
508
+ processor = SingleStreamBlockProcessor()
509
+ self.set_processor(processor)
510
+
511
+
512
+ def set_processor(self, processor) -> None:
513
+ self.processor = processor
514
+
515
+ def get_processor(self):
516
+ return self.processor
517
+
518
+ def forward(
519
+ self,
520
+ x: Tensor,
521
+ vec: Tensor,
522
+ pe: Tensor,
523
+ image_proj: Tensor | None = None,
524
+ ip_scale: float = 1.0
525
+ ) -> Tensor:
526
+ if image_proj is None:
527
+ return self.processor(self, x, vec, pe)
528
+ else:
529
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
530
+
531
+
532
+
533
+ class LastLayer(nn.Module):
534
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
535
+ super().__init__()
536
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
537
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
538
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
539
+
540
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
541
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
542
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
543
+ x = self.linear(x)
544
+ return x
545
+
546
+ class ImageProjModel(torch.nn.Module):
547
+ """Projection Model
548
+ https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28
549
+ """
550
+
551
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
552
+ super().__init__()
553
+
554
+ self.generator = None
555
+ self.cross_attention_dim = cross_attention_dim
556
+ self.clip_extra_context_tokens = clip_extra_context_tokens
557
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
558
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
559
+
560
+ def forward(self, image_embeds):
561
+ embeds = image_embeds
562
+ clip_extra_context_tokens = self.proj(embeds).reshape(
563
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
564
+ )
565
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
566
+ return clip_extra_context_tokens
567
+
src/flux/sampling.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def get_noise(
13
+ num_samples: int,
14
+ height: int,
15
+ width: int,
16
+ device: torch.device,
17
+ dtype: torch.dtype,
18
+ seed: int,
19
+ ):
20
+ return torch.randn(
21
+ num_samples,
22
+ 16,
23
+ # allow for packing
24
+ 2 * math.ceil(height / 16),
25
+ 2 * math.ceil(width / 16),
26
+ device=device,
27
+ dtype=dtype,
28
+ generator=torch.Generator(device=device).manual_seed(seed),
29
+ )
30
+
31
+
32
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ txt = t5(prompt)
49
+ if txt.shape[0] == 1 and bs > 1:
50
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
52
+
53
+ vec = clip(prompt)
54
+ if vec.shape[0] == 1 and bs > 1:
55
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56
+
57
+ return {
58
+ "img": img,
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "vec": vec.to(img.device),
63
+ }
64
+
65
+
66
+ def time_shift(mu: float, sigma: float, t: Tensor):
67
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68
+
69
+
70
+ def get_lin_function(
71
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72
+ ) -> Callable[[float], float]:
73
+ m = (y2 - y1) / (x2 - x1)
74
+ b = y1 - m * x1
75
+ return lambda x: m * x + b
76
+
77
+
78
+ def get_schedule(
79
+ num_steps: int,
80
+ image_seq_len: int,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ shift: bool = True,
84
+ ) -> list[float]:
85
+ # extra step for zero
86
+ timesteps = torch.linspace(1, 0, num_steps + 1)
87
+
88
+ # shifting the schedule to favor high timesteps for higher signal images
89
+ if shift:
90
+ # eastimate mu based on linear estimation between two points
91
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92
+ timesteps = time_shift(mu, 1.0, timesteps)
93
+
94
+ return timesteps.tolist()
95
+
96
+
97
+ def denoise(
98
+ model: Flux,
99
+ # model input
100
+ img: Tensor,
101
+ img_ids: Tensor,
102
+ txt: Tensor,
103
+ txt_ids: Tensor,
104
+ vec: Tensor,
105
+ neg_txt: Tensor,
106
+ neg_txt_ids: Tensor,
107
+ neg_vec: Tensor,
108
+ # sampling parameters
109
+ timesteps: list[float],
110
+ guidance: float = 4.0,
111
+ true_gs = 1,
112
+ timestep_to_start_cfg=0,
113
+ # ip-adapter parameters
114
+ image_proj: Tensor=None,
115
+ neg_image_proj: Tensor=None,
116
+ ip_scale: Tensor | float = 1.0,
117
+ neg_ip_scale: Tensor | float = 1.0
118
+ ):
119
+ i = 0
120
+ # this is ignored for schnell
121
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
122
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
123
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
124
+ pred = model(
125
+ img=img,
126
+ img_ids=img_ids,
127
+ txt=txt,
128
+ txt_ids=txt_ids,
129
+ y=vec,
130
+ timesteps=t_vec,
131
+ guidance=guidance_vec,
132
+ image_proj=image_proj,
133
+ ip_scale=ip_scale,
134
+ )
135
+ if i >= timestep_to_start_cfg:
136
+ neg_pred = model(
137
+ img=img,
138
+ img_ids=img_ids,
139
+ txt=neg_txt,
140
+ txt_ids=neg_txt_ids,
141
+ y=neg_vec,
142
+ timesteps=t_vec,
143
+ guidance=guidance_vec,
144
+ image_proj=neg_image_proj,
145
+ ip_scale=neg_ip_scale,
146
+ )
147
+ pred = neg_pred + true_gs * (pred - neg_pred)
148
+ img = img + (t_prev - t_curr) * pred
149
+ i += 1
150
+ return img
151
+
152
+ def denoise_controlnet(
153
+ model: Flux,
154
+ controlnet:None,
155
+ # model input
156
+ img: Tensor,
157
+ img_ids: Tensor,
158
+ txt: Tensor,
159
+ txt_ids: Tensor,
160
+ vec: Tensor,
161
+ neg_txt: Tensor,
162
+ neg_txt_ids: Tensor,
163
+ neg_vec: Tensor,
164
+ controlnet_cond,
165
+ # sampling parameters
166
+ timesteps: list[float],
167
+ guidance: float = 4.0,
168
+ true_gs = 1,
169
+ controlnet_gs=0.7,
170
+ timestep_to_start_cfg=0,
171
+ # ip-adapter parameters
172
+ image_proj: Tensor=None,
173
+ neg_image_proj: Tensor=None,
174
+ ip_scale: Tensor | float = 1,
175
+ neg_ip_scale: Tensor | float = 1,
176
+ ):
177
+ # this is ignored for schnell
178
+ i = 0
179
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
180
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
181
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
182
+ block_res_samples = controlnet(
183
+ img=img,
184
+ img_ids=img_ids,
185
+ controlnet_cond=controlnet_cond,
186
+ txt=txt,
187
+ txt_ids=txt_ids,
188
+ y=vec,
189
+ timesteps=t_vec,
190
+ guidance=guidance_vec,
191
+ )
192
+ pred = model(
193
+ img=img,
194
+ img_ids=img_ids,
195
+ txt=txt,
196
+ txt_ids=txt_ids,
197
+ y=vec,
198
+ timesteps=t_vec,
199
+ guidance=guidance_vec,
200
+ block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples],
201
+ image_proj=image_proj,
202
+ ip_scale=ip_scale,
203
+ )
204
+ if i >= timestep_to_start_cfg:
205
+ neg_block_res_samples = controlnet(
206
+ img=img,
207
+ img_ids=img_ids,
208
+ controlnet_cond=controlnet_cond,
209
+ txt=neg_txt,
210
+ txt_ids=neg_txt_ids,
211
+ y=neg_vec,
212
+ timesteps=t_vec,
213
+ guidance=guidance_vec,
214
+ )
215
+ neg_pred = model(
216
+ img=img,
217
+ img_ids=img_ids,
218
+ txt=neg_txt,
219
+ txt_ids=neg_txt_ids,
220
+ y=neg_vec,
221
+ timesteps=t_vec,
222
+ guidance=guidance_vec,
223
+ block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples],
224
+ image_proj=neg_image_proj,
225
+ ip_scale=neg_ip_scale,
226
+ )
227
+ pred = neg_pred + true_gs * (pred - neg_pred)
228
+
229
+ img = img + (t_prev - t_curr) * pred
230
+
231
+ i += 1
232
+ return img
233
+
234
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
235
+ return rearrange(
236
+ x,
237
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
238
+ h=math.ceil(height / 16),
239
+ w=math.ceil(width / 16),
240
+ ph=2,
241
+ pw=2,
242
+ )
src/flux/util.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import json
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors import safe_open
11
+ from safetensors.torch import load_file as load_sft
12
+
13
+ from optimum.quanto import requantize
14
+
15
+ from .model import Flux, FluxParams
16
+ from .controlnet import ControlNetFlux
17
+ from .modules.autoencoder import AutoEncoder, AutoEncoderParams
18
+ from .modules.conditioner import HFEmbedder
19
+ from .annotator.dwpose import DWposeDetector
20
+ from .annotator.mlsd import MLSDdetector
21
+ from .annotator.canny import CannyDetector
22
+ from .annotator.midas import MidasDetector
23
+ from .annotator.hed import HEDdetector
24
+ from .annotator.tile import TileDetector
25
+
26
+
27
+ def load_safetensors(path):
28
+ tensors = {}
29
+ with safe_open(path, framework="pt", device="cpu") as f:
30
+ for key in f.keys():
31
+ tensors[key] = f.get_tensor(key)
32
+ return tensors
33
+
34
+ def get_lora_rank(checkpoint):
35
+ for k in checkpoint.keys():
36
+ if k.endswith(".down.weight"):
37
+ return checkpoint[k].shape[0]
38
+
39
+ def load_checkpoint(local_path, repo_id, name):
40
+ if local_path is not None:
41
+ if '.safetensors' in local_path:
42
+ print("Loading .safetensors checkpoint...")
43
+ checkpoint = load_safetensors(local_path)
44
+ else:
45
+ print("Loading checkpoint...")
46
+ checkpoint = torch.load(local_path, map_location='cpu')
47
+ elif repo_id is not None and name is not None:
48
+ print("Loading checkpoint from repo id...")
49
+ checkpoint = load_from_repo_id(repo_id, name)
50
+ else:
51
+ raise ValueError(
52
+ "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
53
+ )
54
+ return checkpoint
55
+
56
+
57
+ def c_crop(image):
58
+ width, height = image.size
59
+ new_size = min(width, height)
60
+ left = (width - new_size) / 2
61
+ top = (height - new_size) / 2
62
+ right = (width + new_size) / 2
63
+ bottom = (height + new_size) / 2
64
+ return image.crop((left, top, right, bottom))
65
+
66
+
67
+ class Annotator:
68
+ def __init__(self, name: str, device: str):
69
+ if name == "canny":
70
+ processor = CannyDetector()
71
+ elif name == "openpose":
72
+ processor = DWposeDetector(device)
73
+ elif name == "depth":
74
+ processor = MidasDetector()
75
+ elif name == "hed":
76
+ processor = HEDdetector()
77
+ elif name == "hough":
78
+ processor = MLSDdetector()
79
+ elif name == "tile":
80
+ processor = TileDetector()
81
+ self.name = name
82
+ self.processor = processor
83
+
84
+ def __call__(self, image: Image, width: int, height: int):
85
+ image = c_crop(image)
86
+ image = image.resize((width, height))
87
+ image = np.array(image)
88
+ if self.name == "canny":
89
+ result = self.processor(image, low_threshold=100, high_threshold=200)
90
+ elif self.name == "hough":
91
+ result = self.processor(image, thr_v=0.05, thr_d=5)
92
+ elif self.name == "depth":
93
+ result = self.processor(image)
94
+ result, _ = result
95
+ else:
96
+ result = self.processor(image)
97
+
98
+ if result.ndim != 3:
99
+ result = result[:, :, None]
100
+ result = np.concatenate([result, result, result], axis=2)
101
+ return result
102
+
103
+
104
+ @dataclass
105
+ class ModelSpec:
106
+ params: FluxParams
107
+ ae_params: AutoEncoderParams
108
+ ckpt_path: str | None
109
+ ae_path: str | None
110
+ repo_id: str | None
111
+ repo_flow: str | None
112
+ repo_ae: str | None
113
+ repo_id_ae: str | None
114
+
115
+
116
+ configs = {
117
+ "flux-dev": ModelSpec(
118
+ repo_id="black-forest-labs/FLUX.1-dev",
119
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
120
+ repo_flow="flux1-dev.safetensors",
121
+ repo_ae="ae.safetensors",
122
+ ckpt_path=os.getenv("FLUX_DEV"),
123
+ params=FluxParams(
124
+ in_channels=64,
125
+ vec_in_dim=768,
126
+ context_in_dim=4096,
127
+ hidden_size=3072,
128
+ mlp_ratio=4.0,
129
+ num_heads=24,
130
+ depth=19,
131
+ depth_single_blocks=38,
132
+ axes_dim=[16, 56, 56],
133
+ theta=10_000,
134
+ qkv_bias=True,
135
+ guidance_embed=True,
136
+ ),
137
+ ae_path=os.getenv("AE"),
138
+ ae_params=AutoEncoderParams(
139
+ resolution=256,
140
+ in_channels=3,
141
+ ch=128,
142
+ out_ch=3,
143
+ ch_mult=[1, 2, 4, 4],
144
+ num_res_blocks=2,
145
+ z_channels=16,
146
+ scale_factor=0.3611,
147
+ shift_factor=0.1159,
148
+ ),
149
+ ),
150
+ "flux-dev-fp8": ModelSpec(
151
+ repo_id="XLabs-AI/flux-dev-fp8",
152
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
153
+ repo_flow="flux-dev-fp8.safetensors",
154
+ repo_ae="ae.safetensors",
155
+ ckpt_path=os.getenv("FLUX_DEV_FP8"),
156
+ params=FluxParams(
157
+ in_channels=64,
158
+ vec_in_dim=768,
159
+ context_in_dim=4096,
160
+ hidden_size=3072,
161
+ mlp_ratio=4.0,
162
+ num_heads=24,
163
+ depth=19,
164
+ depth_single_blocks=38,
165
+ axes_dim=[16, 56, 56],
166
+ theta=10_000,
167
+ qkv_bias=True,
168
+ guidance_embed=True,
169
+ ),
170
+ ae_path=os.getenv("AE"),
171
+ ae_params=AutoEncoderParams(
172
+ resolution=256,
173
+ in_channels=3,
174
+ ch=128,
175
+ out_ch=3,
176
+ ch_mult=[1, 2, 4, 4],
177
+ num_res_blocks=2,
178
+ z_channels=16,
179
+ scale_factor=0.3611,
180
+ shift_factor=0.1159,
181
+ ),
182
+ ),
183
+ "flux-schnell": ModelSpec(
184
+ repo_id="black-forest-labs/FLUX.1-schnell",
185
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
186
+ repo_flow="flux1-schnell.safetensors",
187
+ repo_ae="ae.safetensors",
188
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
189
+ params=FluxParams(
190
+ in_channels=64,
191
+ vec_in_dim=768,
192
+ context_in_dim=4096,
193
+ hidden_size=3072,
194
+ mlp_ratio=4.0,
195
+ num_heads=24,
196
+ depth=19,
197
+ depth_single_blocks=38,
198
+ axes_dim=[16, 56, 56],
199
+ theta=10_000,
200
+ qkv_bias=True,
201
+ guidance_embed=False,
202
+ ),
203
+ ae_path=os.getenv("AE"),
204
+ ae_params=AutoEncoderParams(
205
+ resolution=256,
206
+ in_channels=3,
207
+ ch=128,
208
+ out_ch=3,
209
+ ch_mult=[1, 2, 4, 4],
210
+ num_res_blocks=2,
211
+ z_channels=16,
212
+ scale_factor=0.3611,
213
+ shift_factor=0.1159,
214
+ ),
215
+ ),
216
+ }
217
+
218
+
219
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
220
+ if len(missing) > 0 and len(unexpected) > 0:
221
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
222
+ print("\n" + "-" * 79 + "\n")
223
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
224
+ elif len(missing) > 0:
225
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
226
+ elif len(unexpected) > 0:
227
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
228
+
229
+ def load_from_repo_id(repo_id, checkpoint_name):
230
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
231
+ sd = load_sft(ckpt_path, device='cpu')
232
+ return sd
233
+
234
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
235
+ # Loading Flux
236
+ print("Init model")
237
+ ckpt_path = configs[name].ckpt_path
238
+ if (
239
+ ckpt_path is None
240
+ and configs[name].repo_id is not None
241
+ and configs[name].repo_flow is not None
242
+ and hf_download
243
+ ):
244
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
245
+
246
+ with torch.device("meta" if ckpt_path is not None else device):
247
+ model = Flux(configs[name].params).to(torch.bfloat16)
248
+
249
+ if ckpt_path is not None:
250
+ print("Loading checkpoint")
251
+ # load_sft doesn't support torch.device
252
+ sd = load_sft(ckpt_path, device=str(device))
253
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
254
+ print_load_warning(missing, unexpected)
255
+ return model
256
+
257
+ def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
258
+ # Loading Flux
259
+ print("Init model")
260
+ ckpt_path = configs[name].ckpt_path
261
+ if (
262
+ ckpt_path is None
263
+ and configs[name].repo_id is not None
264
+ and configs[name].repo_flow is not None
265
+ and hf_download
266
+ ):
267
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
268
+
269
+ with torch.device("meta" if ckpt_path is not None else device):
270
+ model = Flux(configs[name].params)
271
+
272
+ if ckpt_path is not None:
273
+ print("Loading checkpoint")
274
+ # load_sft doesn't support torch.device
275
+ sd = load_sft(ckpt_path, device=str(device))
276
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
277
+ print_load_warning(missing, unexpected)
278
+ return model
279
+
280
+ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
281
+ # Loading Flux
282
+ print("Init model")
283
+ ckpt_path = configs[name].ckpt_path
284
+ if (
285
+ ckpt_path is None
286
+ and configs[name].repo_id is not None
287
+ and configs[name].repo_flow is not None
288
+ and hf_download
289
+ ):
290
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
291
+ json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
292
+
293
+
294
+ model = Flux(configs[name].params).to(torch.bfloat16)
295
+
296
+ print("Loading checkpoint")
297
+ # load_sft doesn't support torch.device
298
+ sd = load_sft(ckpt_path, device='cpu')
299
+ with open(json_path, "r") as f:
300
+ quantization_map = json.load(f)
301
+ print("Start a quantization process...")
302
+ requantize(model, sd, quantization_map, device=device)
303
+ print("Model is quantized!")
304
+ return model
305
+
306
+ def load_controlnet(name, device, transformer=None):
307
+ with torch.device(device):
308
+ controlnet = ControlNetFlux(configs[name].params)
309
+ if transformer is not None:
310
+ controlnet.load_state_dict(transformer.state_dict(), strict=False)
311
+ return controlnet
312
+
313
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
314
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
315
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
316
+
317
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
318
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
319
+
320
+
321
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
322
+ ckpt_path = configs[name].ae_path
323
+ if (
324
+ ckpt_path is None
325
+ and configs[name].repo_id is not None
326
+ and configs[name].repo_ae is not None
327
+ and hf_download
328
+ ):
329
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
330
+
331
+ # Loading the autoencoder
332
+ print("Init AE")
333
+ with torch.device("meta" if ckpt_path is not None else device):
334
+ ae = AutoEncoder(configs[name].ae_params)
335
+
336
+ if ckpt_path is not None:
337
+ sd = load_sft(ckpt_path, device=str(device))
338
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
339
+ print_load_warning(missing, unexpected)
340
+ return ae
341
+
342
+
343
+ class WatermarkEmbedder:
344
+ def __init__(self, watermark):
345
+ self.watermark = watermark
346
+ self.num_bits = len(WATERMARK_BITS)
347
+ self.encoder = WatermarkEncoder()
348
+ self.encoder.set_watermark("bits", self.watermark)
349
+
350
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
351
+ """
352
+ Adds a predefined watermark to the input image
353
+
354
+ Args:
355
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
356
+
357
+ Returns:
358
+ same as input but watermarked
359
+ """
360
+ image = 0.5 * image + 0.5
361
+ squeeze = len(image.shape) == 4
362
+ if squeeze:
363
+ image = image[None, ...]
364
+ n = image.shape[0]
365
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
366
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
367
+ # watermarking libary expects input as cv2 BGR format
368
+ for k in range(image_np.shape[0]):
369
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
370
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
371
+ image.device
372
+ )
373
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
374
+ if squeeze:
375
+ image = image[0]
376
+ image = 2 * image - 1
377
+ return image
378
+
379
+
380
+ # A fixed 48-bit message that was choosen at random
381
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
382
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
383
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]