adorabook commited on
Commit
92a699d
·
verified ·
1 Parent(s): b43927c

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -2
  2. .gitignore +142 -0
  3. LICENSE +201 -0
  4. README.md +128 -0
  5. app.py +224 -0
  6. app_flux.py +327 -0
  7. app_v1_1.py +504 -0
  8. docs/pulid_for_flux.md +89 -0
  9. docs/pulid_v1.1.md +28 -0
  10. docs/v1.1_preview.md +14 -0
  11. eva_clip/__init__.py +11 -0
  12. eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  13. eva_clip/constants.py +2 -0
  14. eva_clip/eva_vit_model.py +548 -0
  15. eva_clip/factory.py +517 -0
  16. eva_clip/hf_configs.py +57 -0
  17. eva_clip/hf_model.py +248 -0
  18. eva_clip/loss.py +138 -0
  19. eva_clip/model.py +439 -0
  20. eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
  21. eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
  22. eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
  23. eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
  24. eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
  25. eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
  26. eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
  27. eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
  28. eva_clip/modified_resnet.py +181 -0
  29. eva_clip/openai.py +144 -0
  30. eva_clip/pretrained.py +332 -0
  31. eva_clip/rope.py +137 -0
  32. eva_clip/timm_model.py +122 -0
  33. eva_clip/tokenizer.py +201 -0
  34. eva_clip/transform.py +103 -0
  35. eva_clip/transformer.py +737 -0
  36. eva_clip/utils.py +326 -0
  37. example_inputs/hinton.jpeg +0 -0
  38. example_inputs/lecun.jpg +0 -0
  39. example_inputs/lifeifei.jpg +0 -0
  40. example_inputs/liuyifei.png +0 -0
  41. example_inputs/pengwei.jpg +3 -0
  42. example_inputs/rihanna.webp +0 -0
  43. example_inputs/zcy.webp +0 -0
  44. flux/__init__.py +11 -0
  45. flux/math.py +31 -0
  46. flux/model.py +165 -0
  47. flux/modules/__init__.py +0 -0
  48. flux/modules/autoencoder.py +312 -0
  49. flux/modules/conditioner.py +37 -0
  50. flux/modules/layers.py +253 -0
.gitattributes CHANGED
@@ -32,5 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- pulid-flux-adorabook/example_inputs/pengwei.jpg filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -textexample_inputs/pengwei.jpg filter=lfs diff=lfs merge=lfs -text
 
.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets/*
2
+ experiments/*
3
+ results/*
4
+ tb_logger/*
5
+ wandb/*
6
+ tmp/*
7
+ weights/*
8
+ inputs/*
9
+
10
+ *.DS_Store
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ pip-wheel-metadata/
35
+ share/python-wheels/
36
+ *.egg-info/
37
+ .installed.cfg
38
+ *.egg
39
+ MANIFEST
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .nox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ *.py,cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ .idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID (NeurIPS 2024)
2
+
3
+ ### :open_book: PuLID: Pure and Lightning ID Customization via Contrastive Alignment
4
+ > [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2404.16022) [![xl](https://img.shields.io/badge/🤗-HuggingFaceDemo-orange)](https://huggingface.co/spaces/yanze/PuLID) [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX) [![Replicate](https://img.shields.io/badge/Replicate-Demo_for_PuLID-blue)](https://replicate.com/zsxkib/pulid) [![Replicate](https://img.shields.io/badge/Replicate-PuLID_FLUX-blue)](https://replicate.com/zsxkib/flux-pulid)<br>
5
+ > Zinan Guo*, Yanze Wu*✝, Zhuowei Chen, Lang Chen, Peng Zhang, Qian He <br>
6
+ > (*Equal Contribution, ✝Corresponding Author) <br>
7
+ > ByteDance Inc <br>
8
+
9
+ ### :triangular_flag_on_post: Updates
10
+ * **2024.10.31**: 🔥 We are happy to release our latest [models](https://huggingface.co/guozinan/PuLID), **PuLID-v1.1** and **PuLID-FLUX-v0.9.1**. See more in [Model Zoo](#european_castle-model-zoo) and [pulid v1.1 model](docs/pulid_v1.1.md). We also update a new revision for the [arXiv paper](https://arxiv.org/abs/2404.16022), which includes more results, details, and analysis, please check it out.
11
+ * **2024.09.26**: 🎉 PuLID accepted by NeurIPS 2024
12
+ * **2024.09.12**: We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md)
13
+ * **2024.05.23**: share the [preview of our upcoming v1.1 model](docs/v1.1_preview.md), please stay tuned
14
+ * **2024.05.01**: release v1 codes&models, also the [🤗HuggingFace Demo](https://huggingface.co/spaces/yanze/PuLID)
15
+ * **2024.04.25**: release arXiv paper.
16
+
17
+
18
+ ## PuLID for FLUX
19
+ Please check the doc and demo of PuLID-FLUX [here](docs/pulid_for_flux.md).
20
+
21
+ ### updates
22
+ - [x] Local gradio demo is ready now
23
+ - [x] Online HuggingFace demo is ready now [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX)
24
+ - [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
25
+ - [x] (Community Implementation) Online Replicate demo is ready now [![Replicate](https://replicate.com/zsxkib/flux-pulid/badge)](https://replicate.com/zsxkib/flux-pulid)
26
+ - [x] Local gradio demo supports 12GB graphic card now
27
+ - [x] v0.9.1 is ready now
28
+
29
+
30
+ Below results are generated with PuLID-FLUX.
31
+ ![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
32
+
33
+
34
+ ## Examples
35
+ Images generated with our PuLID
36
+ ![examples](https://github.com/ToTheBeginning/PuLID/assets/11482921/65610b0d-ba4f-4dc3-a74d-bd60f8f5ce37)
37
+ Applications
38
+
39
+ https://github.com/ToTheBeginning/PuLID/assets/11482921/9bdd0c8a-99e8-4eab-ab9e-39bf796cc6b8
40
+
41
+ ## :european_castle: Model Zoo
42
+
43
+ | Version | Base Model | Description |
44
+ |:--------------------------------------------------------------------------------------------------:|:----------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
45
+ | [PuLID-v1](https://huggingface.co/guozinan/PuLID/blob/main/pulid_v1.bin) | SDXL | Paper model. |
46
+ | [PuLID-v1.1](https://huggingface.co/guozinan/PuLID/blob/main/pulid_v1.1.safetensors) | SDXL | Compared to PuLID-v1, better compatibility, editability, facial naturalness, and similarity. |
47
+ | [PuLID-FLUX-v0.9.0](https://huggingface.co/guozinan/PuLID/blob/main/pulid_flux_v0.9.0.safetensors) | FLUX | Our first version for PuLID-FLUX, better prompt-following and image quality (since FLUX is more powerful than SDXL). But ID fidelity is not high enough for some male inputs |
48
+ | [PuLID-FLUX-v0.9.1](https://huggingface.co/guozinan/PuLID/blob/main/pulid_flux_v0.9.1.safetensors) | FLUX | Compared to PuLID-FLUX-v0.9.0, better ID fidelity. From the quantitative metric of ID similarity, the improvement is about 5 percentage points. Meanwhile, the editability remains similar as before. |
49
+
50
+
51
+ ## :wrench: Dependencies and Installation
52
+ - Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
53
+ - [PyTorch >= 2.0](https://pytorch.org/) if you don't need flux-dev-fp8, otherwise [PyTorch >= 2.4.1](https://pytorch.org/)
54
+ ```bash
55
+ # clone PuLID repo
56
+ git clone https://github.com/ToTheBeginning/PuLID.git
57
+ cd PuLID
58
+ # create conda env
59
+ conda create --name pulid python=3.10
60
+ # activate env
61
+ conda activate pulid
62
+ # Install dependent packages
63
+ # 1. if you don't need flux-fp8, e.g., you are using xl or flux-bf16, install the following requirements.txt
64
+ pip install -r requirements.txt
65
+ # 2. if you need flux-fp8 (to put flux on consumer-grade gpu), install the following requirements_fp8.txt
66
+ pip install -r requirements_fp8.txt
67
+ ```
68
+
69
+ ## :zap: Quick Inference
70
+ ### Local Gradio Demo
71
+ ```bash
72
+ # for v1 version
73
+ python app.py
74
+
75
+ # for v1.1 version
76
+ python app_v1.1.py --base BASE_MODEL
77
+ Usage:
78
+ -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning
79
+ ```
80
+
81
+ ### Online HuggingFace Demo
82
+ Thanks for the GPU grant from HuggingFace team, you can try PuLID HF demo in
83
+ - [https://huggingface.co/spaces/yanze/PuLID](https://huggingface.co/spaces/yanze/PuLID) for SDXL
84
+ - [https://huggingface.co/spaces/yanze/PuLID-FLUX](https://huggingface.co/spaces/yanze/PuLID-FLUX) for FLUX
85
+
86
+ ## :paperclip: Related Resources
87
+ Following are some third-party implementations of PuLID we have found in the Internet.
88
+ We appreciate the efforts of the respective developers for making PuLID accessible to a wider audience.
89
+ If there are any PuLID based resources and applications that we have not mentioned here, please let us know,
90
+ and we will include them in this list.
91
+
92
+ #### Online Demo
93
+ - **Colab**: https://github.com/camenduru/PuLID-jupyter provided by [camenduru](https://github.com/camenduru)
94
+ - **Replicate (PuLID)**: https://replicate.com/zsxkib/pulid provided by [zsxkib](https://github.com/zsxkib)
95
+ - **Replicate (PuLID-FLUX)**: https://replicate.com/zsxkib/flux-pulid provided by [zsxkib](https://github.com/zsxkib)
96
+
97
+ #### ComfyUI
98
+ - https://github.com/cubiq/PuLID_ComfyUI provided by [cubiq](https://github.com/cubiq), native ComfyUI implementation
99
+ - https://github.com/ZHO-ZHO-ZHO/ComfyUI-PuLID-ZHO provided by [ZHO](https://github.com/ZHO-ZHO-ZHO), diffusers-based implementation
100
+
101
+ #### WebUI
102
+ - [SD.Next](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md#update-for-2024-11-21) Implementation provided by [vladmandic](https://github.com/vladmandic)
103
+ - https://github.com/Mikubill/sd-webui-controlnet/pull/2838 provided by [huchenlei](https://github.com/huchenlei)
104
+
105
+ #### Other Applications
106
+ - PuLID-FLUX multi-person generation with [Regional-Prompting-FLUX](https://github.com/instantX-research/Regional-Prompting-FLUX), provided by [Anthony](https://github.com/antonioo-c)
107
+
108
+ ## Disclaimer
109
+ This project strives to impact the domain of AI-driven image generation positively. Users are granted the freedom to
110
+ create images using this tool, but they are expected to comply with local laws and utilize it responsibly.
111
+ The developers do not assume any responsibility for potential misuse by users.
112
+
113
+
114
+ ## Citation
115
+ If PuLID is helpful, please help to ⭐ the repo.
116
+
117
+ If you find this project useful for your research, please consider citing our paper:
118
+ ```bibtex
119
+ @InProceedings{guo2024pulid,
120
+ title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
121
+ author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and Zhang, Peng and He, Qian},
122
+ booktitle={Advances in Neural Information Processing Systems},
123
+ year={2024}
124
+ }
125
+ ```
126
+
127
+ ## :e-mail: Contact
128
+ If you have any comments or questions, please [open a new issue](https://github.com/ToTheBeginning/PuLID/issues/new/choose) or feel free to contact [Yanze Wu](https://tothebeginning.github.io/) and [Zinan Guo](mailto:[email protected]).
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from pulid import attention_processor as attention
6
+ from pulid.pipeline import PuLIDPipeline
7
+ from pulid.utils import resize_numpy_image_long, seed_everything
8
+
9
+ torch.set_grad_enabled(False)
10
+
11
+ pipeline = PuLIDPipeline()
12
+
13
+ # other params
14
+ DEFAULT_NEGATIVE_PROMPT = (
15
+ 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
16
+ 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
17
+ 'low resolution, partially rendered objects, deformed or partially rendered eyes, '
18
+ 'deformed, deformed eyeballs, cross-eyed,blurry'
19
+ )
20
+
21
+
22
+ def run(*args):
23
+ id_image = args[0]
24
+ supp_images = args[1:4]
25
+ prompt, neg_prompt, scale, n_samples, seed, steps, H, W, id_scale, mode, id_mix = args[4:]
26
+
27
+ pipeline.debug_img_list = []
28
+ if mode == 'fidelity':
29
+ attention.NUM_ZERO = 8
30
+ attention.ORTHO = False
31
+ attention.ORTHO_v2 = True
32
+ elif mode == 'extremely style':
33
+ attention.NUM_ZERO = 16
34
+ attention.ORTHO = True
35
+ attention.ORTHO_v2 = False
36
+ else:
37
+ raise ValueError
38
+
39
+ if id_image is not None:
40
+ id_image = resize_numpy_image_long(id_image, 1024)
41
+ id_embeddings = pipeline.get_id_embedding(id_image)
42
+ for supp_id_image in supp_images:
43
+ if supp_id_image is not None:
44
+ supp_id_image = resize_numpy_image_long(supp_id_image, 1024)
45
+ supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
46
+ id_embeddings = torch.cat(
47
+ (id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
48
+ )
49
+ else:
50
+ id_embeddings = None
51
+
52
+ seed_everything(seed)
53
+ ims = []
54
+ for _ in range(n_samples):
55
+ img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
56
+ ims.append(np.array(img))
57
+
58
+ return ims, pipeline.debug_img_list
59
+
60
+
61
+ _HEADER_ = '''
62
+ <h2><b>Official Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
63
+
64
+ **PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
65
+
66
+ Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
67
+
68
+ ❗️❗️❗️**Tips:**
69
+ - we provide some examples in the bottom, you can try these example prompts first
70
+ - a single ID image is usually sufficient, you can also supplement with additional auxiliary images
71
+ - We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.
72
+
73
+ ''' # noqa E501
74
+
75
+ _CITE_ = r"""
76
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
77
+ ---
78
+ 🚀 **Share**
79
+ If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!
80
+
81
+ 📝 **Citation**
82
+ If you find our work useful for your research or applications, please cite using this bibtex:
83
+ ```bibtex
84
+ @article{guo2024pulid,
85
+ title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
86
+ author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
87
+ journal={arXiv preprint arXiv:2404.16022},
88
+ year={2024}
89
+ }
90
+ ```
91
+
92
+ 📋 **License**
93
+ Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.
94
+
95
+ 📧 **Contact**
96
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b> or <b>[email protected]</b>.
97
+ """ # noqa E501
98
+
99
+
100
+ with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
101
+ gr.Markdown(_HEADER_)
102
+ with gr.Row():
103
+ with gr.Column():
104
+ with gr.Row():
105
+ face_image = gr.Image(label="ID image (main)", sources="upload", type="numpy", height=256)
106
+ supp_image1 = gr.Image(
107
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
108
+ )
109
+ supp_image2 = gr.Image(
110
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
111
+ )
112
+ supp_image3 = gr.Image(
113
+ label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
114
+ )
115
+ prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
116
+ submit = gr.Button("Generate")
117
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
118
+ scale = gr.Slider(
119
+ label="CFG, recommend value range [1, 1.5], 1 will be faster ",
120
+ value=1.2,
121
+ minimum=1,
122
+ maximum=1.5,
123
+ step=0.1,
124
+ )
125
+ n_samples = gr.Slider(label="Num samples", value=4, minimum=1, maximum=8, step=1)
126
+ seed = gr.Slider(
127
+ label="Seed", value=42, minimum=np.iinfo(np.uint32).min, maximum=np.iinfo(np.uint32).max, step=1
128
+ )
129
+ steps = gr.Slider(label="Steps", value=4, minimum=1, maximum=100, step=1)
130
+ with gr.Row():
131
+ H = gr.Slider(label="Height", value=1024, minimum=512, maximum=2024, step=64)
132
+ W = gr.Slider(label="Width", value=768, minimum=512, maximum=2024, step=64)
133
+ with gr.Row():
134
+ id_scale = gr.Slider(label="ID scale", minimum=0, maximum=5, step=0.05, value=0.8, interactive=True)
135
+ mode = gr.Dropdown(label="mode", choices=['fidelity', 'extremely style'], value='fidelity')
136
+ id_mix = gr.Checkbox(
137
+ label="ID Mix (if you want to mix two ID image, please turn this on, otherwise, turn this off)",
138
+ value=False,
139
+ )
140
+
141
+ gr.Markdown("## Examples")
142
+ example_inps = [
143
+ [
144
+ 'portrait,cinematic,wolf ears,white hair',
145
+ 'example_inputs/liuyifei.png',
146
+ 'fidelity',
147
+ ]
148
+ ]
149
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='realistic')
150
+
151
+ example_inps = [
152
+ [
153
+ 'portrait, impressionist painting, loose brushwork, vibrant color, light and shadow play',
154
+ 'example_inputs/zcy.webp',
155
+ 'fidelity',
156
+ ]
157
+ ]
158
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='painting style')
159
+
160
+ example_inps = [
161
+ [
162
+ 'portrait, flat papercut style, silhouette, clean cuts, paper, sharp edges, minimalist,color block,man', # noqa E501
163
+ 'example_inputs/lecun.jpg',
164
+ 'fidelity',
165
+ ]
166
+ ]
167
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='papercut style')
168
+
169
+ example_inps = [
170
+ [
171
+ 'woman,cartoon,solo,Popmart Blind Box, Super Mario, 3d',
172
+ 'example_inputs/rihanna.webp',
173
+ 'fidelity',
174
+ ]
175
+ ]
176
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='3d style')
177
+
178
+ example_inps = [
179
+ [
180
+ 'portrait, the legend of zelda, anime',
181
+ 'example_inputs/liuyifei.png',
182
+ 'extremely style',
183
+ ]
184
+ ]
185
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='anime style')
186
+
187
+ example_inps = [
188
+ [
189
+ 'portrait, superman',
190
+ 'example_inputs/lecun.jpg',
191
+ 'example_inputs/lifeifei.jpg',
192
+ 'fidelity',
193
+ True,
194
+ ]
195
+ ]
196
+ gr.Examples(examples=example_inps, inputs=[prompt, face_image, supp_image1, mode, id_mix], label='id mix')
197
+
198
+ with gr.Column():
199
+ output = gr.Gallery(label='Output', elem_id="gallery")
200
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
201
+ gr.Markdown(_CITE_)
202
+
203
+ inps = [
204
+ face_image,
205
+ supp_image1,
206
+ supp_image2,
207
+ supp_image3,
208
+ prompt,
209
+ neg_prompt,
210
+ scale,
211
+ n_samples,
212
+ seed,
213
+ steps,
214
+ H,
215
+ W,
216
+ id_scale,
217
+ mode,
218
+ id_mix,
219
+ ]
220
+ submit.click(fn=run, inputs=inps, outputs=[output, intermediate_output])
221
+
222
+
223
+ demo.queue(max_size=3)
224
+ demo.launch(server_name='0.0.0.0')
app_flux.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from einops import rearrange
6
+ from PIL import Image
7
+
8
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
9
+ from flux.util import (
10
+ SamplingOptions,
11
+ load_ae,
12
+ load_clip,
13
+ load_flow_model,
14
+ load_flow_model_quintized,
15
+ load_t5,
16
+ )
17
+ from pulid.pipeline_flux import PuLIDPipeline
18
+ from pulid.utils import resize_numpy_image_long
19
+
20
+
21
+ def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
22
+ t5 = load_t5(device, max_length=128)
23
+ clip = load_clip(device)
24
+ if fp8:
25
+ model = load_flow_model_quintized(name, device="cpu" if offload else device)
26
+ else:
27
+ model = load_flow_model(name, device="cpu" if offload else device)
28
+ model.eval()
29
+ ae = load_ae(name, device="cpu" if offload else device)
30
+ return model, ae, t5, clip
31
+
32
+
33
+ class FluxGenerator:
34
+ def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
35
+ self.device = torch.device(device)
36
+ self.offload = offload
37
+ self.aggressive_offload = aggressive_offload
38
+ self.model_name = model_name
39
+ self.model, self.ae, self.t5, self.clip = get_models(
40
+ model_name,
41
+ device=self.device,
42
+ offload=self.offload,
43
+ fp8=args.fp8,
44
+ )
45
+ self.pulid_model = PuLIDPipeline(self.model, device="cpu" if offload else device, weight_dtype=torch.bfloat16,
46
+ onnx_provider=args.onnx_provider)
47
+ if offload:
48
+ self.pulid_model.face_helper.face_det.mean_tensor = self.pulid_model.face_helper.face_det.mean_tensor.to(torch.device("cuda"))
49
+ self.pulid_model.face_helper.face_det.device = torch.device("cuda")
50
+ self.pulid_model.face_helper.device = torch.device("cuda")
51
+ self.pulid_model.device = torch.device("cuda")
52
+ self.pulid_model.load_pretrain(args.pretrained_model, version=args.version)
53
+
54
+ @torch.inference_mode()
55
+ def generate_image(
56
+ self,
57
+ width,
58
+ height,
59
+ num_steps,
60
+ start_step,
61
+ guidance,
62
+ seed,
63
+ prompt,
64
+ id_image=None,
65
+ id_weight=1.0,
66
+ neg_prompt="",
67
+ true_cfg=1.0,
68
+ timestep_to_start_cfg=1,
69
+ max_sequence_length=128,
70
+ ):
71
+ self.t5.max_length = max_sequence_length
72
+
73
+ seed = int(seed)
74
+ if seed == -1:
75
+ seed = None
76
+
77
+ opts = SamplingOptions(
78
+ prompt=prompt,
79
+ width=width,
80
+ height=height,
81
+ num_steps=num_steps,
82
+ guidance=guidance,
83
+ seed=seed,
84
+ )
85
+
86
+ if opts.seed is None:
87
+ opts.seed = torch.Generator(device="cpu").seed()
88
+ print(f"Generating '{opts.prompt}' with seed {opts.seed}")
89
+ t0 = time.perf_counter()
90
+
91
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
92
+
93
+ # prepare input
94
+ x = get_noise(
95
+ 1,
96
+ opts.height,
97
+ opts.width,
98
+ device=self.device,
99
+ dtype=torch.bfloat16,
100
+ seed=opts.seed,
101
+ )
102
+ timesteps = get_schedule(
103
+ opts.num_steps,
104
+ x.shape[-1] * x.shape[-2] // 4,
105
+ shift=True,
106
+ )
107
+
108
+ if self.offload:
109
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
110
+ inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
111
+ inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
112
+
113
+ # offload TEs to CPU, load processor models and id encoder to gpu
114
+ if self.offload:
115
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
116
+ torch.cuda.empty_cache()
117
+ self.pulid_model.components_to_device(torch.device("cuda"))
118
+
119
+ if id_image is not None:
120
+ id_image = resize_numpy_image_long(id_image, 1024)
121
+ id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
122
+ else:
123
+ id_embeddings = None
124
+ uncond_id_embeddings = None
125
+
126
+ # offload processor models and id encoder to CPU, load dit model to gpu
127
+ if self.offload:
128
+ self.pulid_model.components_to_device(torch.device("cpu"))
129
+ torch.cuda.empty_cache()
130
+ if self.aggressive_offload:
131
+ self.model.components_to_gpu()
132
+ else:
133
+ self.model = self.model.to(self.device)
134
+
135
+ # denoise initial noise
136
+ x = denoise(
137
+ self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
138
+ start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
139
+ timestep_to_start_cfg=timestep_to_start_cfg,
140
+ neg_txt=inp_neg["txt"] if use_true_cfg else None,
141
+ neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
142
+ neg_vec=inp_neg["vec"] if use_true_cfg else None,
143
+ aggressive_offload=self.aggressive_offload,
144
+ )
145
+
146
+ # offload model, load autoencoder to gpu
147
+ if self.offload:
148
+ self.model.cpu()
149
+ torch.cuda.empty_cache()
150
+ self.ae.decoder.to(x.device)
151
+
152
+ # decode latents to pixel space
153
+ x = unpack(x.float(), opts.height, opts.width)
154
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
155
+ x = self.ae.decode(x)
156
+
157
+ if self.offload:
158
+ self.ae.decoder.cpu()
159
+ torch.cuda.empty_cache()
160
+
161
+ t1 = time.perf_counter()
162
+
163
+ print(f"Done in {t1 - t0:.1f}s.")
164
+ # bring into PIL format
165
+ x = x.clamp(-1, 1)
166
+ # x = embed_watermark(x.float())
167
+ x = rearrange(x[0], "c h w -> h w c")
168
+
169
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
170
+ return img, str(opts.seed), self.pulid_model.debug_img_list
171
+
172
+ _HEADER_ = '''
173
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
174
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
175
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
176
+ </div>
177
+
178
+ ❗️❗️❗️**Tips:**
179
+ - `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
180
+ - `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to the [doc](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#useful-tips).
181
+ - please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
182
+ - we provide some examples in the bottom, you can try these example prompts first
183
+
184
+ ''' # noqa E501
185
+
186
+ _CITE_ = r"""
187
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
188
+ ---
189
+
190
+ 📧 **Contact**
191
+ If you have any questions or feedbacks, feel free to open a discussion or contact <b>[email protected]</b>.
192
+ """ # noqa E501
193
+
194
+
195
+ def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
196
+ offload: bool = False, aggressive_offload: bool = False):
197
+ generator = FluxGenerator(model_name, device, offload, aggressive_offload, args)
198
+
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown(_HEADER_)
201
+
202
+ with gr.Row():
203
+ with gr.Column():
204
+ prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
205
+ id_image = gr.Image(label="ID Image")
206
+ id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
207
+
208
+ width = gr.Slider(256, 1536, 896, step=16, label="Width")
209
+ height = gr.Slider(256, 1536, 1152, step=16, label="Height")
210
+ num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
211
+ start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
212
+ guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
213
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
214
+ max_sequence_length = gr.Slider(128, 512, 128, step=128,
215
+ label="max_sequence_length for prompt (T5), small will be faster")
216
+
217
+ with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
218
+ neg_prompt = gr.Textbox(
219
+ label="Negative Prompt",
220
+ value="bad quality, worst quality, text, signature, watermark, extra limbs")
221
+ true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
222
+ timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
223
+
224
+ generate_btn = gr.Button("Generate")
225
+
226
+ with gr.Column():
227
+ output_image = gr.Image(label="Generated Image")
228
+ seed_output = gr.Textbox(label="Used Seed")
229
+ intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
230
+ gr.Markdown(_CITE_)
231
+
232
+ with gr.Row(), gr.Column():
233
+ gr.Markdown("## Examples")
234
+ example_inps = [
235
+ [
236
+ 'a woman holding sign with glowing green text \"PuLID for FLUX\"',
237
+ 'example_inputs/liuyifei.png',
238
+ 4, 4, 2680261499100305976, 1
239
+ ],
240
+ [
241
+ 'portrait, side view',
242
+ 'example_inputs/liuyifei.png',
243
+ 4, 4, 1205240166692517553, 1
244
+ ],
245
+ [
246
+ 'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
247
+ 'example_inputs/liuyifei.png',
248
+ 4, 4, 6349424134217931066, 1
249
+ ],
250
+ [
251
+ 'a young child is eating Icecream',
252
+ 'example_inputs/liuyifei.png',
253
+ 4, 4, 10606046113565776207, 1
254
+ ],
255
+ [
256
+ 'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
257
+ 'example_inputs/pengwei.jpg',
258
+ 4, 4, 2410129802683836089, 1
259
+ ],
260
+ [
261
+ 'portrait, candle light',
262
+ 'example_inputs/pengwei.jpg',
263
+ 4, 4, 17522759474323955700, 1
264
+ ],
265
+ [
266
+ 'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
267
+ 'example_inputs/pengwei.jpg',
268
+ 4, 4, 17733156847328193625, 1
269
+ ],
270
+ [
271
+ 'American Comics, 1boy',
272
+ 'example_inputs/pengwei.jpg',
273
+ 1, 4, 13223174453874179686, 1
274
+ ],
275
+ [
276
+ 'portrait, pixar',
277
+ 'example_inputs/pengwei.jpg',
278
+ 1, 4, 9445036702517583939, 1
279
+ ],
280
+ ]
281
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
282
+ label='fake CFG')
283
+
284
+ example_inps = [
285
+ [
286
+ 'portrait, made of ice sculpture',
287
+ 'example_inputs/lecun.jpg',
288
+ 1, 1, 3811899118709451814, 5
289
+ ],
290
+ ]
291
+ gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
292
+ label='true CFG')
293
+
294
+ generate_btn.click(
295
+ fn=generator.generate_image,
296
+ inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
297
+ true_cfg, timestep_to_start_cfg, max_sequence_length],
298
+ outputs=[output_image, seed_output, intermediate_output],
299
+ )
300
+
301
+ return demo
302
+
303
+
304
+ if __name__ == "__main__":
305
+ import argparse
306
+
307
+ parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
308
+ parser.add_argument('--version', type=str, default='v0.9.1', help='version of the model', choices=['v0.9.0', 'v0.9.1'])
309
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
310
+ help="currently only support flux-dev")
311
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use")
312
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
313
+ parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use, for 24G GPUs")
314
+ parser.add_argument("--fp8", action="store_true", help="use flux-dev-fp8 model")
315
+ parser.add_argument("--onnx_provider", type=str, default="gpu", choices=["gpu", "cpu"],
316
+ help="set onnx_provider to cpu (default gpu) can help reduce RAM usage, and when combined with"
317
+ "fp8 option, the peak RAM is under 15GB")
318
+ parser.add_argument("--port", type=int, default=8080, help="Port to use")
319
+ parser.add_argument("--dev", action='store_true', help="Development mode")
320
+ parser.add_argument("--pretrained_model", type=str, help='for development')
321
+ args = parser.parse_args()
322
+
323
+ if args.aggressive_offload:
324
+ args.offload = True
325
+
326
+ demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
327
+ demo.launch(server_name='0.0.0.0', server_port=args.port)
app_v1_1.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+
7
+ from pulid import attention_processor as attention
8
+ from pulid.pipeline_v1_1 import PuLIDPipeline
9
+ from pulid.utils import resize_numpy_image_long
10
+
11
+ torch.set_grad_enabled(False)
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ '--base',
16
+ type=str,
17
+ default='RunDiffusion/Juggernaut-XL-v9',
18
+ choices=[
19
+ 'Lykon/dreamshaper-xl-lightning',
20
+ # 'SG161222/RealVisXL_V4.0', will add it later
21
+ 'RunDiffusion/Juggernaut-XL-v9',
22
+ ],
23
+ )
24
+ # parser.add_argument('--sampler', type=str, default='dpmpp_2m', choices=['dpmpp_sde', 'dpmpp_2m'])
25
+ parser.add_argument('--port', type=int, default=7860)
26
+ args = parser.parse_args()
27
+
28
+ use_lightning_model = 'lightning' in args.base.lower()
29
+ # currently we only support two commonly used sampler
30
+ args.sampler = 'dpmpp_sde' if use_lightning_model else 'dpmpp_2m'
31
+ if use_lightning_model:
32
+ default_cfg = 2.0
33
+ default_steps = 5
34
+ else:
35
+ default_cfg = 7.0
36
+ default_steps = 25
37
+
38
+ pipeline = PuLIDPipeline(sdxl_repo=args.base, sampler=args.sampler)
39
+
40
+ # other params
41
+ DEFAULT_NEGATIVE_PROMPT = (
42
+ 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
43
+ 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
44
+ 'low resolution, partially rendered objects, deformed or partially rendered eyes, '
45
+ 'deformed, deformed eyeballs, cross-eyed,blurry'
46
+ )
47
+
48
+ dreamshaper_example_inps = [
49
+ ['portrait, blacklight', 'example_inputs/liuyifei.png', 42, 0.8, 10],
50
+ ['pixel art, 1boy', 'example_inputs/lecun.jpg', 42, 0.8, 10],
51
+ [
52
+ 'cinematic film still, close up, photo of redheaded girl near grasses, fictional landscapes, (intense sunlight:1.4), realist detail, brooding mood, ue5, detailed character expressions, light amber and red, amazing quality, wallpaper, analog film grain',
53
+ 'example_inputs/liuyifei.png',
54
+ 42,
55
+ 0.8,
56
+ 10,
57
+ ],
58
+ [
59
+ 'A minimalist line art depiction of an Artificial Intelligence being\'s thought process, lines and nodes forming intricate patterns.',
60
+ 'example_inputs/hinton.jpeg',
61
+ 42,
62
+ 0.8,
63
+ 10,
64
+ ],
65
+ [
66
+ 'instagram photo, photo of 23 y.o man in black sweater, pale skin, (smile:0.4), hard shadows',
67
+ 'example_inputs/pengwei.jpg',
68
+ 42,
69
+ 0.8,
70
+ 10,
71
+ ],
72
+ [
73
+ 'by Tsutomu Nihei,(strange but extremely beautiful:1.4),(masterpiece, best quality:1.4),in the style of nicola samori,The Joker,',
74
+ 'example_inputs/lecun.jpg',
75
+ 1675432759740519133,
76
+ 0.8,
77
+ 10,
78
+ ],
79
+ ]
80
+
81
+ jugger_example_inps = [
82
+ [
83
+ 'robot,simple robot,robot with glass face,ellipse head robot,(made partially out of glass),hexagonal shapes,ferns growing inside head,butterflies on head,butterflies flying around',
84
+ 'example_inputs/hinton.jpeg',
85
+ 15022214902832471291,
86
+ 0.8,
87
+ 20,
88
+ ],
89
+ ['sticker art, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
90
+ [
91
+ '1girl, cute model, Long thick Maxi Skirt, Knit sweater, swept back hair, alluring smile, working at a clothing store, perfect eyes, highly detailed beautiful expressive eyes, detailed eyes, 35mm photograph, film, bokeh, professional, 4k, highly detailed dynamic lighting, photorealistic, 8k, raw, rich, intricate details,',
92
+ 'example_inputs/liuyifei.png',
93
+ 42,
94
+ 0.8,
95
+ 20,
96
+ ],
97
+ ['Chinese paper-cut, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
98
+ ['Studio Ghibli, 1boy', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
99
+ ['1man made of ice sculpture', 'example_inputs/lecun.jpg', 42, 0.8, 20],
100
+ ['portrait of green-skinned shrek, wearing lacoste purple sweater', 'example_inputs/lecun.jpg', 42, 0.8, 20],
101
+ ['1990s Japanese anime, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
102
+ ['made of little stones, portrait', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
103
+ ]
104
+
105
+
106
+ @torch.inference_mode()
107
+ def run(*args):
108
+ id_image = args[0]
109
+ supp_images = args[1:4]
110
+ prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, ortho = args[4:]
111
+ seed = int(seed)
112
+ if seed == -1:
113
+ seed = torch.Generator(device="cpu").seed()
114
+
115
+ pipeline.debug_img_list = []
116
+
117
+ attention.NUM_ZERO = num_zero
118
+ if ortho == 'v2':
119
+ attention.ORTHO = False
120
+ attention.ORTHO_v2 = True
121
+ elif ortho == 'v1':
122
+ attention.ORTHO = True
123
+ attention.ORTHO_v2 = False
124
+ else:
125
+ attention.ORTHO = False
126
+ attention.ORTHO_v2 = False
127
+
128
+ if id_image is not None:
129
+ id_image = resize_numpy_image_long(id_image, 1024)
130
+ supp_id_image_list = [
131
+ resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
132
+ ]
133
+ id_image_list = [id_image] + supp_id_image_list
134
+ uncond_id_embedding, id_embedding = pipeline.get_id_embedding(id_image_list)
135
+ else:
136
+ uncond_id_embedding = None
137
+ id_embedding = None
138
+
139
+ img = pipeline.inference(
140
+ prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
141
+ )[0]
142
+
143
+ return np.array(img), str(seed), pipeline.debug_img_list
144
+
145
+
146
+ _HEADER_ = '''
147
+ <h2><b>Official Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
148
+
149
+ **PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
150
+
151
+ Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
152
+
153
+ ❗️❗️❗️**Tips:**
154
+ - we provide some examples in the bottom, you can try these example prompts first
155
+ - a single ID image is usually sufficient, you can also supplement with additional auxiliary images
156
+ - You can adjust the trade-off between ID fidelity and editability in the advanced options, but generally, the default settings are good enough.
157
+
158
+ ''' # noqa E501
159
+
160
+ _CITE_ = r"""
161
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
162
+ ---
163
+ 📧 **Contact**
164
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b> or <b>[email protected]</b>.
165
+ """ # noqa E501
166
+
167
+
168
+ with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
169
+ gr.Markdown(_HEADER_)
170
+ with gr.Row():
171
+ with gr.Column():
172
+ with gr.Row():
173
+ face_image = gr.Image(label="ID image (main)", height=256)
174
+ supp_image1 = gr.Image(label="Additional ID image (auxiliary)", height=256)
175
+ supp_image2 = gr.Image(label="Additional ID image (auxiliary)", height=256)
176
+ supp_image3 = gr.Image(label="Additional ID image (auxiliary)", height=256)
177
+ prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
178
+ submit = gr.Button("Generate")
179
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
180
+ scale = gr.Slider(
181
+ label="CFG (recommend 2 for lightning model and 7 for non-accelerated model)",
182
+ value=default_cfg,
183
+ minimum=1,
184
+ maximum=10,
185
+ step=0.1,
186
+ )
187
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
188
+ steps = gr.Slider(label="Steps", value=default_steps, minimum=1, maximum=30, step=1)
189
+ with gr.Row():
190
+ H = gr.Slider(label="Height", value=1152, minimum=512, maximum=2024, step=64)
191
+ W = gr.Slider(label="Width", value=896, minimum=512, maximum=2024, step=64)
192
+ with gr.Row(), gr.Accordion(
193
+ "Advanced Options (adjust the trade-off between ID fidelity and editability)", open=False
194
+ ):
195
+ id_scale = gr.Slider(
196
+ label="ID scale (Increasing it enhances ID similarity but reduces editability)",
197
+ minimum=0,
198
+ maximum=5,
199
+ step=0.05,
200
+ value=0.8,
201
+ interactive=True,
202
+ )
203
+ num_zero = gr.Slider(
204
+ label="num zero (Increasing it enhances ID editability but reduces similarity)",
205
+ minimum=0,
206
+ maximum=80,
207
+ step=1,
208
+ value=20,
209
+ interactive=True,
210
+ )
211
+ ortho = gr.Dropdown(label="ortho", choices=['off', 'v1', 'v2'], value='v2', visible=False)
212
+
213
+ with gr.Column():
214
+ output = gr.Image(label="Generated Image")
215
+ seed_output = gr.Textbox(label="Used Seed")
216
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
217
+ gr.Markdown(_CITE_)
218
+
219
+ with gr.Row(), gr.Column():
220
+ gr.Markdown("## Examples")
221
+ if args.base == 'Lykon/dreamshaper-xl-lightning':
222
+ gr.Examples(
223
+ examples=dreamshaper_example_inps,
224
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
225
+ label='dreamshaper-xl-lightning examples',
226
+ )
227
+ elif args.base == 'RunDiffusion/Juggernaut-XL-v9':
228
+ gr.Examples(
229
+ examples=jugger_example_inps,
230
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
231
+ label='Juggernaut-XL-v9 examples',
232
+ )
233
+
234
+ inps = [
235
+ face_image,
236
+ supp_image1,
237
+ supp_image2,
238
+ supp_image3,
239
+ prompt,
240
+ neg_prompt,
241
+ scale,
242
+ seed,
243
+ steps,
244
+ H,
245
+ W,
246
+ id_scale,
247
+ num_zero,
248
+ ortho,
249
+ ]
250
+ submit.click(fn=run, inputs=inps, outputs=[output, seed_output, intermediate_output])
251
+ import argparse
252
+
253
+ import gradio as gr
254
+ import numpy as np
255
+ import torch
256
+
257
+ from pulid import attention_processor as attention
258
+ from pulid.pipeline_v1_1 import PuLIDPipeline
259
+ from pulid.utils import resize_numpy_image_long
260
+
261
+ torch.set_grad_enabled(False)
262
+
263
+ parser = argparse.ArgumentParser()
264
+ parser.add_argument(
265
+ '--base',
266
+ type=str,
267
+ default='RunDiffusion/Juggernaut-XL-v9',
268
+ choices=[
269
+ 'Lykon/dreamshaper-xl-lightning',
270
+ # 'SG161222/RealVisXL_V4.0', will add it later
271
+ 'RunDiffusion/Juggernaut-XL-v9',
272
+ ],
273
+ )
274
+ # parser.add_argument('--sampler', type=str, default='dpmpp_2m', choices=['dpmpp_sde', 'dpmpp_2m'])
275
+ parser.add_argument('--port', type=int, default=7860)
276
+ args = parser.parse_args()
277
+
278
+ use_lightning_model = 'lightning' in args.base.lower()
279
+ # currently we only support two commonly used sampler
280
+ args.sampler = 'dpmpp_sde' if use_lightning_model else 'dpmpp_2m'
281
+ if use_lightning_model:
282
+ default_cfg = 2.0
283
+ default_steps = 5
284
+ else:
285
+ default_cfg = 7.0
286
+ default_steps = 25
287
+
288
+ pipeline = PuLIDPipeline(sdxl_repo=args.base, sampler=args.sampler)
289
+
290
+ # other params
291
+ DEFAULT_NEGATIVE_PROMPT = (
292
+ 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
293
+ 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
294
+ 'low resolution, partially rendered objects, deformed or partially rendered eyes, '
295
+ 'deformed, deformed eyeballs, cross-eyed,blurry'
296
+ )
297
+
298
+ dreamshaper_example_inps = [
299
+ ['portrait, blacklight', 'example_inputs/liuyifei.png', 42, 0.8, 10],
300
+ ['pixel art, 1boy', 'example_inputs/lecun.jpg', 42, 0.8, 10],
301
+ [
302
+ 'cinematic film still, close up, photo of redheaded girl near grasses, fictional landscapes, (intense sunlight:1.4), realist detail, brooding mood, ue5, detailed character expressions, light amber and red, amazing quality, wallpaper, analog film grain',
303
+ 'example_inputs/liuyifei.png',
304
+ 42,
305
+ 0.8,
306
+ 10,
307
+ ],
308
+ [
309
+ 'A minimalist line art depiction of an Artificial Intelligence being\'s thought process, lines and nodes forming intricate patterns.',
310
+ 'example_inputs/hinton.jpeg',
311
+ 42,
312
+ 0.8,
313
+ 10,
314
+ ],
315
+ [
316
+ 'instagram photo, photo of 23 y.o man in black sweater, pale skin, (smile:0.4), hard shadows',
317
+ 'example_inputs/pengwei.jpg',
318
+ 42,
319
+ 0.8,
320
+ 10,
321
+ ],
322
+ [
323
+ 'by Tsutomu Nihei,(strange but extremely beautiful:1.4),(masterpiece, best quality:1.4),in the style of nicola samori,The Joker,',
324
+ 'example_inputs/lecun.jpg',
325
+ 1675432759740519133,
326
+ 0.8,
327
+ 10,
328
+ ],
329
+ ]
330
+
331
+ jugger_example_inps = [
332
+ [
333
+ 'robot,simple robot,robot with glass face,ellipse head robot,(made partially out of glass),hexagonal shapes,ferns growing inside head,butterflies on head,butterflies flying around',
334
+ 'example_inputs/hinton.jpeg',
335
+ 15022214902832471291,
336
+ 0.8,
337
+ 20,
338
+ ],
339
+ ['sticker art, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
340
+ [
341
+ '1girl, cute model, Long thick Maxi Skirt, Knit sweater, swept back hair, alluring smile, working at a clothing store, perfect eyes, highly detailed beautiful expressive eyes, detailed eyes, 35mm photograph, film, bokeh, professional, 4k, highly detailed dynamic lighting, photorealistic, 8k, raw, rich, intricate details,',
342
+ 'example_inputs/liuyifei.png',
343
+ 42,
344
+ 0.8,
345
+ 20,
346
+ ],
347
+ ['Chinese paper-cut, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
348
+ ['Studio Ghibli, 1boy', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
349
+ ['1man made of ice sculpture', 'example_inputs/lecun.jpg', 42, 0.8, 20],
350
+ ['portrait of green-skinned shrek, wearing lacoste purple sweater', 'example_inputs/lecun.jpg', 42, 0.8, 20],
351
+ ['1990s Japanese anime, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
352
+ ['made of little stones, portrait', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
353
+ ]
354
+
355
+
356
+ @torch.inference_mode()
357
+ def run(*args):
358
+ id_image = args[0]
359
+ supp_images = args[1:4]
360
+ prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, ortho = args[4:]
361
+ seed = int(seed)
362
+ if seed == -1:
363
+ seed = torch.Generator(device="cpu").seed()
364
+
365
+ pipeline.debug_img_list = []
366
+
367
+ attention.NUM_ZERO = num_zero
368
+ if ortho == 'v2':
369
+ attention.ORTHO = False
370
+ attention.ORTHO_v2 = True
371
+ elif ortho == 'v1':
372
+ attention.ORTHO = True
373
+ attention.ORTHO_v2 = False
374
+ else:
375
+ attention.ORTHO = False
376
+ attention.ORTHO_v2 = False
377
+
378
+ if id_image is not None:
379
+ id_image = resize_numpy_image_long(id_image, 1024)
380
+ supp_id_image_list = [
381
+ resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
382
+ ]
383
+ id_image_list = [id_image] + supp_id_image_list
384
+ uncond_id_embedding, id_embedding = pipeline.get_id_embedding(id_image_list)
385
+ else:
386
+ uncond_id_embedding = None
387
+ id_embedding = None
388
+
389
+ img = pipeline.inference(
390
+ prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
391
+ )[0]
392
+
393
+ return np.array(img), str(seed), pipeline.debug_img_list
394
+
395
+
396
+ _HEADER_ = '''
397
+ <h2><b>Official Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
398
+
399
+ **PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
400
+
401
+ Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
402
+
403
+ ❗️❗️❗️**Tips:**
404
+ - we provide some examples in the bottom, you can try these example prompts first
405
+ - a single ID image is usually sufficient, you can also supplement with additional auxiliary images
406
+ - You can adjust the trade-off between ID fidelity and editability in the advanced options, but generally, the default settings are good enough.
407
+
408
+ ''' # noqa E501
409
+
410
+ _CITE_ = r"""
411
+ If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
412
+ ---
413
+ 📧 **Contact**
414
+ If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b> or <b>[email protected]</b>.
415
+ """ # noqa E501
416
+
417
+
418
+ with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
419
+ gr.Markdown(_HEADER_)
420
+ with gr.Row():
421
+ with gr.Column():
422
+ with gr.Row():
423
+ face_image = gr.Image(label="ID image (main)", height=256)
424
+ supp_image1 = gr.Image(label="Additional ID image (auxiliary)", height=256)
425
+ supp_image2 = gr.Image(label="Additional ID image (auxiliary)", height=256)
426
+ supp_image3 = gr.Image(label="Additional ID image (auxiliary)", height=256)
427
+ prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
428
+ submit = gr.Button("Generate")
429
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
430
+ scale = gr.Slider(
431
+ label="CFG (recommend 2 for lightning model and 7 for non-accelerated model)",
432
+ value=default_cfg,
433
+ minimum=1,
434
+ maximum=10,
435
+ step=0.1,
436
+ )
437
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
438
+ steps = gr.Slider(label="Steps", value=default_steps, minimum=1, maximum=30, step=1)
439
+ with gr.Row():
440
+ H = gr.Slider(label="Height", value=1152, minimum=512, maximum=2024, step=64)
441
+ W = gr.Slider(label="Width", value=896, minimum=512, maximum=2024, step=64)
442
+ with gr.Row(), gr.Accordion(
443
+ "Advanced Options (adjust the trade-off between ID fidelity and editability)", open=False
444
+ ):
445
+ id_scale = gr.Slider(
446
+ label="ID scale (Increasing it enhances ID similarity but reduces editability)",
447
+ minimum=0,
448
+ maximum=5,
449
+ step=0.05,
450
+ value=0.8,
451
+ interactive=True,
452
+ )
453
+ num_zero = gr.Slider(
454
+ label="num zero (Increasing it enhances ID editability but reduces similarity)",
455
+ minimum=0,
456
+ maximum=80,
457
+ step=1,
458
+ value=20,
459
+ interactive=True,
460
+ )
461
+ ortho = gr.Dropdown(label="ortho", choices=['off', 'v1', 'v2'], value='v2', visible=False)
462
+
463
+ with gr.Column():
464
+ output = gr.Image(label="Generated Image")
465
+ seed_output = gr.Textbox(label="Used Seed")
466
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
467
+ gr.Markdown(_CITE_)
468
+
469
+ with gr.Row(), gr.Column():
470
+ gr.Markdown("## Examples")
471
+ if args.base == 'Lykon/dreamshaper-xl-lightning':
472
+ gr.Examples(
473
+ examples=dreamshaper_example_inps,
474
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
475
+ label='dreamshaper-xl-lightning examples',
476
+ )
477
+ elif args.base == 'RunDiffusion/Juggernaut-XL-v9':
478
+ gr.Examples(
479
+ examples=jugger_example_inps,
480
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
481
+ label='Juggernaut-XL-v9 examples',
482
+ )
483
+
484
+ inps = [
485
+ face_image,
486
+ supp_image1,
487
+ supp_image2,
488
+ supp_image3,
489
+ prompt,
490
+ neg_prompt,
491
+ scale,
492
+ seed,
493
+ steps,
494
+ H,
495
+ W,
496
+ id_scale,
497
+ num_zero,
498
+ ortho,
499
+ ]
500
+ submit.click(fn=run, inputs=inps, outputs=[output, seed_output, intermediate_output])
501
+
502
+ demo.launch(server_name='0.0.0.0', server_port=args.port)
503
+
504
+ demo.launch(server_name='0.0.0.0', server_port=args.port)
docs/pulid_for_flux.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID for FLUX
2
+ We are happy to release the **PuLID-FLUX-v0.9.0** model, which provides a tuning-free ID customization solution for FLUX.1-dev.
3
+
4
+ If PuLID-FLUX is helpful, please help to ⭐ this repo or recommend it to your friends 😊
5
+
6
+ ## Inference
7
+ ### :triangular_flag_on_post: Update
8
+ - 2024.10.31: We release the **PuLID-FLUX.v0.9.1** model. Compared to the previous version, v0.9.1 has improved the ID fidelity, with an increase of about 5 percentage points in quantitative metrics of facial similarity.
9
+
10
+ ### Local Gradio Demo
11
+ You first need to follow the [dependencies-and-installation](../README.md#wrench-dependencies-and-installation) to set
12
+ up the environment, and download the `flux1-dev.safetensors` (if you want to use bf16 rather than fp8) and `ae.safetensors` from [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main).
13
+ The PuLID-FLUX model will be automatically downloaded from [huggingface](https://huggingface.co/guozinan/PuLID/tree/main).
14
+
15
+ There are following four options to run the gradio demo:
16
+
17
+ :notes: Note: The Gradio demo defaults to using the latest version of the model. If you need to switch to an older version or a specific version, please append `--version SPECIFIC_VERSION` to the following command lines.
18
+
19
+ #### naive bf16
20
+ simply run `python app_flux.py`, the peak memory is under 45GB.
21
+
22
+ #### bf16 + offload
23
+ run `python app_flux.py --offload`, the peak memory is under 30GB.
24
+
25
+ #### fp8 + offload (for consumer-grade GPUs)
26
+ To use fp8, you need to make sure you have installed `requirements-fp8.txt`, it includes `optimum-quanto` and higher version of PyTorch.
27
+ We use `flux-dev-fp8` checkpoint from [XLabs-AI/flux-dev-fp8](https://huggingface.co/XLabs-AI/flux-dev-fp8), it will be automatically downloaded. You can also download it manually and put it in the models folder
28
+
29
+ Run `python app_flux.py --offload --fp8 --onnx_provider cpu`, the peak memory is under 15GB, this is for GPU with 16GB memory.
30
+
31
+ For 24GB graphic memory users, you can run `python app_flux.py --offload --fp8`, the peak memory is under 17GB.
32
+
33
+ For 12GB graphic memory users, you can run `python app_flux.py --aggressive_offload --fp8 --onnx_provider cpu`, the peak memory is about 11GB.
34
+ However, using aggressive offload (like sequential offload), the speed will be very slow due to the frequent need for memory transfers between CPU and GPU at each timestep.
35
+
36
+ Please note that, there is a difference in image quality between fp8 and bf16, with some degradation in the former.
37
+ Specifically, the details of the face may be slightly worse, but the layout is similar. If you want the best results
38
+ of PuLID-FLUX or you have the resources, please use bf16 rather than fp8.
39
+ We have included a comparison in the table below.
40
+
41
+ | | case1 | case2 | case3 | case4 |
42
+ |------|:-------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------:|
43
+ | bf16 | ![c1_bf16](https://github.com/user-attachments/assets/781b2102-d5fe-4786-b4d3-7b8df501c781) | ![c2_bf16](https://github.com/user-attachments/assets/6218a6ca-f07e-4a9a-ac63-896526ff52cf) | ![c3_bf16](https://github.com/user-attachments/assets/3b6675e5-d26e-4799-b0f3-72e4a7f9a771) |![c4_bf16](https://github.com/user-attachments/assets/b4e162ca-da8b-4e68-8d6b-ba1a674b2a0b)|
44
+ | fp8 | ![c1_fp8](https://github.com/user-attachments/assets/8547f020-bd39-4e9b-aa82-b85be4efc41c) | ![c2_fp8](https://github.com/user-attachments/assets/00d3d485-0298-4966-82e1-a31946797ac8) | ![c3_fp8](https://github.com/user-attachments/assets/b1c6a6b6-1140-49a3-93bd-1245ee5fef4c) |![c4_fp8](https://github.com/user-attachments/assets/62e512ca-6315-4a89-9350-430e20b86b36)|
45
+
46
+
47
+ #### bf16 + more agreesive offload
48
+ run `python app_flux.py --aggressive_offload`, the peak memory is around 23GB.
49
+ But it will be very, very slow. If you have better solution to run bf16 under 24GB, please let us know.
50
+
51
+ ### Online Demo
52
+ - huggingface demo:
53
+ [https://huggingface.co/spaces/yanze/PuLID-FLUX](https://huggingface.co/spaces/yanze/PuLID-FLUX)
54
+
55
+ ### ComfyUI
56
+ Please stay tuned for the community implementation
57
+
58
+ ## Visual Results
59
+ ![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
60
+
61
+
62
+ ## Useful Tips
63
+ There are two parameters that are crucial and need to be set carefully:
64
+
65
+ 1. `timestep to start inserting ID`: This parameter controls the timing of ID insertion. If set to 0, the ID starts being inserted to the DIT from the first timestep. The earlier it is inserted, the higher the ID fidelity will be, but the editability may decrease. The later it is inserted, the lower the fidelity to the ID, but the editability will increase, and the disruption to the original model behavior will also be smaller. For generating realistic images, we suggest setting this to 4. If you found the ID similarity is not high enough, you could try lowering this parameter accordingly. For generating stylized images, we suggest setting it to 0-1.
66
+ ![start_id](https://github.com/user-attachments/assets/3866ffab-542d-4e2f-9a0c-6877c9158d49)
67
+
68
+ 2. `true CFG scale`: FLUX.1-dev is a guidance distill model. The original CFG process, which required twice the number of inference steps, is distilled into a guidance scale, thereby modulating the DIT through the guidance scale to simulate the true CFG process with half the inference steps. We will refer to this as fake CFG in the following doc. Our PuLID-FLUX model can be tested under the fake CFG settings, and the guidance scale can be set to a commonly used value, such as 4. However, the model also supports using the real CFG for inference. We compare the results of using true CFG with the fake CFG in photorealistic scenarios below.
69
+ ![fake_cfg_vs_true_cfg_fidelity](https://github.com/user-attachments/assets/73b44dc8-37c7-48c8-8f55-73882731126d)
70
+ As shown in the above image, in terms of ID fidelity, using fake CFG is similar to true CFG in most cases, except that in a few cases, true CFG achieves higher ID similarity. In terms of image aesthetics and facial naturalness, fake CFG performs better. However, by carefully adjusting hyperparameters, the performance of true CFG may be further improved, we leave this to the community to explore. Therefore, we recommend using fake CFG for photorealistic scenes. If you are not satisfy about the ID fidelity, you can try switching to true CFG. Additionally, as shown below, we have found that using fake CFG in stylized scenes sometimes results in lower ID similarity and poorer style response, so if you encounter these two issues in stylized scenes, please consider switching to true CFG.
71
+ ![fake_cfg_vs_true_cfg_style](https://github.com/user-attachments/assets/fb042639-64e6-4bb3-a3a4-5c138793318e)
72
+
73
+
74
+
75
+ ## Some Technical Details
76
+ - We switch the ID encoder from an MLP structure to a Transformer structure. Interested users can refer to [source code](https://github.com/ToTheBeginning/PuLID/blob/cce7cdd65b5bf283c1a39c29f2726902a3c135ca/pulid/encoders_flux.py#L122)
77
+ - Inspired by [Flamingo](https://arxiv.org/abs/2204.14198), we insert additional cross-attention blocks every few DIT blocks to interact ID features with DIT image features
78
+ - We would like to clarify that the acceleration method (lile SDXL-Lightning) serves as an
79
+ optional acceleration trick, but it is not indispensable for training PuLID. We will update the arxiv paper with the relevant details in the near future. Please stay tuned.
80
+
81
+
82
+ ## limitation
83
+ The model is currently in beta version, and we have observed that the ID fidelity may not be high for some male inputs, maybe the model requires more training. If the improved model is ready, we will release it here, so please stay tuned.
84
+
85
+ ## License
86
+ As long as you use FLUX.1-dev model, you should follow the [FLUX.1-dev model license](https://github.com/black-forest-labs/flux/tree/main/model_licenses)
87
+
88
+ ## contact
89
+ If you have any questions or suggestions about the model, please contact [Yanze Wu](https://tothebeginning.github.io/) or open an issue/discussion here.
docs/pulid_v1.1.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID v1.1
2
+
3
+ ✅ Release PuLID v1.1 for SDXL models. In this version, we have made the following improvements:
4
+ - **stronger ID editability and prompt-following ability**. As shown in the examples below, pulid v1.1 can edit style, viewpoint, expression, lighting, etc., while maintaining a high degree of ID similarity.
5
+ - **better compatibility with community models**. According to our tests, PuLID-v1.1 performs well on some popular and powerful base models in the community, such as Juggernaut-XL, RealVisXL, and DreamShaper-XL-Lightning.
6
+ - **better naturalness**. We have noted that without any tricks, the faces generated by PuLID-v1 can be somewhat blurry and lack detail. The v1.1 version, when combined with more powerful base models, has shown improvements in facial clarity and naturalness.
7
+ - ID similarity has beed slightly improved.
8
+
9
+ Note: The v1.1 model does not perform as well as v1 on the SDXL-lightning-4step base model, so if you want to use SDXL-lightning-4step as the base model, we still recommend using the v1 model.
10
+
11
+
12
+ Following are some examples generated by PuLID-v1.1, you can reproduce these results from our Gradio demo.
13
+ ![release_v1 1](https://github.com/user-attachments/assets/3b24e9b2-59c0-4bfb-8658-a26e730c298d)
14
+
15
+ ## How to use
16
+
17
+ ### Local Gradio demo
18
+ ```bash
19
+ python app_v1.1.py --base BASE_MODEL
20
+ Usage:
21
+ -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning
22
+ for each supported base model, we prepare some examples in the bottom of the gradio demo, please try these examples first.
23
+ ```
24
+
25
+ ### Online demo
26
+ We plan to upgrade the [PuLID HF demo](https://huggingface.co/spaces/yanze/PuLID) to v1.1 soon, please stay tuned.
27
+
28
+
docs/v1.1_preview.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PuLID v1.1 preview
2
+ ## The improvements of PuLID v1.1
3
+
4
+ In PuLID v1.1, we have made the following improvements:
5
+ - **better naturalness**
6
+ - **stronger editability**
7
+ - **more compatible with community models**
8
+
9
+ ### PuLID with RealVis-XL as base model. Zoom in for best view
10
+ ![realvis](https://github.com/ToTheBeginning/PuLID/assets/169147031/d6aa288b-b826-41bb-a512-96f9d54b448f)
11
+ ### PuLID with Juggernaut-XL-Lightning as base model. Zoom in for best view
12
+ ![juggernautXL_lightning](https://github.com/ToTheBeginning/PuLID/assets/169147031/4371d6b2-1063-49be-9ff1-56db58140cfe)
13
+ ### PuLID with Dreamshaper-XL-Lightning as base model. Zoom in for best view
14
+ ![dreamshaper](https://github.com/ToTheBeginning/PuLID/assets/169147031/89a21ee0-25c1-4098-a868-59e3149fe10c)
eva_clip/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .loss import ClipLoss
5
+ from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
+ from .openai import load_openai_model, list_openai_models
8
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
+ from .tokenizer import SimpleTokenizer, tokenize
11
+ from .transform import image_transform
eva_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
eva_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
eva_clip/eva_vit_model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ try:
11
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
+ except:
13
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
14
+
15
+ from .transformer import PatchDropout
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+
18
+ if os.getenv('ENV_TYPE') == 'deepspeed':
19
+ try:
20
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
+ except:
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers
28
+ import xformers.ops as xops
29
+ XFORMERS_IS_AVAILBLE = True
30
+ except:
31
+ XFORMERS_IS_AVAILBLE = False
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
35
+ """
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training)
42
+
43
+ def extra_repr(self) -> str:
44
+ return 'p={}'.format(self.drop_prob)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features,
51
+ hidden_features=None,
52
+ out_features=None,
53
+ act_layer=nn.GELU,
54
+ norm_layer=nn.LayerNorm,
55
+ drop=0.,
56
+ subln=False,
57
+
58
+ ):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features)
63
+ self.act = act_layer()
64
+
65
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
66
+
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ # x = self.drop(x)
74
+ # commit this for the orignal BERT implement
75
+ x = self.ffn_ln(x)
76
+
77
+ x = self.fc2(x)
78
+ x = self.drop(x)
79
+ return x
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
83
+ norm_layer=nn.LayerNorm, subln=False):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+
88
+ self.w1 = nn.Linear(in_features, hidden_features)
89
+ self.w2 = nn.Linear(in_features, hidden_features)
90
+
91
+ self.act = act_layer()
92
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
93
+ self.w3 = nn.Linear(hidden_features, out_features)
94
+
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x1 = self.w1(x)
99
+ x2 = self.w2(x)
100
+ hidden = self.act(x1) * x2
101
+ x = self.ffn_ln(hidden)
102
+ x = self.w3(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(
108
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
109
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ if attn_head_dim is not None:
114
+ head_dim = attn_head_dim
115
+ all_head_dim = head_dim * self.num_heads
116
+ self.scale = qk_scale or head_dim ** -0.5
117
+
118
+ self.subln = subln
119
+ if self.subln:
120
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
121
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
122
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
123
+ else:
124
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
125
+
126
+ if qkv_bias:
127
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
128
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
129
+ else:
130
+ self.q_bias = None
131
+ self.v_bias = None
132
+
133
+ if window_size:
134
+ self.window_size = window_size
135
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
136
+ self.relative_position_bias_table = nn.Parameter(
137
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
138
+ # cls to token & token 2 cls & cls to cls
139
+
140
+ # get pair-wise relative position index for each token inside the window
141
+ coords_h = torch.arange(window_size[0])
142
+ coords_w = torch.arange(window_size[1])
143
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
144
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
145
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
146
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
147
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
148
+ relative_coords[:, :, 1] += window_size[1] - 1
149
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
150
+ relative_position_index = \
151
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
152
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
153
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
154
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
155
+ relative_position_index[0, 0] = self.num_relative_distance - 1
156
+
157
+ self.register_buffer("relative_position_index", relative_position_index)
158
+ else:
159
+ self.window_size = None
160
+ self.relative_position_bias_table = None
161
+ self.relative_position_index = None
162
+
163
+ self.attn_drop = nn.Dropout(attn_drop)
164
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
165
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
166
+ self.proj = nn.Linear(all_head_dim, dim)
167
+ self.proj_drop = nn.Dropout(proj_drop)
168
+ self.xattn = xattn
169
+ self.xattn_drop = attn_drop
170
+
171
+ self.rope = rope
172
+
173
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
174
+ B, N, C = x.shape
175
+ if self.subln:
176
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
177
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
178
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
179
+
180
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
181
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
182
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183
+ else:
184
+
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
188
+
189
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
191
+ q, k, v = qkv[0], qkv[1], qkv[2]
192
+
193
+ if self.rope:
194
+ # slightly fast impl
195
+ q_t = q[:, :, 1:, :]
196
+ ro_q_t = self.rope(q_t)
197
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
198
+
199
+ k_t = k[:, :, 1:, :]
200
+ ro_k_t = self.rope(k_t)
201
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
202
+
203
+ if self.xattn:
204
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
205
+ k = k.permute(0, 2, 1, 3)
206
+ v = v.permute(0, 2, 1, 3)
207
+
208
+ x = xops.memory_efficient_attention(
209
+ q, k, v,
210
+ p=self.xattn_drop,
211
+ scale=self.scale,
212
+ )
213
+ x = x.reshape(B, N, -1)
214
+ x = self.inner_attn_ln(x)
215
+ x = self.proj(x)
216
+ x = self.proj_drop(x)
217
+ else:
218
+ q = q * self.scale
219
+ attn = (q @ k.transpose(-2, -1))
220
+
221
+ if self.relative_position_bias_table is not None:
222
+ relative_position_bias = \
223
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
224
+ self.window_size[0] * self.window_size[1] + 1,
225
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
226
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
227
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
228
+
229
+ if rel_pos_bias is not None:
230
+ attn = attn + rel_pos_bias.type_as(attn)
231
+
232
+ if attn_mask is not None:
233
+ attn_mask = attn_mask.bool()
234
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
235
+
236
+ attn = attn.softmax(dim=-1)
237
+ attn = self.attn_drop(attn)
238
+
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
240
+ x = self.inner_attn_ln(x)
241
+ x = self.proj(x)
242
+ x = self.proj_drop(x)
243
+ return x
244
+
245
+
246
+ class Block(nn.Module):
247
+
248
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
249
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
250
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
251
+ subln=False, naiveswiglu=False):
252
+ super().__init__()
253
+ self.norm1 = norm_layer(dim)
254
+ self.attn = Attention(
255
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
256
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
257
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
258
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
259
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
260
+ self.norm2 = norm_layer(dim)
261
+ mlp_hidden_dim = int(dim * mlp_ratio)
262
+
263
+ if naiveswiglu:
264
+ self.mlp = SwiGLU(
265
+ in_features=dim,
266
+ hidden_features=mlp_hidden_dim,
267
+ subln=subln,
268
+ norm_layer=norm_layer,
269
+ )
270
+ else:
271
+ self.mlp = Mlp(
272
+ in_features=dim,
273
+ hidden_features=mlp_hidden_dim,
274
+ act_layer=act_layer,
275
+ subln=subln,
276
+ drop=drop
277
+ )
278
+
279
+ if init_values is not None and init_values > 0:
280
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
281
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
282
+ else:
283
+ self.gamma_1, self.gamma_2 = None, None
284
+
285
+ self.postnorm = postnorm
286
+
287
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
288
+ if self.gamma_1 is None:
289
+ if self.postnorm:
290
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
291
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
292
+ else:
293
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
294
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
295
+ else:
296
+ if self.postnorm:
297
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
298
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
299
+ else:
300
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
301
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
302
+ return x
303
+
304
+
305
+ class PatchEmbed(nn.Module):
306
+ """ Image to Patch Embedding
307
+ """
308
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
309
+ super().__init__()
310
+ img_size = to_2tuple(img_size)
311
+ patch_size = to_2tuple(patch_size)
312
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
313
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
314
+ self.img_size = img_size
315
+ self.patch_size = patch_size
316
+ self.num_patches = num_patches
317
+
318
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
319
+
320
+ def forward(self, x, **kwargs):
321
+ B, C, H, W = x.shape
322
+ # FIXME look at relaxing size constraints
323
+ assert H == self.img_size[0] and W == self.img_size[1], \
324
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
325
+ x = self.proj(x).flatten(2).transpose(1, 2)
326
+ return x
327
+
328
+
329
+ class RelativePositionBias(nn.Module):
330
+
331
+ def __init__(self, window_size, num_heads):
332
+ super().__init__()
333
+ self.window_size = window_size
334
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
335
+ self.relative_position_bias_table = nn.Parameter(
336
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
337
+ # cls to token & token 2 cls & cls to cls
338
+
339
+ # get pair-wise relative position index for each token inside the window
340
+ coords_h = torch.arange(window_size[0])
341
+ coords_w = torch.arange(window_size[1])
342
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
343
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
344
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
345
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
347
+ relative_coords[:, :, 1] += window_size[1] - 1
348
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
349
+ relative_position_index = \
350
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
351
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
352
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
353
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
354
+ relative_position_index[0, 0] = self.num_relative_distance - 1
355
+
356
+ self.register_buffer("relative_position_index", relative_position_index)
357
+
358
+ def forward(self):
359
+ relative_position_bias = \
360
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
361
+ self.window_size[0] * self.window_size[1] + 1,
362
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
363
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
364
+
365
+
366
+ class EVAVisionTransformer(nn.Module):
367
+ """ Vision Transformer with support for patch or hybrid CNN input stage
368
+ """
369
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
370
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
371
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
372
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
373
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
374
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
375
+ super().__init__()
376
+
377
+ if not XFORMERS_IS_AVAILBLE:
378
+ xattn = False
379
+
380
+ self.image_size = img_size
381
+ self.num_classes = num_classes
382
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
383
+
384
+ self.patch_embed = PatchEmbed(
385
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
386
+ num_patches = self.patch_embed.num_patches
387
+
388
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
389
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
390
+ if use_abs_pos_emb:
391
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
392
+ else:
393
+ self.pos_embed = None
394
+ self.pos_drop = nn.Dropout(p=drop_rate)
395
+
396
+ if use_shared_rel_pos_bias:
397
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
398
+ else:
399
+ self.rel_pos_bias = None
400
+
401
+ if rope:
402
+ half_head_dim = embed_dim // num_heads // 2
403
+ hw_seq_len = img_size // patch_size
404
+ self.rope = VisionRotaryEmbeddingFast(
405
+ dim=half_head_dim,
406
+ pt_seq_len=pt_hw_seq_len,
407
+ ft_seq_len=hw_seq_len if intp_freq else None,
408
+ # patch_dropout=patch_dropout
409
+ )
410
+ else:
411
+ self.rope = None
412
+
413
+ self.naiveswiglu = naiveswiglu
414
+
415
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
+ self.use_rel_pos_bias = use_rel_pos_bias
417
+ self.blocks = nn.ModuleList([
418
+ Block(
419
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
420
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
421
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
422
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
423
+ for i in range(depth)])
424
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
425
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
426
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
427
+
428
+ if self.pos_embed is not None:
429
+ trunc_normal_(self.pos_embed, std=.02)
430
+
431
+ trunc_normal_(self.cls_token, std=.02)
432
+ # trunc_normal_(self.mask_token, std=.02)
433
+
434
+ self.apply(self._init_weights)
435
+ self.fix_init_weight()
436
+
437
+ if isinstance(self.head, nn.Linear):
438
+ trunc_normal_(self.head.weight, std=.02)
439
+ self.head.weight.data.mul_(init_scale)
440
+ self.head.bias.data.mul_(init_scale)
441
+
442
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
443
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
444
+
445
+ self.grad_checkpointing = grad_checkpointing
446
+
447
+ def fix_init_weight(self):
448
+ def rescale(param, layer_id):
449
+ param.div_(math.sqrt(2.0 * layer_id))
450
+
451
+ for layer_id, layer in enumerate(self.blocks):
452
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
453
+ if self.naiveswiglu:
454
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
455
+ else:
456
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
457
+
458
+ def get_cast_dtype(self) -> torch.dtype:
459
+ return self.blocks[0].mlp.fc2.weight.dtype
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, nn.Linear):
463
+ trunc_normal_(m.weight, std=.02)
464
+ if m.bias is not None:
465
+ nn.init.constant_(m.bias, 0)
466
+ elif isinstance(m, nn.LayerNorm):
467
+ nn.init.constant_(m.bias, 0)
468
+ nn.init.constant_(m.weight, 1.0)
469
+
470
+ def get_num_layers(self):
471
+ return len(self.blocks)
472
+
473
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
474
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
475
+ for param in self.parameters():
476
+ param.requires_grad = False
477
+
478
+ @torch.jit.ignore
479
+ def set_grad_checkpointing(self, enable=True):
480
+ self.grad_checkpointing = enable
481
+
482
+ @torch.jit.ignore
483
+ def no_weight_decay(self):
484
+ return {'pos_embed', 'cls_token'}
485
+
486
+ def get_classifier(self):
487
+ return self.head
488
+
489
+ def reset_classifier(self, num_classes, global_pool=''):
490
+ self.num_classes = num_classes
491
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
492
+
493
+ def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
494
+
495
+ x = self.patch_embed(x)
496
+ batch_size, seq_len, _ = x.size()
497
+
498
+ if shuffle:
499
+ idx = torch.randperm(x.shape[1]) + 1
500
+ zero = torch.LongTensor([0, ])
501
+ idx = torch.cat([zero, idx])
502
+ pos_embed = self.pos_embed[:, idx]
503
+
504
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
505
+ x = torch.cat((cls_tokens, x), dim=1)
506
+ if shuffle:
507
+ x = x + pos_embed
508
+ elif self.pos_embed is not None:
509
+ x = x + self.pos_embed
510
+ x = self.pos_drop(x)
511
+
512
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
513
+ if os.getenv('RoPE') == '1':
514
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
515
+ x, patch_indices_keep = self.patch_dropout(x)
516
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
517
+ else:
518
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
519
+ x = self.patch_dropout(x)
520
+ else:
521
+ x = self.patch_dropout(x)
522
+
523
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
524
+ hidden_states = []
525
+ for idx, blk in enumerate(self.blocks):
526
+ if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
527
+ hidden_states.append(x)
528
+ if self.grad_checkpointing:
529
+ x = checkpoint(blk, x, (rel_pos_bias,))
530
+ else:
531
+ x = blk(x, rel_pos_bias=rel_pos_bias)
532
+
533
+ if not return_all_features:
534
+ x = self.norm(x)
535
+ if self.fc_norm is not None:
536
+ return self.fc_norm(x.mean(1)), hidden_states
537
+ else:
538
+ return x[:, 0], hidden_states
539
+ return x
540
+
541
+ def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
542
+ if return_all_features:
543
+ return self.forward_features(x, return_all_features, return_hidden, shuffle)
544
+ x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
545
+ x = self.head(x)
546
+ if return_hidden:
547
+ return x, hidden_states
548
+ return x
eva_clip/factory.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13
+ get_cast_dtype
14
+ from .openai import load_openai_model
15
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16
+ from .transform import image_transform
17
+ from .tokenizer import HFTokenizer, tokenize
18
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
23
+
24
+
25
+ def _natural_key(string_):
26
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
27
+
28
+
29
+ def _rescan_model_configs():
30
+ global _MODEL_CONFIGS
31
+
32
+ config_ext = ('.json',)
33
+ config_files = []
34
+ for config_path in _MODEL_CONFIG_PATHS:
35
+ if config_path.is_file() and config_path.suffix in config_ext:
36
+ config_files.append(config_path)
37
+ elif config_path.is_dir():
38
+ for ext in config_ext:
39
+ config_files.extend(config_path.glob(f'*{ext}'))
40
+
41
+ for cf in config_files:
42
+ with open(cf, "r", encoding="utf8") as f:
43
+ model_cfg = json.load(f)
44
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
45
+ _MODEL_CONFIGS[cf.stem] = model_cfg
46
+
47
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
48
+
49
+
50
+ _rescan_model_configs() # initial populate of model config registry
51
+
52
+
53
+ def list_models():
54
+ """ enumerate available model architectures based on config files """
55
+ return list(_MODEL_CONFIGS.keys())
56
+
57
+
58
+ def add_model_config(path):
59
+ """ add model config path or file and update registry """
60
+ if not isinstance(path, Path):
61
+ path = Path(path)
62
+ _MODEL_CONFIG_PATHS.append(path)
63
+ _rescan_model_configs()
64
+
65
+
66
+ def get_model_config(model_name):
67
+ if model_name in _MODEL_CONFIGS:
68
+ return deepcopy(_MODEL_CONFIGS[model_name])
69
+ else:
70
+ return None
71
+
72
+
73
+ def get_tokenizer(model_name):
74
+ config = get_model_config(model_name)
75
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
76
+ return tokenizer
77
+
78
+
79
+ # loading openai CLIP weights when is_openai=True for training
80
+ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
81
+ if is_openai:
82
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
83
+ state_dict = model.state_dict()
84
+ for key in ["input_resolution", "context_length", "vocab_size"]:
85
+ state_dict.pop(key, None)
86
+ else:
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ for mk in model_key.split('|'):
89
+ if isinstance(checkpoint, dict) and mk in checkpoint:
90
+ state_dict = checkpoint[mk]
91
+ break
92
+ else:
93
+ state_dict = checkpoint
94
+ if next(iter(state_dict.items()))[0].startswith('module'):
95
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
96
+
97
+ for k in skip_list:
98
+ if k in list(state_dict.keys()):
99
+ logging.info(f"Removing key {k} from pretrained checkpoint")
100
+ del state_dict[k]
101
+
102
+ if os.getenv('RoPE') == '1':
103
+ for k in list(state_dict.keys()):
104
+ if 'freqs_cos' in k or 'freqs_sin' in k:
105
+ del state_dict[k]
106
+ return state_dict
107
+
108
+
109
+
110
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
111
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
112
+ # detect old format and make compatible with new format
113
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
114
+ state_dict = convert_to_custom_text_state_dict(state_dict)
115
+ if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
116
+ state_dict['logit_scale'] = state_dict['text.logit_scale']
117
+ del state_dict['text.logit_scale']
118
+
119
+ # resize_clip_pos_embed for CLIP and open CLIP
120
+ if 'visual.positional_embedding' in state_dict:
121
+ resize_clip_pos_embed(state_dict, model)
122
+ # specified to eva_vit_model
123
+ elif 'visual.pos_embed' in state_dict:
124
+ resize_evaclip_pos_embed(state_dict, model)
125
+
126
+ # resize_clip_pos_embed(state_dict, model)
127
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
128
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
129
+ return incompatible_keys
130
+
131
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
132
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
133
+
134
+ for k in list(state_dict.keys()):
135
+ if not k.startswith('visual.'):
136
+ del state_dict[k]
137
+ for k in list(state_dict.keys()):
138
+ if k.startswith('visual.'):
139
+ new_k = k[7:]
140
+ state_dict[new_k] = state_dict[k]
141
+ del state_dict[k]
142
+ return state_dict
143
+
144
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
145
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
146
+
147
+ for k in list(state_dict.keys()):
148
+ if k.startswith('visual.'):
149
+ del state_dict[k]
150
+ return state_dict
151
+
152
+ def get_pretrained_tag(pretrained_model):
153
+ pretrained_model = pretrained_model.lower()
154
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
155
+ return "open_clip"
156
+ elif "openai" in pretrained_model:
157
+ return "clip"
158
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
159
+ return "eva_clip"
160
+ else:
161
+ return "other"
162
+
163
+ def load_pretrained_checkpoint(
164
+ model,
165
+ visual_checkpoint_path,
166
+ text_checkpoint_path,
167
+ strict=True,
168
+ visual_model=None,
169
+ text_model=None,
170
+ model_key="model|module|state_dict",
171
+ skip_list=[]):
172
+ visual_tag = get_pretrained_tag(visual_model)
173
+ text_tag = get_pretrained_tag(text_model)
174
+
175
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
176
+ visual_incompatible_keys, text_incompatible_keys = None, None
177
+ if visual_checkpoint_path:
178
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
179
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
180
+ elif visual_tag == "clip":
181
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182
+ else:
183
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184
+
185
+ # resize_clip_pos_embed for CLIP and open CLIP
186
+ if 'positional_embedding' in visual_state_dict:
187
+ resize_visual_pos_embed(visual_state_dict, model)
188
+ # specified to EVA model
189
+ elif 'pos_embed' in visual_state_dict:
190
+ resize_eva_pos_embed(visual_state_dict, model)
191
+
192
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
193
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
194
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
195
+
196
+ if text_checkpoint_path:
197
+ if text_tag == "eva_clip" or text_tag == "open_clip":
198
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
199
+ elif text_tag == "clip":
200
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
201
+ else:
202
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203
+
204
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205
+
206
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208
+
209
+ return visual_incompatible_keys, text_incompatible_keys
210
+
211
+ def create_model(
212
+ model_name: str,
213
+ pretrained: Optional[str] = None,
214
+ precision: str = 'fp32',
215
+ device: Union[str, torch.device] = 'cpu',
216
+ jit: bool = False,
217
+ force_quick_gelu: bool = False,
218
+ force_custom_clip: bool = False,
219
+ force_patch_dropout: Optional[float] = None,
220
+ pretrained_image: str = '',
221
+ pretrained_text: str = '',
222
+ pretrained_hf: bool = True,
223
+ pretrained_visual_model: str = None,
224
+ pretrained_text_model: str = None,
225
+ cache_dir: Optional[str] = None,
226
+ skip_list: list = [],
227
+ ):
228
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ if pretrained and pretrained.lower() == 'openai':
233
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
234
+ model = load_openai_model(
235
+ model_name,
236
+ precision=precision,
237
+ device=device,
238
+ jit=jit,
239
+ cache_dir=cache_dir,
240
+ )
241
+ else:
242
+ model_cfg = get_model_config(model_name)
243
+ if model_cfg is not None:
244
+ logging.info(f'Loaded {model_name} model config.')
245
+ else:
246
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
247
+ raise RuntimeError(f'Model config for {model_name} not found.')
248
+
249
+ if 'rope' in model_cfg.get('vision_cfg', {}):
250
+ if model_cfg['vision_cfg']['rope']:
251
+ os.environ['RoPE'] = "1"
252
+ else:
253
+ os.environ['RoPE'] = "0"
254
+
255
+ if force_quick_gelu:
256
+ # override for use of QuickGELU on non-OpenAI transformer models
257
+ model_cfg["quick_gelu"] = True
258
+
259
+ if force_patch_dropout is not None:
260
+ # override the default patch dropout value
261
+ model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
262
+
263
+ cast_dtype = get_cast_dtype(precision)
264
+ custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
265
+
266
+
267
+ if custom_clip:
268
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
269
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
270
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
271
+ else:
272
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
273
+
274
+ pretrained_cfg = {}
275
+ if pretrained:
276
+ checkpoint_path = ''
277
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
278
+ if pretrained_cfg:
279
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
280
+ elif os.path.exists(pretrained):
281
+ checkpoint_path = pretrained
282
+
283
+ if checkpoint_path:
284
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
285
+ load_checkpoint(model,
286
+ checkpoint_path,
287
+ model_key="model|module|state_dict",
288
+ strict=False
289
+ )
290
+ else:
291
+ error_str = (
292
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
293
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
294
+ logging.warning(error_str)
295
+ raise RuntimeError(error_str)
296
+ else:
297
+ visual_checkpoint_path = ''
298
+ text_checkpoint_path = ''
299
+
300
+ if pretrained_image:
301
+ pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
303
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
304
+ # pretrained weight loading for timm models set via vision_cfg
305
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
306
+ elif pretrained_image_cfg:
307
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
308
+ elif os.path.exists(pretrained_image):
309
+ visual_checkpoint_path = pretrained_image
310
+ else:
311
+ logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
312
+ raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
313
+
314
+ if pretrained_text:
315
+ pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
316
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
317
+ if pretrained_image_cfg:
318
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
319
+ elif os.path.exists(pretrained_text):
320
+ text_checkpoint_path = pretrained_text
321
+ else:
322
+ logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323
+ raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324
+
325
+ if visual_checkpoint_path:
326
+ logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327
+ if text_checkpoint_path:
328
+ logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
329
+
330
+ if visual_checkpoint_path or text_checkpoint_path:
331
+ load_pretrained_checkpoint(
332
+ model,
333
+ visual_checkpoint_path,
334
+ text_checkpoint_path,
335
+ strict=False,
336
+ visual_model=pretrained_visual_model,
337
+ text_model=pretrained_text_model,
338
+ model_key="model|module|state_dict",
339
+ skip_list=skip_list
340
+ )
341
+
342
+ if "fp16" in precision or "bf16" in precision:
343
+ logging.info(f'convert precision to {precision}')
344
+ model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
345
+
346
+ model.to(device=device)
347
+
348
+ # set image / mean metadata from pretrained_cfg if available, or use default
349
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
350
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
351
+
352
+ if jit:
353
+ model = torch.jit.script(model)
354
+
355
+ return model
356
+
357
+
358
+ def create_model_and_transforms(
359
+ model_name: str,
360
+ pretrained: Optional[str] = None,
361
+ precision: str = 'fp32',
362
+ device: Union[str, torch.device] = 'cpu',
363
+ jit: bool = False,
364
+ force_quick_gelu: bool = False,
365
+ force_custom_clip: bool = False,
366
+ force_patch_dropout: Optional[float] = None,
367
+ pretrained_image: str = '',
368
+ pretrained_text: str = '',
369
+ pretrained_hf: bool = True,
370
+ pretrained_visual_model: str = None,
371
+ pretrained_text_model: str = None,
372
+ image_mean: Optional[Tuple[float, ...]] = None,
373
+ image_std: Optional[Tuple[float, ...]] = None,
374
+ cache_dir: Optional[str] = None,
375
+ skip_list: list = [],
376
+ ):
377
+ model = create_model(
378
+ model_name,
379
+ pretrained,
380
+ precision=precision,
381
+ device=device,
382
+ jit=jit,
383
+ force_quick_gelu=force_quick_gelu,
384
+ force_custom_clip=force_custom_clip,
385
+ force_patch_dropout=force_patch_dropout,
386
+ pretrained_image=pretrained_image,
387
+ pretrained_text=pretrained_text,
388
+ pretrained_hf=pretrained_hf,
389
+ pretrained_visual_model=pretrained_visual_model,
390
+ pretrained_text_model=pretrained_text_model,
391
+ cache_dir=cache_dir,
392
+ skip_list=skip_list,
393
+ )
394
+
395
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
396
+ image_std = image_std or getattr(model.visual, 'image_std', None)
397
+ preprocess_train = image_transform(
398
+ model.visual.image_size,
399
+ is_train=True,
400
+ mean=image_mean,
401
+ std=image_std
402
+ )
403
+ preprocess_val = image_transform(
404
+ model.visual.image_size,
405
+ is_train=False,
406
+ mean=image_mean,
407
+ std=image_std
408
+ )
409
+
410
+ return model, preprocess_train, preprocess_val
411
+
412
+
413
+ def create_transforms(
414
+ model_name: str,
415
+ pretrained: Optional[str] = None,
416
+ precision: str = 'fp32',
417
+ device: Union[str, torch.device] = 'cpu',
418
+ jit: bool = False,
419
+ force_quick_gelu: bool = False,
420
+ force_custom_clip: bool = False,
421
+ force_patch_dropout: Optional[float] = None,
422
+ pretrained_image: str = '',
423
+ pretrained_text: str = '',
424
+ pretrained_hf: bool = True,
425
+ pretrained_visual_model: str = None,
426
+ pretrained_text_model: str = None,
427
+ image_mean: Optional[Tuple[float, ...]] = None,
428
+ image_std: Optional[Tuple[float, ...]] = None,
429
+ cache_dir: Optional[str] = None,
430
+ skip_list: list = [],
431
+ ):
432
+ model = create_model(
433
+ model_name,
434
+ pretrained,
435
+ precision=precision,
436
+ device=device,
437
+ jit=jit,
438
+ force_quick_gelu=force_quick_gelu,
439
+ force_custom_clip=force_custom_clip,
440
+ force_patch_dropout=force_patch_dropout,
441
+ pretrained_image=pretrained_image,
442
+ pretrained_text=pretrained_text,
443
+ pretrained_hf=pretrained_hf,
444
+ pretrained_visual_model=pretrained_visual_model,
445
+ pretrained_text_model=pretrained_text_model,
446
+ cache_dir=cache_dir,
447
+ skip_list=skip_list,
448
+ )
449
+
450
+
451
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
452
+ image_std = image_std or getattr(model.visual, 'image_std', None)
453
+ preprocess_train = image_transform(
454
+ model.visual.image_size,
455
+ is_train=True,
456
+ mean=image_mean,
457
+ std=image_std
458
+ )
459
+ preprocess_val = image_transform(
460
+ model.visual.image_size,
461
+ is_train=False,
462
+ mean=image_mean,
463
+ std=image_std
464
+ )
465
+ del model
466
+
467
+ return preprocess_train, preprocess_val
468
+
469
+ def create_model_from_pretrained(
470
+ model_name: str,
471
+ pretrained: str,
472
+ precision: str = 'fp32',
473
+ device: Union[str, torch.device] = 'cpu',
474
+ jit: bool = False,
475
+ force_quick_gelu: bool = False,
476
+ force_custom_clip: bool = False,
477
+ force_patch_dropout: Optional[float] = None,
478
+ return_transform: bool = True,
479
+ image_mean: Optional[Tuple[float, ...]] = None,
480
+ image_std: Optional[Tuple[float, ...]] = None,
481
+ cache_dir: Optional[str] = None,
482
+ is_frozen: bool = False,
483
+ ):
484
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
485
+ raise RuntimeError(
486
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
487
+ f' Use open_clip.list_pretrained() to find one.')
488
+
489
+ model = create_model(
490
+ model_name,
491
+ pretrained,
492
+ precision=precision,
493
+ device=device,
494
+ jit=jit,
495
+ force_quick_gelu=force_quick_gelu,
496
+ force_custom_clip=force_custom_clip,
497
+ force_patch_dropout=force_patch_dropout,
498
+ cache_dir=cache_dir,
499
+ )
500
+
501
+ if is_frozen:
502
+ for param in model.parameters():
503
+ param.requires_grad = False
504
+
505
+ if not return_transform:
506
+ return model
507
+
508
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
509
+ image_std = image_std or getattr(model.visual, 'image_std', None)
510
+ preprocess = image_transform(
511
+ model.visual.image_size,
512
+ is_train=False,
513
+ mean=image_mean,
514
+ std=image_std
515
+ )
516
+
517
+ return model, preprocess
eva_clip/hf_configs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "bert": {
46
+ "config_names": {
47
+ "context_length": "max_position_embeddings",
48
+ "vocab_size": "vocab_size",
49
+ "width": "hidden_size",
50
+ "heads": "num_attention_heads",
51
+ "layers": "num_hidden_layers",
52
+ "layer_attr": "layer",
53
+ "token_embeddings_attr": "embeddings"
54
+ },
55
+ "pooler": "mean_pooler",
56
+ }
57
+ }
eva_clip/hf_model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch import TensorType
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+ # TODO: ?last - for gpt-like models
35
+ _POOLERS = {}
36
+
37
+ def register_pooler(cls):
38
+ """Decorator registering pooler class"""
39
+ _POOLERS[_camel2snake(cls.__name__)] = cls
40
+ return cls
41
+
42
+
43
+ @register_pooler
44
+ class MeanPooler(nn.Module):
45
+ """Mean pooling"""
46
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
47
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
48
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
49
+
50
+ @register_pooler
51
+ class MaxPooler(nn.Module):
52
+ """Max pooling"""
53
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
54
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
55
+ return masked_output.max(1).values
56
+
57
+ @register_pooler
58
+ class ClsPooler(nn.Module):
59
+ """CLS token pooling"""
60
+ def __init__(self, use_pooler_output=True):
61
+ super().__init__()
62
+ self.cls_token_position = 0
63
+ self.use_pooler_output = use_pooler_output
64
+
65
+ def forward(self, x:BaseModelOutput, attention_mask:TensorType):
66
+
67
+ if (self.use_pooler_output and
68
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
69
+ (x.pooler_output is not None)
70
+ ):
71
+ return x.pooler_output
72
+
73
+ return x.last_hidden_state[:, self.cls_token_position, :]
74
+
75
+ class HFTextEncoder(nn.Module):
76
+ """HuggingFace model adapter"""
77
+ def __init__(
78
+ self,
79
+ model_name_or_path: str,
80
+ output_dim: int,
81
+ tokenizer_name: str = None,
82
+ config: PretrainedConfig = None,
83
+ pooler_type: str = None,
84
+ proj: str = None,
85
+ pretrained: bool = True,
86
+ masked_language_modeling: bool = False):
87
+ super().__init__()
88
+
89
+ self.output_dim = output_dim
90
+
91
+ # TODO: find better way to get this information
92
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
93
+
94
+ if transformers is None:
95
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
96
+ if config is None:
97
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
98
+ if masked_language_modeling:
99
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
100
+ AutoModelForMaskedLM.from_config, self.config)
101
+ else:
102
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
103
+ AutoModel.from_config, self.config)
104
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
105
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
106
+ self.transformer = create_func(model_args)
107
+ self.transformer = self.transformer.encoder
108
+ else:
109
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
110
+ else:
111
+ self.config = config
112
+ if masked_language_modeling:
113
+ self.transformer = AutoModelForMaskedLM.from_config(config)
114
+ else:
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
136
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
137
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
138
+
139
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
140
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141
+ # attn_mask = (x != self.config.pad_token_id).long()
142
+ # out = self.transformer(
143
+ # input_ids=x,
144
+ # attention_mask=attn_mask,
145
+ # encoder_hidden_states = image_embeds,
146
+ # encoder_attention_mask = image_atts,
147
+ # )
148
+ # pooled_out = self.pooler(out, attn_mask)
149
+
150
+ # return self.itm_proj(pooled_out)
151
+
152
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153
+ if masked_indices is None:
154
+ masked_indices = torch.bernoulli(probability_matrix).bool()
155
+
156
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158
+
159
+ if targets is not None:
160
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
161
+
162
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
165
+
166
+ # 10% of the time, we replace masked input tokens with random word
167
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169
+ input_ids[indices_random] = random_words[indices_random]
170
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171
+
172
+ if targets is not None:
173
+ return input_ids, targets
174
+ else:
175
+ return input_ids
176
+
177
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178
+ labels = input_ids.clone()
179
+ attn_mask = (input_ids != self.config.pad_token_id).long()
180
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182
+ probability_matrix = torch.full(labels.shape, mlm_probability)
183
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184
+ probability_matrix = probability_matrix)
185
+ mlm_output = self.transformer(input_ids,
186
+ attention_mask = attn_mask,
187
+ encoder_hidden_states = image_embeds,
188
+ encoder_attention_mask = image_atts,
189
+ return_dict = True,
190
+ labels = labels,
191
+ )
192
+ return mlm_output.loss
193
+ # mlm_output = self.transformer(input_ids,
194
+ # attention_mask = attn_mask,
195
+ # encoder_hidden_states = image_embeds,
196
+ # encoder_attention_mask = image_atts,
197
+ # return_dict = True,
198
+ # ).last_hidden_state
199
+ # logits = self.mlm_proj(mlm_output)
200
+
201
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203
+ # labels = labels[:, 1:].contiguous().view(-1)
204
+
205
+ # mlm_loss = F.cross_entropy(
206
+ # logits,
207
+ # labels,
208
+ # # label_smoothing=0.1,
209
+ # )
210
+ # return mlm_loss
211
+
212
+
213
+ def forward(self, x:TensorType) -> TensorType:
214
+ attn_mask = (x != self.config.pad_token_id).long()
215
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
216
+ pooled_out = self.pooler(out, attn_mask)
217
+
218
+ return self.proj(pooled_out)
219
+
220
+ def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221
+ if not unlocked_layers: # full freezing
222
+ for n, p in self.transformer.named_parameters():
223
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224
+ return
225
+
226
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229
+ embeddings = getattr(
230
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231
+ modules = [embeddings, *layer_list][:-unlocked_layers]
232
+ # freeze layers
233
+ for module in modules:
234
+ for n, p in module.named_parameters():
235
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236
+
237
+
238
+ @torch.jit.ignore
239
+ def set_grad_checkpointing(self, enable=True):
240
+ self.transformer.gradient_checkpointing_enable()
241
+
242
+ def get_num_layers(self):
243
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245
+ return len(layer_list)
246
+
247
+ def init_parameters(self):
248
+ pass
eva_clip/loss.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+ has_distributed = True
10
+ except ImportError:
11
+ has_distributed = False
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+ from timm.loss import LabelSmoothingCrossEntropy
19
+
20
+
21
+ def gather_features(
22
+ image_features,
23
+ text_features,
24
+ local_loss=False,
25
+ gather_with_grad=False,
26
+ rank=0,
27
+ world_size=1,
28
+ use_horovod=False
29
+ ):
30
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
+ if use_horovod:
32
+ assert hvd is not None, 'Please install horovod'
33
+ if gather_with_grad:
34
+ all_image_features = hvd.allgather(image_features)
35
+ all_text_features = hvd.allgather(text_features)
36
+ else:
37
+ with torch.no_grad():
38
+ all_image_features = hvd.allgather(image_features)
39
+ all_text_features = hvd.allgather(text_features)
40
+ if not local_loss:
41
+ # ensure grads for local rank when all_* features don't have a gradient
42
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
+ gathered_image_features[rank] = image_features
45
+ gathered_text_features[rank] = text_features
46
+ all_image_features = torch.cat(gathered_image_features, dim=0)
47
+ all_text_features = torch.cat(gathered_text_features, dim=0)
48
+ else:
49
+ # We gather tensors from all gpus
50
+ if gather_with_grad:
51
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55
+ else:
56
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58
+ dist.all_gather(gathered_image_features, image_features)
59
+ dist.all_gather(gathered_text_features, text_features)
60
+ if not local_loss:
61
+ # ensure grads for local rank when all_* features don't have a gradient
62
+ gathered_image_features[rank] = image_features
63
+ gathered_text_features[rank] = text_features
64
+ all_image_features = torch.cat(gathered_image_features, dim=0)
65
+ all_text_features = torch.cat(gathered_text_features, dim=0)
66
+
67
+ return all_image_features, all_text_features
68
+
69
+
70
+ class ClipLoss(nn.Module):
71
+
72
+ def __init__(
73
+ self,
74
+ local_loss=False,
75
+ gather_with_grad=False,
76
+ cache_labels=False,
77
+ rank=0,
78
+ world_size=1,
79
+ use_horovod=False,
80
+ smoothing=0.,
81
+ ):
82
+ super().__init__()
83
+ self.local_loss = local_loss
84
+ self.gather_with_grad = gather_with_grad
85
+ self.cache_labels = cache_labels
86
+ self.rank = rank
87
+ self.world_size = world_size
88
+ self.use_horovod = use_horovod
89
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90
+
91
+ # cache state
92
+ self.prev_num_logits = 0
93
+ self.labels = {}
94
+
95
+ def forward(self, image_features, text_features, logit_scale=1.):
96
+ device = image_features.device
97
+ if self.world_size > 1:
98
+ all_image_features, all_text_features = gather_features(
99
+ image_features, text_features,
100
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101
+
102
+ if self.local_loss:
103
+ logits_per_image = logit_scale * image_features @ all_text_features.T
104
+ logits_per_text = logit_scale * text_features @ all_image_features.T
105
+ else:
106
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
107
+ logits_per_text = logits_per_image.T
108
+ else:
109
+ logits_per_image = logit_scale * image_features @ text_features.T
110
+ logits_per_text = logit_scale * text_features @ image_features.T
111
+ # calculated ground-truth and cache if enabled
112
+ num_logits = logits_per_image.shape[0]
113
+ if self.prev_num_logits != num_logits or device not in self.labels:
114
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
115
+ if self.world_size > 1 and self.local_loss:
116
+ labels = labels + num_logits * self.rank
117
+ if self.cache_labels:
118
+ self.labels[device] = labels
119
+ self.prev_num_logits = num_logits
120
+ else:
121
+ labels = self.labels[device]
122
+
123
+ if self.label_smoothing_cross_entropy:
124
+ total_loss = (
125
+ self.label_smoothing_cross_entropy(logits_per_image, labels) +
126
+ self.label_smoothing_cross_entropy(logits_per_text, labels)
127
+ ) / 2
128
+ else:
129
+ total_loss = (
130
+ F.cross_entropy(logits_per_image, labels) +
131
+ F.cross_entropy(logits_per_text, labels)
132
+ ) / 2
133
+
134
+ acc = None
135
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138
+ return total_loss, acc
eva_clip/model.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ try:
16
+ from .hf_model import HFTextEncoder
17
+ except:
18
+ HFTextEncoder = None
19
+ from .modified_resnet import ModifiedResNet
20
+ from .timm_model import TimmModel
21
+ from .eva_vit_model import EVAVisionTransformer
22
+ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
23
+
24
+ try:
25
+ from apex.normalization import FusedLayerNorm
26
+ except:
27
+ FusedLayerNorm = LayerNorm
28
+ print("Please 'pip install apex'")
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ @dataclass
37
+ class CLIPVisionCfg:
38
+ layers: Union[Tuple[int, int, int, int], int] = 12
39
+ width: int = 768
40
+ head_width: int = 64
41
+ mlp_ratio: float = 4.0
42
+ patch_size: int = 16
43
+ image_size: Union[Tuple[int, int], int] = 224
44
+ ls_init_value: Optional[float] = None # layer scale initial value
45
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
46
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
47
+ drop_path_rate: Optional[float] = None # drop path rate
48
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
49
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
50
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
51
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
52
+ timm_proj_bias: bool = False # enable bias final projection
53
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
54
+ qkv_bias: bool = True
55
+ fusedLN: bool = False
56
+ xattn: bool = False
57
+ postnorm: bool = False
58
+ rope: bool = False
59
+ pt_hw_seq_len: int = 16 # 224/14
60
+ intp_freq: bool = False
61
+ naiveswiglu: bool = False
62
+ subln: bool = False
63
+
64
+
65
+ @dataclass
66
+ class CLIPTextCfg:
67
+ context_length: int = 77
68
+ vocab_size: int = 49408
69
+ width: int = 512
70
+ heads: int = 8
71
+ layers: int = 12
72
+ ls_init_value: Optional[float] = None # layer scale initial value
73
+ hf_model_name: str = None
74
+ hf_tokenizer_name: str = None
75
+ hf_model_pretrained: bool = True
76
+ proj: str = 'mlp'
77
+ pooler_type: str = 'mean_pooler'
78
+ masked_language_modeling: bool = False
79
+ fusedLN: bool = False
80
+ xattn: bool = False
81
+ attn_mask: bool = True
82
+
83
+ def get_cast_dtype(precision: str):
84
+ cast_dtype = None
85
+ if precision == 'bf16':
86
+ cast_dtype = torch.bfloat16
87
+ elif precision == 'fp16':
88
+ cast_dtype = torch.float16
89
+ return cast_dtype
90
+
91
+
92
+ def _build_vision_tower(
93
+ embed_dim: int,
94
+ vision_cfg: CLIPVisionCfg,
95
+ quick_gelu: bool = False,
96
+ cast_dtype: Optional[torch.dtype] = None
97
+ ):
98
+ if isinstance(vision_cfg, dict):
99
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
100
+
101
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
102
+ # memory efficient in recent PyTorch releases (>= 1.10).
103
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
104
+ act_layer = QuickGELU if quick_gelu else nn.GELU
105
+
106
+ if vision_cfg.eva_model_name:
107
+ vision_heads = vision_cfg.width // vision_cfg.head_width
108
+ norm_layer = LayerNorm
109
+
110
+ visual = EVAVisionTransformer(
111
+ img_size=vision_cfg.image_size,
112
+ patch_size=vision_cfg.patch_size,
113
+ num_classes=embed_dim,
114
+ use_mean_pooling=vision_cfg.global_average_pool, #False
115
+ init_values=vision_cfg.ls_init_value,
116
+ patch_dropout=vision_cfg.patch_dropout,
117
+ embed_dim=vision_cfg.width,
118
+ depth=vision_cfg.layers,
119
+ num_heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ qkv_bias=vision_cfg.qkv_bias,
122
+ drop_path_rate=vision_cfg.drop_path_rate,
123
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
124
+ xattn=vision_cfg.xattn,
125
+ rope=vision_cfg.rope,
126
+ postnorm=vision_cfg.postnorm,
127
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
128
+ intp_freq= vision_cfg.intp_freq,
129
+ naiveswiglu= vision_cfg.naiveswiglu,
130
+ subln= vision_cfg.subln
131
+ )
132
+ elif vision_cfg.timm_model_name:
133
+ visual = TimmModel(
134
+ vision_cfg.timm_model_name,
135
+ pretrained=vision_cfg.timm_model_pretrained,
136
+ pool=vision_cfg.timm_pool,
137
+ proj=vision_cfg.timm_proj,
138
+ proj_bias=vision_cfg.timm_proj_bias,
139
+ embed_dim=embed_dim,
140
+ image_size=vision_cfg.image_size
141
+ )
142
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
143
+ elif isinstance(vision_cfg.layers, (tuple, list)):
144
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
145
+ visual = ModifiedResNet(
146
+ layers=vision_cfg.layers,
147
+ output_dim=embed_dim,
148
+ heads=vision_heads,
149
+ image_size=vision_cfg.image_size,
150
+ width=vision_cfg.width
151
+ )
152
+ else:
153
+ vision_heads = vision_cfg.width // vision_cfg.head_width
154
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
155
+ visual = VisionTransformer(
156
+ image_size=vision_cfg.image_size,
157
+ patch_size=vision_cfg.patch_size,
158
+ width=vision_cfg.width,
159
+ layers=vision_cfg.layers,
160
+ heads=vision_heads,
161
+ mlp_ratio=vision_cfg.mlp_ratio,
162
+ ls_init_value=vision_cfg.ls_init_value,
163
+ patch_dropout=vision_cfg.patch_dropout,
164
+ global_average_pool=vision_cfg.global_average_pool,
165
+ output_dim=embed_dim,
166
+ act_layer=act_layer,
167
+ norm_layer=norm_layer,
168
+ )
169
+
170
+ return visual
171
+
172
+
173
+ def _build_text_tower(
174
+ embed_dim: int,
175
+ text_cfg: CLIPTextCfg,
176
+ quick_gelu: bool = False,
177
+ cast_dtype: Optional[torch.dtype] = None,
178
+ ):
179
+ if isinstance(text_cfg, dict):
180
+ text_cfg = CLIPTextCfg(**text_cfg)
181
+
182
+ if text_cfg.hf_model_name:
183
+ text = HFTextEncoder(
184
+ text_cfg.hf_model_name,
185
+ output_dim=embed_dim,
186
+ tokenizer_name=text_cfg.hf_tokenizer_name,
187
+ proj=text_cfg.proj,
188
+ pooler_type=text_cfg.pooler_type,
189
+ masked_language_modeling=text_cfg.masked_language_modeling
190
+ )
191
+ else:
192
+ act_layer = QuickGELU if quick_gelu else nn.GELU
193
+ norm_layer = LayerNorm
194
+
195
+ text = TextTransformer(
196
+ context_length=text_cfg.context_length,
197
+ vocab_size=text_cfg.vocab_size,
198
+ width=text_cfg.width,
199
+ heads=text_cfg.heads,
200
+ layers=text_cfg.layers,
201
+ ls_init_value=text_cfg.ls_init_value,
202
+ output_dim=embed_dim,
203
+ act_layer=act_layer,
204
+ norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
205
+ xattn=text_cfg.xattn,
206
+ attn_mask=text_cfg.attn_mask,
207
+ )
208
+ return text
209
+
210
+ class CLIP(nn.Module):
211
+ def __init__(
212
+ self,
213
+ embed_dim: int,
214
+ vision_cfg: CLIPVisionCfg,
215
+ text_cfg: CLIPTextCfg,
216
+ quick_gelu: bool = False,
217
+ cast_dtype: Optional[torch.dtype] = None,
218
+ ):
219
+ super().__init__()
220
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
221
+
222
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
223
+ self.transformer = text.transformer
224
+ self.vocab_size = text.vocab_size
225
+ self.token_embedding = text.token_embedding
226
+ self.positional_embedding = text.positional_embedding
227
+ self.ln_final = text.ln_final
228
+ self.text_projection = text.text_projection
229
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
230
+
231
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
232
+
233
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
234
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
235
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
236
+
237
+ @torch.jit.ignore
238
+ def set_grad_checkpointing(self, enable=True):
239
+ self.visual.set_grad_checkpointing(enable)
240
+ self.transformer.grad_checkpointing = enable
241
+
242
+ @torch.jit.ignore
243
+ def no_weight_decay(self):
244
+ return {'logit_scale'}
245
+
246
+ def encode_image(self, image, normalize: bool = False):
247
+ features = self.visual(image)
248
+ return F.normalize(features, dim=-1) if normalize else features
249
+
250
+ def encode_text(self, text, normalize: bool = False):
251
+ cast_dtype = self.transformer.get_cast_dtype()
252
+
253
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
254
+
255
+ x = x + self.positional_embedding.to(cast_dtype)
256
+ x = x.permute(1, 0, 2) # NLD -> LND
257
+ x = self.transformer(x, attn_mask=self.attn_mask)
258
+ x = x.permute(1, 0, 2) # LND -> NLD
259
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
260
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
261
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
262
+ return F.normalize(x, dim=-1) if normalize else x
263
+
264
+ def forward(self, image, text):
265
+ image_features = self.encode_image(image, normalize=True)
266
+ text_features = self.encode_text(text, normalize=True)
267
+ return image_features, text_features, self.logit_scale.exp()
268
+
269
+
270
+ class CustomCLIP(nn.Module):
271
+ def __init__(
272
+ self,
273
+ embed_dim: int,
274
+ vision_cfg: CLIPVisionCfg,
275
+ text_cfg: CLIPTextCfg,
276
+ quick_gelu: bool = False,
277
+ cast_dtype: Optional[torch.dtype] = None,
278
+ itm_task: bool = False,
279
+ ):
280
+ super().__init__()
281
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
283
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
284
+
285
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
286
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
287
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
288
+
289
+ def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
290
+ self.text.lock(unlocked_layers, freeze_layer_norm)
291
+
292
+ @torch.jit.ignore
293
+ def set_grad_checkpointing(self, enable=True):
294
+ self.visual.set_grad_checkpointing(enable)
295
+ self.text.set_grad_checkpointing(enable)
296
+
297
+ @torch.jit.ignore
298
+ def no_weight_decay(self):
299
+ return {'logit_scale'}
300
+
301
+ def encode_image(self, image, normalize: bool = False):
302
+ features = self.visual(image)
303
+ return F.normalize(features, dim=-1) if normalize else features
304
+
305
+ def encode_text(self, text, normalize: bool = False):
306
+ features = self.text(text)
307
+ return F.normalize(features, dim=-1) if normalize else features
308
+
309
+ def forward(self, image, text):
310
+ image_features = self.encode_image(image, normalize=True)
311
+ text_features = self.encode_text(text, normalize=True)
312
+ return image_features, text_features, self.logit_scale.exp()
313
+
314
+
315
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
316
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
317
+
318
+ def _convert_weights(l):
319
+
320
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
321
+ l.weight.data = l.weight.data.to(dtype)
322
+ if l.bias is not None:
323
+ l.bias.data = l.bias.data.to(dtype)
324
+
325
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
326
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
327
+ tensor = getattr(l, attr, None)
328
+ if tensor is not None:
329
+ tensor.data = tensor.data.to(dtype)
330
+
331
+ if isinstance(l, nn.Parameter):
332
+ l.data = l.data.to(dtype)
333
+
334
+ for name in ["text_projection", "proj"]:
335
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
336
+ attr = getattr(l, name, None)
337
+ if attr is not None:
338
+ attr.data = attr.data.to(dtype)
339
+
340
+ model.apply(_convert_weights)
341
+
342
+
343
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
344
+
345
+
346
+ # used to maintain checkpoint compatibility
347
+ def convert_to_custom_text_state_dict(state_dict: dict):
348
+ if 'text_projection' in state_dict:
349
+ # old format state_dict, move text tower -> .text
350
+ new_state_dict = {}
351
+ for k, v in state_dict.items():
352
+ if any(k.startswith(p) for p in (
353
+ 'text_projection',
354
+ 'positional_embedding',
355
+ 'token_embedding',
356
+ 'transformer',
357
+ 'ln_final',
358
+ 'logit_scale'
359
+ )):
360
+ k = 'text.' + k
361
+ new_state_dict[k] = v
362
+ return new_state_dict
363
+ return state_dict
364
+
365
+
366
+ def build_model_from_openai_state_dict(
367
+ state_dict: dict,
368
+ quick_gelu=True,
369
+ cast_dtype=torch.float16,
370
+ ):
371
+ vit = "visual.proj" in state_dict
372
+
373
+ if vit:
374
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
375
+ vision_layers = len(
376
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
377
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
378
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
379
+ image_size = vision_patch_size * grid_size
380
+ else:
381
+ counts: list = [
382
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
383
+ vision_layers = tuple(counts)
384
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
385
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
386
+ vision_patch_size = None
387
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
388
+ image_size = output_width * 32
389
+
390
+ embed_dim = state_dict["text_projection"].shape[1]
391
+ context_length = state_dict["positional_embedding"].shape[0]
392
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
393
+ transformer_width = state_dict["ln_final.weight"].shape[0]
394
+ transformer_heads = transformer_width // 64
395
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
396
+
397
+ vision_cfg = CLIPVisionCfg(
398
+ layers=vision_layers,
399
+ width=vision_width,
400
+ patch_size=vision_patch_size,
401
+ image_size=image_size,
402
+ )
403
+ text_cfg = CLIPTextCfg(
404
+ context_length=context_length,
405
+ vocab_size=vocab_size,
406
+ width=transformer_width,
407
+ heads=transformer_heads,
408
+ layers=transformer_layers
409
+ )
410
+ model = CLIP(
411
+ embed_dim,
412
+ vision_cfg=vision_cfg,
413
+ text_cfg=text_cfg,
414
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
415
+ cast_dtype=cast_dtype,
416
+ )
417
+
418
+ for key in ["input_resolution", "context_length", "vocab_size"]:
419
+ state_dict.pop(key, None)
420
+
421
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
422
+ model.load_state_dict(state_dict)
423
+ return model.eval()
424
+
425
+
426
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
427
+ model.eval()
428
+ image_size = model.visual.image_size
429
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
430
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
431
+ model = torch.jit.trace_module(
432
+ model,
433
+ inputs=dict(
434
+ forward=(example_images, example_text),
435
+ encode_text=(example_text,),
436
+ encode_image=(example_images,)
437
+ ))
438
+ model.visual.image_size = image_size
439
+ return model
eva_clip/model_configs/EVA01-CLIP-B-16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16,
8
+ "eva_model_name": "eva-clip-b-16",
9
+ "ls_init_value": 0.1,
10
+ "drop_path_rate": 0.0
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12
18
+ }
19
+ }
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 24,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA01-CLIP-g-14.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0.4,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
eva_clip/model_configs/EVA02-CLIP-B-16.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "head_width": 64,
8
+ "patch_size": 16,
9
+ "mlp_ratio": 2.6667,
10
+ "eva_model_name": "eva-clip-b-16-X",
11
+ "drop_path_rate": 0.0,
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 512,
24
+ "heads": 8,
25
+ "layers": 12,
26
+ "xattn": true,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14-336.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14-336",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-L-14.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1280,
20
+ "heads": 20,
21
+ "layers": 32,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/model_configs/EVA02-CLIP-bigE-14.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1024,
20
+ "heads": 16,
21
+ "layers": 24,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
eva_clip/modified_resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from eva_clip.utils import freeze_batch_norm_2d
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.act1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.act2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.act3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.act1(self.bn1(self.conv1(x)))
46
+ out = self.act2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.act3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x, key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0.,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.image_size = image_size
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.act1 = nn.ReLU(inplace=True)
112
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113
+ self.bn2 = nn.BatchNorm2d(width // 2)
114
+ self.act2 = nn.ReLU(inplace=True)
115
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116
+ self.bn3 = nn.BatchNorm2d(width)
117
+ self.act3 = nn.ReLU(inplace=True)
118
+ self.avgpool = nn.AvgPool2d(2)
119
+
120
+ # residual layers
121
+ self._inplanes = width # this is a *mutable* variable used during construction
122
+ self.layer1 = self._make_layer(width, layers[0])
123
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126
+
127
+ embed_dim = width * 32 # the ResNet feature dimension
128
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129
+
130
+ self.init_parameters()
131
+
132
+ def _make_layer(self, planes, blocks, stride=1):
133
+ layers = [Bottleneck(self._inplanes, planes, stride)]
134
+
135
+ self._inplanes = planes * Bottleneck.expansion
136
+ for _ in range(1, blocks):
137
+ layers.append(Bottleneck(self._inplanes, planes))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def init_parameters(self):
142
+ if self.attnpool is not None:
143
+ std = self.attnpool.c_proj.in_features ** -0.5
144
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148
+
149
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150
+ for name, param in resnet_block.named_parameters():
151
+ if name.endswith("bn3.weight"):
152
+ nn.init.zeros_(param)
153
+
154
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156
+ for param in self.parameters():
157
+ param.requires_grad = False
158
+ if freeze_bn_stats:
159
+ freeze_batch_norm_2d(self)
160
+
161
+ @torch.jit.ignore
162
+ def set_grad_checkpointing(self, enable=True):
163
+ # FIXME support for non-transformer
164
+ pass
165
+
166
+ def stem(self, x):
167
+ x = self.act1(self.bn1(self.conv1(x)))
168
+ x = self.act2(self.bn2(self.conv2(x)))
169
+ x = self.act3(self.bn3(self.conv3(x)))
170
+ x = self.avgpool(x)
171
+ return x
172
+
173
+ def forward(self, x):
174
+ x = self.stem(x)
175
+ x = self.layer1(x)
176
+ x = self.layer2(x)
177
+ x = self.layer3(x)
178
+ x = self.layer4(x)
179
+ x = self.attnpool(x)
180
+
181
+ return x
eva_clip/openai.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_models_by_tag('openai')
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ precision: Optional[str] = None,
26
+ device: Optional[Union[str, torch.device]] = None,
27
+ jit: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+ cache_dir : Optional[str]
43
+ The directory to cache the downloaded model weights
44
+
45
+ Returns
46
+ -------
47
+ model : torch.nn.Module
48
+ The CLIP model
49
+ preprocess : Callable[[PIL.Image], torch.Tensor]
50
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
+ """
52
+ if device is None:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if precision is None:
55
+ precision = 'fp32' if device == 'cpu' else 'fp16'
56
+
57
+ if get_pretrained_url(name, 'openai'):
58
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59
+ elif os.path.isfile(name):
60
+ model_path = name
61
+ else:
62
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
+
64
+ try:
65
+ # loading JIT archive
66
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
+ state_dict = None
68
+ except RuntimeError:
69
+ # loading saved state dict
70
+ if jit:
71
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
+ jit = False
73
+ state_dict = torch.load(model_path, map_location="cpu")
74
+
75
+ if not jit:
76
+ # Build a non-jit model from the OpenAI jitted model state dict
77
+ cast_dtype = get_cast_dtype(precision)
78
+ try:
79
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
+ except KeyError:
81
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
+
84
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
+ model = model.to(device)
86
+ if precision.startswith('amp') or precision == 'fp32':
87
+ model.float()
88
+ elif precision == 'bf16':
89
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
90
+
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
+
97
+ def patch_device(module):
98
+ try:
99
+ graphs = [module.graph] if hasattr(module, "graph") else []
100
+ except RuntimeError:
101
+ graphs = []
102
+
103
+ if hasattr(module, "forward1"):
104
+ graphs.append(module.forward1.graph)
105
+
106
+ for graph in graphs:
107
+ for node in graph.findAllNodes("prim::Constant"):
108
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
+ node.copyAttributes(device_node)
110
+
111
+ model.apply(patch_device)
112
+ patch_device(model.encode_image)
113
+ patch_device(model.encode_text)
114
+
115
+ # patch dtype to float32 (typically for CPU)
116
+ if precision == 'fp32':
117
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
+ float_node = float_input.node()
120
+
121
+ def patch_float(module):
122
+ try:
123
+ graphs = [module.graph] if hasattr(module, "graph") else []
124
+ except RuntimeError:
125
+ graphs = []
126
+
127
+ if hasattr(module, "forward1"):
128
+ graphs.append(module.forward1.graph)
129
+
130
+ for graph in graphs:
131
+ for node in graph.findAllNodes("aten::to"):
132
+ inputs = list(node.inputs())
133
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
+ if inputs[i].node()["value"] == 5:
135
+ inputs[i].node().copyAttributes(float_node)
136
+
137
+ model.apply(patch_float)
138
+ patch_float(model.encode_image)
139
+ patch_float(model.encode_text)
140
+ model.float()
141
+
142
+ # ensure image_size attr available at consistent location for both jit and non-jit
143
+ model.visual.image_size = model.input_resolution.item()
144
+ return model
eva_clip/pretrained.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ try:
11
+ from huggingface_hub import hf_hub_download
12
+ _has_hf_hub = True
13
+ except ImportError:
14
+ hf_hub_download = None
15
+ _has_hf_hub = False
16
+
17
+
18
+ def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19
+ return dict(
20
+ url=url,
21
+ hf_hub=hf_hub,
22
+ mean=mean,
23
+ std=std,
24
+ )
25
+
26
+ _VITB32 = dict(
27
+ openai=_pcfg(
28
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
+ laion400m_e31=_pcfg(
30
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31
+ laion400m_e32=_pcfg(
32
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33
+ laion2b_e16=_pcfg(
34
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36
+ )
37
+
38
+ _VITB32_quickgelu = dict(
39
+ openai=_pcfg(
40
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41
+ laion400m_e31=_pcfg(
42
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43
+ laion400m_e32=_pcfg(
44
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45
+ )
46
+
47
+ _VITB16 = dict(
48
+ openai=_pcfg(
49
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50
+ laion400m_e31=_pcfg(
51
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52
+ laion400m_e32=_pcfg(
53
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55
+ )
56
+
57
+ _EVAB16 = dict(
58
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62
+ )
63
+
64
+ _VITB16_PLUS_240 = dict(
65
+ laion400m_e31=_pcfg(
66
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67
+ laion400m_e32=_pcfg(
68
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69
+ )
70
+
71
+ _VITL14 = dict(
72
+ openai=_pcfg(
73
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74
+ laion400m_e31=_pcfg(
75
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76
+ laion400m_e32=_pcfg(
77
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78
+ laion2b_s32b_b82k=_pcfg(
79
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81
+ )
82
+
83
+ _EVAL14 = dict(
84
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88
+ )
89
+
90
+ _VITL14_336 = dict(
91
+ openai=_pcfg(
92
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93
+ )
94
+
95
+ _EVAL14_336 = dict(
96
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98
+ eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99
+ eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100
+ )
101
+
102
+ _VITH14 = dict(
103
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104
+ )
105
+
106
+ _VITg14 = dict(
107
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109
+ )
110
+
111
+ _EVAg14 = dict(
112
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116
+ )
117
+
118
+ _EVAg14_PLUS = dict(
119
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120
+ eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122
+ eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123
+ )
124
+
125
+ _VITbigG14 = dict(
126
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127
+ )
128
+
129
+ _EVAbigE14 = dict(
130
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134
+ )
135
+
136
+ _EVAbigE14_PLUS = dict(
137
+ eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138
+ eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139
+ eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140
+ eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141
+ )
142
+
143
+
144
+ _PRETRAINED = {
145
+ # "ViT-B-32": _VITB32,
146
+ "OpenaiCLIP-B-32": _VITB32,
147
+ "OpenCLIP-B-32": _VITB32,
148
+
149
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152
+
153
+ # "ViT-B-16": _VITB16,
154
+ "OpenaiCLIP-B-16": _VITB16,
155
+ "OpenCLIP-B-16": _VITB16,
156
+
157
+ "EVA02-B-16": _EVAB16,
158
+ "EVA02-CLIP-B-16": _EVAB16,
159
+
160
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162
+
163
+ # "ViT-L-14": _VITL14,
164
+ "OpenaiCLIP-L-14": _VITL14,
165
+ "OpenCLIP-L-14": _VITL14,
166
+
167
+ "EVA02-L-14": _EVAL14,
168
+ "EVA02-CLIP-L-14": _EVAL14,
169
+
170
+ # "ViT-L-14-336": _VITL14_336,
171
+ "OpenaiCLIP-L-14-336": _VITL14_336,
172
+
173
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
174
+
175
+ # "ViT-H-14": _VITH14,
176
+ # "ViT-g-14": _VITg14,
177
+ "OpenCLIP-H-14": _VITH14,
178
+ "OpenCLIP-g-14": _VITg14,
179
+
180
+ "EVA01-CLIP-g-14": _EVAg14,
181
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182
+
183
+ # "ViT-bigG-14": _VITbigG14,
184
+ "OpenCLIP-bigG-14": _VITbigG14,
185
+
186
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
187
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188
+ }
189
+
190
+
191
+ def _clean_tag(tag: str):
192
+ # normalize pretrained tags
193
+ return tag.lower().replace('-', '_')
194
+
195
+
196
+ def list_pretrained(as_str: bool = False):
197
+ """ returns list of pretrained models
198
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199
+ """
200
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201
+
202
+
203
+ def list_pretrained_models_by_tag(tag: str):
204
+ """ return all models having the specified pretrain tag """
205
+ models = []
206
+ tag = _clean_tag(tag)
207
+ for k in _PRETRAINED.keys():
208
+ if tag in _PRETRAINED[k]:
209
+ models.append(k)
210
+ return models
211
+
212
+
213
+ def list_pretrained_tags_by_model(model: str):
214
+ """ return all pretrain tags for the specified model architecture """
215
+ tags = []
216
+ if model in _PRETRAINED:
217
+ tags.extend(_PRETRAINED[model].keys())
218
+ return tags
219
+
220
+
221
+ def is_pretrained_cfg(model: str, tag: str):
222
+ if model not in _PRETRAINED:
223
+ return False
224
+ return _clean_tag(tag) in _PRETRAINED[model]
225
+
226
+
227
+ def get_pretrained_cfg(model: str, tag: str):
228
+ if model not in _PRETRAINED:
229
+ return {}
230
+ model_pretrained = _PRETRAINED[model]
231
+ return model_pretrained.get(_clean_tag(tag), {})
232
+
233
+
234
+ def get_pretrained_url(model: str, tag: str):
235
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
236
+ return cfg.get('url', '')
237
+
238
+
239
+ def download_pretrained_from_url(
240
+ url: str,
241
+ cache_dir: Union[str, None] = None,
242
+ ):
243
+ if not cache_dir:
244
+ cache_dir = os.path.expanduser("~/.cache/clip")
245
+ os.makedirs(cache_dir, exist_ok=True)
246
+ filename = os.path.basename(url)
247
+
248
+ if 'openaipublic' in url:
249
+ expected_sha256 = url.split("/")[-2]
250
+ elif 'mlfoundations' in url:
251
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252
+ else:
253
+ expected_sha256 = ''
254
+
255
+ download_target = os.path.join(cache_dir, filename)
256
+
257
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
258
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
259
+
260
+ if os.path.isfile(download_target):
261
+ if expected_sha256:
262
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
+ return download_target
264
+ else:
265
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266
+ else:
267
+ return download_target
268
+
269
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271
+ while True:
272
+ buffer = source.read(8192)
273
+ if not buffer:
274
+ break
275
+
276
+ output.write(buffer)
277
+ loop.update(len(buffer))
278
+
279
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281
+
282
+ return download_target
283
+
284
+
285
+ def has_hf_hub(necessary=False):
286
+ if not _has_hf_hub and necessary:
287
+ # if no HF Hub module installed, and it is necessary to continue, raise error
288
+ raise RuntimeError(
289
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290
+ return _has_hf_hub
291
+
292
+
293
+ def download_pretrained_from_hf(
294
+ model_id: str,
295
+ filename: str = 'open_clip_pytorch_model.bin',
296
+ revision=None,
297
+ cache_dir: Union[str, None] = None,
298
+ ):
299
+ has_hf_hub(True)
300
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301
+ return cached_file
302
+
303
+
304
+ def download_pretrained(
305
+ cfg: Dict,
306
+ force_hf_hub: bool = False,
307
+ cache_dir: Union[str, None] = None,
308
+ ):
309
+ target = ''
310
+ if not cfg:
311
+ return target
312
+
313
+ download_url = cfg.get('url', '')
314
+ download_hf_hub = cfg.get('hf_hub', '')
315
+ if download_hf_hub and force_hf_hub:
316
+ # use HF hub even if url exists
317
+ download_url = ''
318
+
319
+ if download_url:
320
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321
+ elif download_hf_hub:
322
+ has_hf_hub(True)
323
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
324
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326
+ model_id, filename = os.path.split(download_hf_hub)
327
+ if filename:
328
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329
+ else:
330
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331
+
332
+ return target
eva_clip/rope.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+
30
+ class VisionRotaryEmbedding(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ pt_seq_len,
35
+ ft_seq_len=None,
36
+ custom_freqs = None,
37
+ freqs_for = 'lang',
38
+ theta = 10000,
39
+ max_freq = 10,
40
+ num_freqs = 1,
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59
+
60
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62
+
63
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64
+
65
+ self.register_buffer("freqs_cos", freqs.cos())
66
+ self.register_buffer("freqs_sin", freqs.sin())
67
+
68
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69
+
70
+ def forward(self, t, start_index = 0):
71
+ rot_dim = self.freqs_cos.shape[-1]
72
+ end_index = start_index + rot_dim
73
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76
+
77
+ return torch.cat((t_left, t, t_right), dim = -1)
78
+
79
+ class VisionRotaryEmbeddingFast(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim,
83
+ pt_seq_len,
84
+ ft_seq_len=None,
85
+ custom_freqs = None,
86
+ freqs_for = 'lang',
87
+ theta = 10000,
88
+ max_freq = 10,
89
+ num_freqs = 1,
90
+ patch_dropout = 0.
91
+ ):
92
+ super().__init__()
93
+ if custom_freqs:
94
+ freqs = custom_freqs
95
+ elif freqs_for == 'lang':
96
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97
+ elif freqs_for == 'pixel':
98
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99
+ elif freqs_for == 'constant':
100
+ freqs = torch.ones(num_freqs).float()
101
+ else:
102
+ raise ValueError(f'unknown modality {freqs_for}')
103
+
104
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
105
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106
+
107
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
108
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110
+
111
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113
+
114
+ self.patch_dropout = patch_dropout
115
+
116
+ self.register_buffer("freqs_cos", freqs_cos)
117
+ self.register_buffer("freqs_sin", freqs_sin)
118
+
119
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120
+
121
+ def forward(self, t, patch_indices_keep=None):
122
+ if patch_indices_keep is not None:
123
+ batch = t.size()[0]
124
+ batch_indices = torch.arange(batch)
125
+ batch_indices = batch_indices[..., None]
126
+
127
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129
+
130
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134
+
135
+ return t * freqs_cos + rotate_half(t) * freqs_sin
136
+
137
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
eva_clip/timm_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from .utils import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name,
36
+ embed_dim,
37
+ image_size=224,
38
+ pool='avg',
39
+ proj='linear',
40
+ proj_bias=False,
41
+ drop=0.,
42
+ pretrained=False):
43
+ super().__init__()
44
+ if timm is None:
45
+ raise RuntimeError("Please `pip install timm` to use timm models.")
46
+
47
+ self.image_size = to_2tuple(image_size)
48
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
49
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
50
+ feature_ndim = 1 if not feat_size else 2
51
+ if pool in ('abs_attn', 'rot_attn'):
52
+ assert feature_ndim == 2
53
+ # if attn pooling used, remove both classifier and default pool
54
+ self.trunk.reset_classifier(0, global_pool='')
55
+ else:
56
+ # reset global pool if pool config set, otherwise leave as network default
57
+ reset_kwargs = dict(global_pool=pool) if pool else {}
58
+ self.trunk.reset_classifier(0, **reset_kwargs)
59
+ prev_chs = self.trunk.num_features
60
+
61
+ head_layers = OrderedDict()
62
+ if pool == 'abs_attn':
63
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64
+ prev_chs = embed_dim
65
+ elif pool == 'rot_attn':
66
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67
+ prev_chs = embed_dim
68
+ else:
69
+ assert proj, 'projection layer needed if non-attention pooling is used.'
70
+
71
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72
+ if proj == 'linear':
73
+ head_layers['drop'] = nn.Dropout(drop)
74
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75
+ elif proj == 'mlp':
76
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77
+
78
+ self.head = nn.Sequential(head_layers)
79
+
80
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81
+ """ lock modules
82
+ Args:
83
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84
+ """
85
+ if not unlocked_groups:
86
+ # lock full model
87
+ for param in self.trunk.parameters():
88
+ param.requires_grad = False
89
+ if freeze_bn_stats:
90
+ freeze_batch_norm_2d(self.trunk)
91
+ else:
92
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93
+ try:
94
+ # FIXME import here until API stable and in an official release
95
+ from timm.models.helpers import group_parameters, group_modules
96
+ except ImportError:
97
+ raise RuntimeError(
98
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99
+ matcher = self.trunk.group_matcher()
100
+ gparams = group_parameters(self.trunk, matcher)
101
+ max_layer_id = max(gparams.keys())
102
+ max_layer_id = max_layer_id - unlocked_groups
103
+ for group_idx in range(max_layer_id + 1):
104
+ group = gparams[group_idx]
105
+ for param in group:
106
+ self.trunk.get_parameter(param).requires_grad = False
107
+ if freeze_bn_stats:
108
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
109
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110
+ freeze_batch_norm_2d(self.trunk, gmodules)
111
+
112
+ @torch.jit.ignore
113
+ def set_grad_checkpointing(self, enable=True):
114
+ try:
115
+ self.trunk.set_grad_checkpointing(enable)
116
+ except Exception as e:
117
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118
+
119
+ def forward(self, x):
120
+ x = self.trunk(x)
121
+ x = self.head(x)
122
+ return x
eva_clip/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ from functools import lru_cache
9
+ from typing import Union, List
10
+
11
+ import ftfy
12
+ import regex as re
13
+ import torch
14
+
15
+ # https://stackoverflow.com/q/62691279
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+
20
+ @lru_cache()
21
+ def default_bpe():
22
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23
+
24
+
25
+ @lru_cache()
26
+ def bytes_to_unicode():
27
+ """
28
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
29
+ The reversible bpe codes work on unicode strings.
30
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
33
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
35
+ """
36
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37
+ cs = bs[:]
38
+ n = 0
39
+ for b in range(2**8):
40
+ if b not in bs:
41
+ bs.append(b)
42
+ cs.append(2**8+n)
43
+ n += 1
44
+ cs = [chr(n) for n in cs]
45
+ return dict(zip(bs, cs))
46
+
47
+
48
+ def get_pairs(word):
49
+ """Return set of symbol pairs in a word.
50
+ Word is represented as tuple of symbols (symbols being variable-length strings).
51
+ """
52
+ pairs = set()
53
+ prev_char = word[0]
54
+ for char in word[1:]:
55
+ pairs.add((prev_char, char))
56
+ prev_char = char
57
+ return pairs
58
+
59
+
60
+ def basic_clean(text):
61
+ text = ftfy.fix_text(text)
62
+ text = html.unescape(html.unescape(text))
63
+ return text.strip()
64
+
65
+
66
+ def whitespace_clean(text):
67
+ text = re.sub(r'\s+', ' ', text)
68
+ text = text.strip()
69
+ return text
70
+
71
+
72
+ class SimpleTokenizer(object):
73
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74
+ self.byte_encoder = bytes_to_unicode()
75
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77
+ merges = merges[1:49152-256-2+1]
78
+ merges = [tuple(merge.split()) for merge in merges]
79
+ vocab = list(bytes_to_unicode().values())
80
+ vocab = vocab + [v+'</w>' for v in vocab]
81
+ for merge in merges:
82
+ vocab.append(''.join(merge))
83
+ if not special_tokens:
84
+ special_tokens = ['<start_of_text>', '<end_of_text>']
85
+ else:
86
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
87
+ vocab.extend(special_tokens)
88
+ self.encoder = dict(zip(vocab, range(len(vocab))))
89
+ self.decoder = {v: k for k, v in self.encoder.items()}
90
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
91
+ self.cache = {t:t for t in special_tokens}
92
+ special = "|".join(special_tokens)
93
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94
+
95
+ self.vocab_size = len(self.encoder)
96
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
97
+
98
+ def bpe(self, token):
99
+ if token in self.cache:
100
+ return self.cache[token]
101
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
102
+ pairs = get_pairs(word)
103
+
104
+ if not pairs:
105
+ return token+'</w>'
106
+
107
+ while True:
108
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109
+ if bigram not in self.bpe_ranks:
110
+ break
111
+ first, second = bigram
112
+ new_word = []
113
+ i = 0
114
+ while i < len(word):
115
+ try:
116
+ j = word.index(first, i)
117
+ new_word.extend(word[i:j])
118
+ i = j
119
+ except:
120
+ new_word.extend(word[i:])
121
+ break
122
+
123
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
124
+ new_word.append(first+second)
125
+ i += 2
126
+ else:
127
+ new_word.append(word[i])
128
+ i += 1
129
+ new_word = tuple(new_word)
130
+ word = new_word
131
+ if len(word) == 1:
132
+ break
133
+ else:
134
+ pairs = get_pairs(word)
135
+ word = ' '.join(word)
136
+ self.cache[token] = word
137
+ return word
138
+
139
+ def encode(self, text):
140
+ bpe_tokens = []
141
+ text = whitespace_clean(basic_clean(text)).lower()
142
+ for token in re.findall(self.pat, text):
143
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145
+ return bpe_tokens
146
+
147
+ def decode(self, tokens):
148
+ text = ''.join([self.decoder[token] for token in tokens])
149
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
150
+ return text
151
+
152
+
153
+ _tokenizer = SimpleTokenizer()
154
+
155
+
156
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157
+ """
158
+ Returns the tokenized representation of given input string(s)
159
+
160
+ Parameters
161
+ ----------
162
+ texts : Union[str, List[str]]
163
+ An input string or a list of input strings to tokenize
164
+ context_length : int
165
+ The context length to use; all CLIP models use 77 as the context length
166
+
167
+ Returns
168
+ -------
169
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170
+ """
171
+ if isinstance(texts, str):
172
+ texts = [texts]
173
+
174
+ sot_token = _tokenizer.encoder["<start_of_text>"]
175
+ eot_token = _tokenizer.encoder["<end_of_text>"]
176
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178
+
179
+ for i, tokens in enumerate(all_tokens):
180
+ if len(tokens) > context_length:
181
+ tokens = tokens[:context_length] # Truncate
182
+ tokens[-1] = eot_token
183
+ result[i, :len(tokens)] = torch.tensor(tokens)
184
+
185
+ return result
186
+
187
+
188
+ class HFTokenizer:
189
+ "HuggingFace tokenizer wrapper"
190
+ def __init__(self, tokenizer_name:str):
191
+ from transformers import AutoTokenizer
192
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193
+
194
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195
+ # same cleaning as for default tokenizer, except lowercasing
196
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197
+ if isinstance(texts, str):
198
+ texts = [texts]
199
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
200
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201
+ return input_ids
eva_clip/transform.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8
+ CenterCrop
9
+
10
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11
+
12
+
13
+ class ResizeMaxSize(nn.Module):
14
+
15
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16
+ super().__init__()
17
+ if not isinstance(max_size, int):
18
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
19
+ self.max_size = max_size
20
+ self.interpolation = interpolation
21
+ self.fn = min if fn == 'min' else min
22
+ self.fill = fill
23
+
24
+ def forward(self, img):
25
+ if isinstance(img, torch.Tensor):
26
+ height, width = img.shape[:2]
27
+ else:
28
+ width, height = img.size
29
+ scale = self.max_size / float(max(height, width))
30
+ if scale != 1.0:
31
+ new_size = tuple(round(dim * scale) for dim in (height, width))
32
+ img = F.resize(img, new_size, self.interpolation)
33
+ pad_h = self.max_size - new_size[0]
34
+ pad_w = self.max_size - new_size[1]
35
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36
+ return img
37
+
38
+
39
+ def _convert_to_rgb(image):
40
+ return image.convert('RGB')
41
+
42
+
43
+ # class CatGen(nn.Module):
44
+ # def __init__(self, num=4):
45
+ # self.num = num
46
+ # def mixgen_batch(image, text):
47
+ # batch_size = image.shape[0]
48
+ # index = np.random.permutation(batch_size)
49
+
50
+ # cat_images = []
51
+ # for i in range(batch_size):
52
+ # # image mixup
53
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54
+ # # text concat
55
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56
+ # text = torch.stack(text)
57
+ # return image, text
58
+
59
+
60
+ def image_transform(
61
+ image_size: int,
62
+ is_train: bool,
63
+ mean: Optional[Tuple[float, ...]] = None,
64
+ std: Optional[Tuple[float, ...]] = None,
65
+ resize_longest_max: bool = False,
66
+ fill_color: int = 0,
67
+ ):
68
+ mean = mean or OPENAI_DATASET_MEAN
69
+ if not isinstance(mean, (list, tuple)):
70
+ mean = (mean,) * 3
71
+
72
+ std = std or OPENAI_DATASET_STD
73
+ if not isinstance(std, (list, tuple)):
74
+ std = (std,) * 3
75
+
76
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78
+ image_size = image_size[0]
79
+
80
+ normalize = Normalize(mean=mean, std=std)
81
+ if is_train:
82
+ return Compose([
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ])
88
+ else:
89
+ if resize_longest_max:
90
+ transforms = [
91
+ ResizeMaxSize(image_size, fill=fill_color)
92
+ ]
93
+ else:
94
+ transforms = [
95
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96
+ CenterCrop(image_size),
97
+ ]
98
+ transforms.extend([
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ])
103
+ return Compose(transforms)
eva_clip/transformer.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Sequence
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ try:
12
+ from timm.models.layers import trunc_normal_
13
+ except:
14
+ from timm.layers import trunc_normal_
15
+
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+ from .utils import to_2tuple
18
+
19
+ if os.getenv('ENV_TYPE') == 'deepspeed':
20
+ try:
21
+ import deepspeed
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
+ except:
24
+ print("Please 'pip install deepspeed'")
25
+ deepspeed = None
26
+ from torch.utils.checkpoint import checkpoint
27
+ else:
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ print("Please 'pip install xformers'")
35
+
36
+ class LayerNormFp32(nn.LayerNorm):
37
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ output = F.layer_norm(
43
+ x.float(),
44
+ self.normalized_shape,
45
+ self.weight.float() if self.weight is not None else None,
46
+ self.bias.float() if self.bias is not None else None,
47
+ self.eps,
48
+ )
49
+ return output.type_as(x)
50
+
51
+
52
+ class LayerNorm(nn.LayerNorm):
53
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ orig_type = x.dtype
57
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
+ return x.to(orig_type)
59
+
60
+ class QuickGELU(nn.Module):
61
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
62
+ def forward(self, x: torch.Tensor):
63
+ return x * torch.sigmoid(1.702 * x)
64
+
65
+
66
+ class LayerScale(nn.Module):
67
+ def __init__(self, dim, init_values=1e-5, inplace=False):
68
+ super().__init__()
69
+ self.inplace = inplace
70
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
71
+
72
+ def forward(self, x):
73
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
+
75
+ class PatchDropout(nn.Module):
76
+ """
77
+ https://arxiv.org/abs/2212.00794
78
+ """
79
+
80
+ def __init__(self, prob, exclude_first_token=True):
81
+ super().__init__()
82
+ assert 0 <= prob < 1.
83
+ self.prob = prob
84
+ self.exclude_first_token = exclude_first_token # exclude CLS token
85
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
86
+
87
+ def forward(self, x):
88
+ if not self.training or self.prob == 0.:
89
+ return x
90
+
91
+ if self.exclude_first_token:
92
+ cls_tokens, x = x[:, :1], x[:, 1:]
93
+ else:
94
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
95
+
96
+ batch = x.size()[0]
97
+ num_tokens = x.size()[1]
98
+
99
+ batch_indices = torch.arange(batch)
100
+ batch_indices = batch_indices[..., None]
101
+
102
+ keep_prob = 1 - self.prob
103
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
104
+
105
+ rand = torch.randn(batch, num_tokens)
106
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
107
+
108
+ x = x[batch_indices, patch_indices_keep]
109
+
110
+ if self.exclude_first_token:
111
+ x = torch.cat((cls_tokens, x), dim=1)
112
+
113
+ if self.training and os.getenv('RoPE') == '1':
114
+ return x, patch_indices_keep
115
+
116
+ return x
117
+
118
+
119
+ def _in_projection_packed(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ w: torch.Tensor,
124
+ b: Optional[torch.Tensor] = None,
125
+ ):
126
+ """
127
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
128
+ """
129
+ E = q.size(-1)
130
+ if k is v:
131
+ if q is k:
132
+ # self-attention
133
+ return F.linear(q, w, b).chunk(3, dim=-1)
134
+ else:
135
+ # encoder-decoder attention
136
+ w_q, w_kv = w.split([E, E * 2])
137
+ if b is None:
138
+ b_q = b_kv = None
139
+ else:
140
+ b_q, b_kv = b.split([E, E * 2])
141
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
142
+ else:
143
+ w_q, w_k, w_v = w.chunk(3)
144
+ if b is None:
145
+ b_q = b_k = b_v = None
146
+ else:
147
+ b_q, b_k, b_v = b.chunk(3)
148
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim,
154
+ num_heads=8,
155
+ qkv_bias=True,
156
+ scaled_cosine=False,
157
+ scale_heads=False,
158
+ logit_scale_max=math.log(1. / 0.01),
159
+ attn_drop=0.,
160
+ proj_drop=0.,
161
+ xattn=False,
162
+ rope=False
163
+ ):
164
+ super().__init__()
165
+ self.scaled_cosine = scaled_cosine
166
+ self.scale_heads = scale_heads
167
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
168
+ self.num_heads = num_heads
169
+ self.head_dim = dim // num_heads
170
+ self.scale = self.head_dim ** -0.5
171
+ self.logit_scale_max = logit_scale_max
172
+
173
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
174
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
175
+ if qkv_bias:
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
177
+ else:
178
+ self.in_proj_bias = None
179
+
180
+ if self.scaled_cosine:
181
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
182
+ else:
183
+ self.logit_scale = None
184
+ self.attn_drop = nn.Dropout(attn_drop)
185
+ if self.scale_heads:
186
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
187
+ else:
188
+ self.head_scale = None
189
+ self.out_proj = nn.Linear(dim, dim)
190
+ self.out_drop = nn.Dropout(proj_drop)
191
+ self.xattn = xattn
192
+ self.xattn_drop = attn_drop
193
+ self.rope = rope
194
+
195
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
196
+ L, N, C = x.shape
197
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
198
+ if self.xattn:
199
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
200
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
201
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
202
+
203
+ x = xops.memory_efficient_attention(
204
+ q, k, v,
205
+ p=self.xattn_drop,
206
+ scale=self.scale if self.logit_scale is None else None,
207
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
208
+ )
209
+ else:
210
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
211
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
212
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
213
+
214
+ if self.logit_scale is not None:
215
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
216
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
217
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
218
+ attn = attn.view(-1, L, L)
219
+ else:
220
+ q = q * self.scale
221
+ attn = torch.bmm(q, k.transpose(-1, -2))
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
226
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
227
+ attn_mask = new_attn_mask
228
+ attn += attn_mask
229
+
230
+ attn = attn.softmax(dim=-1)
231
+ attn = self.attn_drop(attn)
232
+
233
+ x = torch.bmm(attn, v)
234
+
235
+ if self.head_scale is not None:
236
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
237
+ x = x.view(-1, L, C)
238
+ x = x.transpose(0, 1).reshape(L, N, C)
239
+ x = self.out_proj(x)
240
+ x = self.out_drop(x)
241
+ return x
242
+
243
+ class CustomAttention(nn.Module):
244
+ def __init__(
245
+ self,
246
+ dim,
247
+ num_heads=8,
248
+ qkv_bias=True,
249
+ scaled_cosine=True,
250
+ scale_heads=False,
251
+ logit_scale_max=math.log(1. / 0.01),
252
+ attn_drop=0.,
253
+ proj_drop=0.,
254
+ xattn=False
255
+ ):
256
+ super().__init__()
257
+ self.scaled_cosine = scaled_cosine
258
+ self.scale_heads = scale_heads
259
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
260
+ self.num_heads = num_heads
261
+ self.head_dim = dim // num_heads
262
+ self.scale = self.head_dim ** -0.5
263
+ self.logit_scale_max = logit_scale_max
264
+
265
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
266
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
267
+ if qkv_bias:
268
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
269
+ else:
270
+ self.in_proj_bias = None
271
+
272
+ if self.scaled_cosine:
273
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
274
+ else:
275
+ self.logit_scale = None
276
+ self.attn_drop = nn.Dropout(attn_drop)
277
+ if self.scale_heads:
278
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
279
+ else:
280
+ self.head_scale = None
281
+ self.out_proj = nn.Linear(dim, dim)
282
+ self.out_drop = nn.Dropout(proj_drop)
283
+ self.xattn = xattn
284
+ self.xattn_drop = attn_drop
285
+
286
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
287
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
288
+ N_q, B_q, C_q = q.shape
289
+ N_k, B_k, C_k = k.shape
290
+ N_v, B_v, C_v = v.shape
291
+ if self.xattn:
292
+ # B, N, C -> B, N, num_heads, C
293
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
294
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
295
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
296
+
297
+ x = xops.memory_efficient_attention(
298
+ q, k, v,
299
+ p=self.xattn_drop,
300
+ scale=self.scale if self.logit_scale is None else None,
301
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
302
+ )
303
+ else:
304
+ # B*H, L, C
305
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
306
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
307
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
308
+
309
+ if self.logit_scale is not None:
310
+ # B*H, N_q, N_k
311
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
312
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
313
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
314
+ attn = attn.view(-1, N_q, N_k)
315
+ else:
316
+ q = q * self.scale
317
+ attn = torch.bmm(q, k.transpose(-1, -2))
318
+
319
+ if attn_mask is not None:
320
+ if attn_mask.dtype == torch.bool:
321
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
322
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
323
+ attn_mask = new_attn_mask
324
+ attn += attn_mask
325
+
326
+ attn = attn.softmax(dim=-1)
327
+ attn = self.attn_drop(attn)
328
+
329
+ x = torch.bmm(attn, v)
330
+
331
+ if self.head_scale is not None:
332
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
333
+ x = x.view(-1, N_q, C_q)
334
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
335
+ x = self.out_proj(x)
336
+ x = self.out_drop(x)
337
+ return x
338
+
339
+ class CustomResidualAttentionBlock(nn.Module):
340
+ def __init__(
341
+ self,
342
+ d_model: int,
343
+ n_head: int,
344
+ mlp_ratio: float = 4.0,
345
+ ls_init_value: float = None,
346
+ act_layer: Callable = nn.GELU,
347
+ norm_layer: Callable = LayerNorm,
348
+ scale_cosine_attn: bool = False,
349
+ scale_heads: bool = False,
350
+ scale_attn: bool = False,
351
+ scale_fc: bool = False,
352
+ cross_attn: bool = False,
353
+ xattn: bool = False,
354
+ ):
355
+ super().__init__()
356
+
357
+ self.ln_1 = norm_layer(d_model)
358
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
359
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
360
+ self.attn = CustomAttention(
361
+ d_model, n_head,
362
+ qkv_bias=True,
363
+ attn_drop=0.,
364
+ proj_drop=0.,
365
+ scaled_cosine=scale_cosine_attn,
366
+ scale_heads=scale_heads,
367
+ xattn=xattn
368
+ )
369
+
370
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
371
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
372
+
373
+ self.ln_2 = norm_layer(d_model)
374
+ mlp_width = int(d_model * mlp_ratio)
375
+ self.mlp = nn.Sequential(OrderedDict([
376
+ ("c_fc", nn.Linear(d_model, mlp_width)),
377
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
378
+ ("gelu", act_layer()),
379
+ ("c_proj", nn.Linear(mlp_width, d_model))
380
+ ]))
381
+
382
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
383
+
384
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
385
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
386
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
387
+ return q
388
+
389
+ class CustomTransformer(nn.Module):
390
+ def __init__(
391
+ self,
392
+ width: int,
393
+ layers: int,
394
+ heads: int,
395
+ mlp_ratio: float = 4.0,
396
+ ls_init_value: float = None,
397
+ act_layer: Callable = nn.GELU,
398
+ norm_layer: Callable = LayerNorm,
399
+ scale_cosine_attn: bool = True,
400
+ scale_heads: bool = False,
401
+ scale_attn: bool = False,
402
+ scale_fc: bool = False,
403
+ cross_attn: bool = False,
404
+ xattn: bool = False,
405
+ ):
406
+ super().__init__()
407
+ self.width = width
408
+ self.layers = layers
409
+ self.grad_checkpointing = False
410
+ self.xattn = xattn
411
+
412
+ self.resblocks = nn.ModuleList([
413
+ CustomResidualAttentionBlock(
414
+ width,
415
+ heads,
416
+ mlp_ratio,
417
+ ls_init_value=ls_init_value,
418
+ act_layer=act_layer,
419
+ norm_layer=norm_layer,
420
+ scale_cosine_attn=scale_cosine_attn,
421
+ scale_heads=scale_heads,
422
+ scale_attn=scale_attn,
423
+ scale_fc=scale_fc,
424
+ cross_attn=cross_attn,
425
+ xattn=xattn)
426
+ for _ in range(layers)
427
+ ])
428
+
429
+ def get_cast_dtype(self) -> torch.dtype:
430
+ return self.resblocks[0].mlp.c_fc.weight.dtype
431
+
432
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
433
+ if k is None and v is None:
434
+ k = v = q
435
+ for r in self.resblocks:
436
+ if self.grad_checkpointing and not torch.jit.is_scripting():
437
+ q = checkpoint(r, q, k, v, attn_mask)
438
+ else:
439
+ q = r(q, k, v, attn_mask=attn_mask)
440
+ return q
441
+
442
+
443
+ class ResidualAttentionBlock(nn.Module):
444
+ def __init__(
445
+ self,
446
+ d_model: int,
447
+ n_head: int,
448
+ mlp_ratio: float = 4.0,
449
+ ls_init_value: float = None,
450
+ act_layer: Callable = nn.GELU,
451
+ norm_layer: Callable = LayerNorm,
452
+ xattn: bool = False,
453
+ ):
454
+ super().__init__()
455
+
456
+ self.ln_1 = norm_layer(d_model)
457
+ if xattn:
458
+ self.attn = Attention(d_model, n_head, xattn=True)
459
+ else:
460
+ self.attn = nn.MultiheadAttention(d_model, n_head)
461
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
462
+
463
+ self.ln_2 = norm_layer(d_model)
464
+ mlp_width = int(d_model * mlp_ratio)
465
+ self.mlp = nn.Sequential(OrderedDict([
466
+ ("c_fc", nn.Linear(d_model, mlp_width)),
467
+ ("gelu", act_layer()),
468
+ ("c_proj", nn.Linear(mlp_width, d_model))
469
+ ]))
470
+
471
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
472
+ self.xattn = xattn
473
+
474
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
475
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
476
+ if self.xattn:
477
+ return self.attn(x, attn_mask=attn_mask)
478
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
479
+
480
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
481
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
482
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
483
+ return x
484
+
485
+ class Transformer(nn.Module):
486
+ def __init__(
487
+ self,
488
+ width: int,
489
+ layers: int,
490
+ heads: int,
491
+ mlp_ratio: float = 4.0,
492
+ ls_init_value: float = None,
493
+ act_layer: Callable = nn.GELU,
494
+ norm_layer: Callable = LayerNorm,
495
+ xattn: bool = False,
496
+ ):
497
+ super().__init__()
498
+ self.width = width
499
+ self.layers = layers
500
+ self.grad_checkpointing = False
501
+
502
+ self.resblocks = nn.ModuleList([
503
+ ResidualAttentionBlock(
504
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
505
+ for _ in range(layers)
506
+ ])
507
+
508
+ def get_cast_dtype(self) -> torch.dtype:
509
+ return self.resblocks[0].mlp.c_fc.weight.dtype
510
+
511
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
512
+ for r in self.resblocks:
513
+ if self.grad_checkpointing and not torch.jit.is_scripting():
514
+ x = checkpoint(r, x, attn_mask)
515
+ else:
516
+ x = r(x, attn_mask=attn_mask)
517
+ return x
518
+
519
+
520
+ class VisionTransformer(nn.Module):
521
+ def __init__(
522
+ self,
523
+ image_size: int,
524
+ patch_size: int,
525
+ width: int,
526
+ layers: int,
527
+ heads: int,
528
+ mlp_ratio: float,
529
+ ls_init_value: float = None,
530
+ patch_dropout: float = 0.,
531
+ global_average_pool: bool = False,
532
+ output_dim: int = 512,
533
+ act_layer: Callable = nn.GELU,
534
+ norm_layer: Callable = LayerNorm,
535
+ xattn: bool = False,
536
+ ):
537
+ super().__init__()
538
+ self.image_size = to_2tuple(image_size)
539
+ self.patch_size = to_2tuple(patch_size)
540
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
541
+ self.output_dim = output_dim
542
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
543
+
544
+ scale = width ** -0.5
545
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
546
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
547
+
548
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
549
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
550
+ self.ln_pre = norm_layer(width)
551
+
552
+ self.transformer = Transformer(
553
+ width,
554
+ layers,
555
+ heads,
556
+ mlp_ratio,
557
+ ls_init_value=ls_init_value,
558
+ act_layer=act_layer,
559
+ norm_layer=norm_layer,
560
+ xattn=xattn
561
+ )
562
+
563
+ self.global_average_pool = global_average_pool
564
+ self.ln_post = norm_layer(width)
565
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
566
+
567
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
568
+ for param in self.parameters():
569
+ param.requires_grad = False
570
+
571
+ if unlocked_groups != 0:
572
+ groups = [
573
+ [
574
+ self.conv1,
575
+ self.class_embedding,
576
+ self.positional_embedding,
577
+ self.ln_pre,
578
+ ],
579
+ *self.transformer.resblocks[:-1],
580
+ [
581
+ self.transformer.resblocks[-1],
582
+ self.ln_post,
583
+ ],
584
+ self.proj,
585
+ ]
586
+
587
+ def _unlock(x):
588
+ if isinstance(x, Sequence):
589
+ for g in x:
590
+ _unlock(g)
591
+ else:
592
+ if isinstance(x, torch.nn.Parameter):
593
+ x.requires_grad = True
594
+ else:
595
+ for p in x.parameters():
596
+ p.requires_grad = True
597
+
598
+ _unlock(groups[-unlocked_groups:])
599
+
600
+ def get_num_layers(self):
601
+ return self.transformer.layers
602
+
603
+ @torch.jit.ignore
604
+ def set_grad_checkpointing(self, enable=True):
605
+ self.transformer.grad_checkpointing = enable
606
+
607
+ @torch.jit.ignore
608
+ def no_weight_decay(self):
609
+ return {'positional_embedding', 'class_embedding'}
610
+
611
+ def forward(self, x: torch.Tensor, return_all_features: bool=False):
612
+ x = self.conv1(x) # shape = [*, width, grid, grid]
613
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
614
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
615
+ x = torch.cat(
616
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
617
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
618
+ x = x + self.positional_embedding.to(x.dtype)
619
+
620
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
621
+ x = self.patch_dropout(x)
622
+ x = self.ln_pre(x)
623
+
624
+ x = x.permute(1, 0, 2) # NLD -> LND
625
+ x = self.transformer(x)
626
+ x = x.permute(1, 0, 2) # LND -> NLD
627
+
628
+ if not return_all_features:
629
+ if self.global_average_pool:
630
+ x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
631
+ else:
632
+ x = x[:, 0]
633
+
634
+ x = self.ln_post(x)
635
+
636
+ if self.proj is not None:
637
+ x = x @ self.proj
638
+
639
+ return x
640
+
641
+
642
+ class TextTransformer(nn.Module):
643
+ def __init__(
644
+ self,
645
+ context_length: int = 77,
646
+ vocab_size: int = 49408,
647
+ width: int = 512,
648
+ heads: int = 8,
649
+ layers: int = 12,
650
+ ls_init_value: float = None,
651
+ output_dim: int = 512,
652
+ act_layer: Callable = nn.GELU,
653
+ norm_layer: Callable = LayerNorm,
654
+ xattn: bool= False,
655
+ attn_mask: bool = True
656
+ ):
657
+ super().__init__()
658
+ self.context_length = context_length
659
+ self.vocab_size = vocab_size
660
+ self.width = width
661
+ self.output_dim = output_dim
662
+
663
+ self.token_embedding = nn.Embedding(vocab_size, width)
664
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
665
+ self.transformer = Transformer(
666
+ width=width,
667
+ layers=layers,
668
+ heads=heads,
669
+ ls_init_value=ls_init_value,
670
+ act_layer=act_layer,
671
+ norm_layer=norm_layer,
672
+ xattn=xattn
673
+ )
674
+
675
+ self.xattn = xattn
676
+ self.ln_final = norm_layer(width)
677
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
678
+
679
+ if attn_mask:
680
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
681
+ else:
682
+ self.attn_mask = None
683
+
684
+ self.init_parameters()
685
+
686
+ def init_parameters(self):
687
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
688
+ nn.init.normal_(self.positional_embedding, std=0.01)
689
+
690
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
691
+ attn_std = self.transformer.width ** -0.5
692
+ fc_std = (2 * self.transformer.width) ** -0.5
693
+ for block in self.transformer.resblocks:
694
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
695
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
696
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
697
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
698
+
699
+ if self.text_projection is not None:
700
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
701
+
702
+ @torch.jit.ignore
703
+ def set_grad_checkpointing(self, enable=True):
704
+ self.transformer.grad_checkpointing = enable
705
+
706
+ @torch.jit.ignore
707
+ def no_weight_decay(self):
708
+ # return {'positional_embedding', 'token_embedding'}
709
+ return {'positional_embedding'}
710
+
711
+ def get_num_layers(self):
712
+ return self.transformer.layers
713
+
714
+ def build_attention_mask(self):
715
+ # lazily create causal attention mask, with full attention between the vision tokens
716
+ # pytorch uses additive attention mask; fill with -inf
717
+ mask = torch.empty(self.context_length, self.context_length)
718
+ mask.fill_(float("-inf"))
719
+ mask.triu_(1) # zero out the lower diagonal
720
+ return mask
721
+
722
+ def forward(self, text, return_all_features: bool=False):
723
+ cast_dtype = self.transformer.get_cast_dtype()
724
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
725
+
726
+ x = x + self.positional_embedding.to(cast_dtype)
727
+ x = x.permute(1, 0, 2) # NLD -> LND
728
+ x = self.transformer(x, attn_mask=self.attn_mask)
729
+ # x = self.transformer(x) # no attention mask is applied
730
+ x = x.permute(1, 0, 2) # LND -> NLD
731
+ x = self.ln_final(x)
732
+
733
+ if not return_all_features:
734
+ # x.shape = [batch_size, n_ctx, transformer.width]
735
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
736
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
737
+ return x
eva_clip/utils.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn as nn
9
+ from torchvision.ops.misc import FrozenBatchNorm2d
10
+ import torch.nn.functional as F
11
+
12
+ # open CLIP
13
+ def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
14
+ # Rescale the grid of position embeddings when loading from state_dict
15
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
16
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
17
+ return
18
+ grid_size = to_2tuple(model.visual.grid_size)
19
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
20
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
21
+ if new_seq_len == old_pos_embed.shape[0]:
22
+ return
23
+
24
+ if extra_tokens:
25
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
26
+ else:
27
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
28
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
29
+
30
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
31
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
32
+ pos_emb_img = F.interpolate(
33
+ pos_emb_img,
34
+ size=grid_size,
35
+ mode=interpolation,
36
+ align_corners=True,
37
+ )
38
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
39
+ if pos_emb_tok is not None:
40
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
41
+ else:
42
+ new_pos_embed = pos_emb_img
43
+ state_dict['visual.positional_embedding'] = new_pos_embed
44
+
45
+
46
+ def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
47
+ # Rescale the grid of position embeddings when loading from state_dict
48
+ old_pos_embed = state_dict.get('positional_embedding', None)
49
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
50
+ return
51
+ grid_size = to_2tuple(model.visual.grid_size)
52
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
53
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
54
+ if new_seq_len == old_pos_embed.shape[0]:
55
+ return
56
+
57
+ if extra_tokens:
58
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
59
+ else:
60
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
61
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
62
+
63
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
64
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
65
+ pos_emb_img = F.interpolate(
66
+ pos_emb_img,
67
+ size=grid_size,
68
+ mode=interpolation,
69
+ align_corners=True,
70
+ )
71
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
72
+ if pos_emb_tok is not None:
73
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
74
+ else:
75
+ new_pos_embed = pos_emb_img
76
+ state_dict['positional_embedding'] = new_pos_embed
77
+
78
+ def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
79
+ all_keys = list(state_dict.keys())
80
+ # interpolate position embedding
81
+ if 'visual.pos_embed' in state_dict:
82
+ pos_embed_checkpoint = state_dict['visual.pos_embed']
83
+ embedding_size = pos_embed_checkpoint.shape[-1]
84
+ num_patches = model.visual.patch_embed.num_patches
85
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
86
+ # height (== width) for the checkpoint position embedding
87
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
+ # height (== width) for the new position embedding
89
+ new_size = int(num_patches ** 0.5)
90
+ # class_token and dist_token are kept unchanged
91
+ if orig_size != new_size:
92
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
93
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
94
+ # only the position tokens are interpolated
95
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
96
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
97
+ pos_tokens = torch.nn.functional.interpolate(
98
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
99
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
100
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
101
+ state_dict['visual.pos_embed'] = new_pos_embed
102
+
103
+ patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
104
+ patch_size = model.visual.patch_embed.patch_size
105
+ state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
106
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
107
+
108
+
109
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
110
+ all_keys = list(state_dict.keys())
111
+ # interpolate position embedding
112
+ if 'pos_embed' in state_dict:
113
+ pos_embed_checkpoint = state_dict['pos_embed']
114
+ embedding_size = pos_embed_checkpoint.shape[-1]
115
+ num_patches = model.visual.patch_embed.num_patches
116
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
117
+ # height (== width) for the checkpoint position embedding
118
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
119
+ # height (== width) for the new position embedding
120
+ new_size = int(num_patches ** 0.5)
121
+ # class_token and dist_token are kept unchanged
122
+ if orig_size != new_size:
123
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
125
+ # only the position tokens are interpolated
126
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
127
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
128
+ pos_tokens = torch.nn.functional.interpolate(
129
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
130
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
131
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
132
+ state_dict['pos_embed'] = new_pos_embed
133
+
134
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
135
+ patch_size = model.visual.patch_embed.patch_size
136
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
137
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
138
+
139
+
140
+ def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
141
+ all_keys = list(state_dict.keys())
142
+ for key in all_keys:
143
+ if "relative_position_index" in key:
144
+ state_dict.pop(key)
145
+
146
+ if "relative_position_bias_table" in key:
147
+ rel_pos_bias = state_dict[key]
148
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
149
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
150
+ dst_patch_shape = model.visual.patch_embed.patch_shape
151
+ if dst_patch_shape[0] != dst_patch_shape[1]:
152
+ raise NotImplementedError()
153
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
+ if src_size != dst_size:
157
+ print("Position interpolate for %s from %dx%d to %dx%d" % (
158
+ key, src_size, src_size, dst_size, dst_size))
159
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
160
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
161
+
162
+ def geometric_progression(a, r, n):
163
+ return a * (1.0 - r ** n) / (1.0 - r)
164
+
165
+ left, right = 1.01, 1.5
166
+ while right - left > 1e-6:
167
+ q = (left + right) / 2.0
168
+ gp = geometric_progression(1, q, src_size // 2)
169
+ if gp > dst_size // 2:
170
+ right = q
171
+ else:
172
+ left = q
173
+
174
+ # if q > 1.090307:
175
+ # q = 1.090307
176
+
177
+ dis = []
178
+ cur = 1
179
+ for i in range(src_size // 2):
180
+ dis.append(cur)
181
+ cur += q ** (i + 1)
182
+
183
+ r_ids = [-_ for _ in reversed(dis)]
184
+
185
+ x = r_ids + [0] + dis
186
+ y = r_ids + [0] + dis
187
+
188
+ t = dst_size // 2.0
189
+ dx = np.arange(-t, t + 0.1, 1.0)
190
+ dy = np.arange(-t, t + 0.1, 1.0)
191
+
192
+ print("Original positions = %s" % str(x))
193
+ print("Target positions = %s" % str(dx))
194
+
195
+ all_rel_pos_bias = []
196
+
197
+ for i in range(num_attn_heads):
198
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
199
+ f = F.interpolate.interp2d(x, y, z, kind='cubic')
200
+ all_rel_pos_bias.append(
201
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
202
+
203
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
204
+
205
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
206
+ state_dict[key] = new_rel_pos_bias
207
+
208
+ # interpolate position embedding
209
+ if 'pos_embed' in state_dict:
210
+ pos_embed_checkpoint = state_dict['pos_embed']
211
+ embedding_size = pos_embed_checkpoint.shape[-1]
212
+ num_patches = model.visual.patch_embed.num_patches
213
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
214
+ # height (== width) for the checkpoint position embedding
215
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
216
+ # height (== width) for the new position embedding
217
+ new_size = int(num_patches ** 0.5)
218
+ # class_token and dist_token are kept unchanged
219
+ if orig_size != new_size:
220
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
221
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
222
+ # only the position tokens are interpolated
223
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
224
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
225
+ pos_tokens = torch.nn.functional.interpolate(
226
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
227
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
228
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
229
+ state_dict['pos_embed'] = new_pos_embed
230
+
231
+ patch_embed_proj = state_dict['patch_embed.proj.weight']
232
+ patch_size = model.visual.patch_embed.patch_size
233
+ state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
234
+ patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
235
+
236
+
237
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
238
+ """
239
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
240
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
241
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
242
+
243
+ Args:
244
+ module (torch.nn.Module): Any PyTorch module.
245
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
246
+ name (str): Full module name (prefix)
247
+
248
+ Returns:
249
+ torch.nn.Module: Resulting module
250
+
251
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
252
+ """
253
+ res = module
254
+ is_match = True
255
+ if module_match:
256
+ is_match = name in module_match
257
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
258
+ res = FrozenBatchNorm2d(module.num_features)
259
+ res.num_features = module.num_features
260
+ res.affine = module.affine
261
+ if module.affine:
262
+ res.weight.data = module.weight.data.clone().detach()
263
+ res.bias.data = module.bias.data.clone().detach()
264
+ res.running_mean.data = module.running_mean.data
265
+ res.running_var.data = module.running_var.data
266
+ res.eps = module.eps
267
+ else:
268
+ for child_name, child in module.named_children():
269
+ full_child_name = '.'.join([name, child_name]) if name else child_name
270
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
271
+ if new_child is not child:
272
+ res.add_module(child_name, new_child)
273
+ return res
274
+
275
+
276
+ # From PyTorch internals
277
+ def _ntuple(n):
278
+ def parse(x):
279
+ if isinstance(x, collections.abc.Iterable):
280
+ return x
281
+ return tuple(repeat(x, n))
282
+ return parse
283
+
284
+
285
+ to_1tuple = _ntuple(1)
286
+ to_2tuple = _ntuple(2)
287
+ to_3tuple = _ntuple(3)
288
+ to_4tuple = _ntuple(4)
289
+ to_ntuple = lambda n, x: _ntuple(n)(x)
290
+
291
+
292
+ def is_logging(args):
293
+ def is_global_master(args):
294
+ return args.rank == 0
295
+
296
+ def is_local_master(args):
297
+ return args.local_rank == 0
298
+
299
+ def is_master(args, local=False):
300
+ return is_local_master(args) if local else is_global_master(args)
301
+ return is_master
302
+
303
+
304
+ class AllGather(torch.autograd.Function):
305
+ """An autograd function that performs allgather on a tensor.
306
+ Performs all_gather operation on the provided tensors.
307
+ *** Warning ***: torch.distributed.all_gather has no gradient.
308
+ """
309
+
310
+ @staticmethod
311
+ def forward(ctx, tensor, rank, world_size):
312
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
313
+ torch.distributed.all_gather(tensors_gather, tensor)
314
+ ctx.rank = rank
315
+ ctx.batch_size = tensor.shape[0]
316
+ return torch.cat(tensors_gather, 0)
317
+
318
+ @staticmethod
319
+ def backward(ctx, grad_output):
320
+ return (
321
+ grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
322
+ None,
323
+ None
324
+ )
325
+
326
+ allgather = AllGather.apply
example_inputs/hinton.jpeg ADDED
example_inputs/lecun.jpg ADDED
example_inputs/lifeifei.jpg ADDED
example_inputs/liuyifei.png ADDED
example_inputs/pengwei.jpg ADDED

Git LFS Details

  • SHA256: 1d163eb4cc3244e063895263490ee5abc199fe915e6dae9aadbdfb435523644c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
example_inputs/rihanna.webp ADDED
example_inputs/zcy.webp ADDED
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/math.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ if pe is not None:
8
+ q, k = apply_rope(q, k, pe)
9
+
10
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
11
+ x = rearrange(x, "B H L D -> B L (H D)")
12
+
13
+ return x
14
+
15
+
16
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
17
+ assert dim % 2 == 0
18
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
19
+ omega = 1.0 / (theta**scale)
20
+ out = torch.einsum("...n,d->...nd", pos, omega)
21
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
22
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
23
+ return out.float()
24
+
25
+
26
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
27
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
28
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
29
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
30
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
31
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+
15
+ DEVICE = torch.device("cuda")
16
+
17
+ @dataclass
18
+ class FluxParams:
19
+ in_channels: int
20
+ vec_in_dim: int
21
+ context_in_dim: int
22
+ hidden_size: int
23
+ mlp_ratio: float
24
+ num_heads: int
25
+ depth: int
26
+ depth_single_blocks: int
27
+ axes_dim: list[int]
28
+ theta: int
29
+ qkv_bias: bool
30
+ guidance_embed: bool
31
+
32
+
33
+ class Flux(nn.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+
38
+ def __init__(self, params: FluxParams):
39
+ super().__init__()
40
+
41
+ self.params = params
42
+ self.in_channels = params.in_channels
43
+ self.out_channels = self.in_channels
44
+ if params.hidden_size % params.num_heads != 0:
45
+ raise ValueError(
46
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
47
+ )
48
+ pe_dim = params.hidden_size // params.num_heads
49
+ if sum(params.axes_dim) != pe_dim:
50
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
51
+ self.hidden_size = params.hidden_size
52
+ self.num_heads = params.num_heads
53
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
54
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
55
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
56
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
57
+ self.guidance_in = (
58
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
59
+ )
60
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
61
+
62
+ self.double_blocks = nn.ModuleList(
63
+ [
64
+ DoubleStreamBlock(
65
+ self.hidden_size,
66
+ self.num_heads,
67
+ mlp_ratio=params.mlp_ratio,
68
+ qkv_bias=params.qkv_bias,
69
+ )
70
+ for _ in range(params.depth)
71
+ ]
72
+ )
73
+
74
+ self.single_blocks = nn.ModuleList(
75
+ [
76
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
77
+ for _ in range(params.depth_single_blocks)
78
+ ]
79
+ )
80
+
81
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
82
+
83
+ self.pulid_ca = None
84
+ self.pulid_double_interval = 2
85
+ self.pulid_single_interval = 4
86
+
87
+ def forward(
88
+ self,
89
+ img: Tensor,
90
+ img_ids: Tensor,
91
+ txt: Tensor,
92
+ txt_ids: Tensor,
93
+ timesteps: Tensor,
94
+ y: Tensor,
95
+ guidance: Tensor = None,
96
+ id: Tensor = None,
97
+ id_weight: float = 1.0,
98
+ aggressive_offload: bool = False,
99
+ ) -> Tensor:
100
+ if img.ndim != 3 or txt.ndim != 3:
101
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
102
+
103
+ # running on sequences img
104
+ img = self.img_in(img)
105
+ vec = self.time_in(timestep_embedding(timesteps, 256))
106
+ if self.params.guidance_embed:
107
+ if guidance is None:
108
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
109
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
110
+ vec = vec + self.vector_in(y)
111
+ txt = self.txt_in(txt)
112
+
113
+ ids = torch.cat((txt_ids, img_ids), dim=1)
114
+ pe = self.pe_embedder(ids)
115
+
116
+ ca_idx = 0
117
+ if aggressive_offload:
118
+ self.double_blocks = self.double_blocks.to(DEVICE)
119
+ for i, block in enumerate(self.double_blocks):
120
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
121
+
122
+ if i % self.pulid_double_interval == 0 and id is not None:
123
+ img = img + id_weight * self.pulid_ca[ca_idx](id, img)
124
+ ca_idx += 1
125
+ if aggressive_offload:
126
+ self.double_blocks.cpu()
127
+
128
+ img = torch.cat((txt, img), 1)
129
+ if aggressive_offload:
130
+ # put half of the single blcoks to gpu
131
+ for i in range(len(self.single_blocks) // 2):
132
+ self.single_blocks[i] = self.single_blocks[i].to(DEVICE)
133
+ for i, block in enumerate(self.single_blocks):
134
+ if aggressive_offload and i == len(self.single_blocks)//2:
135
+ # put first half of the single blcoks to cpu and last half to gpu
136
+ for j in range(len(self.single_blocks) // 2):
137
+ self.single_blocks[j].cpu()
138
+ for j in range(len(self.single_blocks) // 2, len(self.single_blocks)):
139
+ self.single_blocks[j] = self.single_blocks[j].to(DEVICE)
140
+ x = block(img, vec=vec, pe=pe)
141
+ real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
142
+
143
+ if i % self.pulid_single_interval == 0 and id is not None:
144
+ real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
145
+ ca_idx += 1
146
+
147
+ img = torch.cat((txt, real_img), 1)
148
+ if aggressive_offload:
149
+ self.single_blocks.cpu()
150
+ img = img[:, txt.shape[1] :, ...]
151
+
152
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
153
+ return img
154
+
155
+ def components_to_gpu(self):
156
+ # everything but double_blocks, single_blocks
157
+ self.img_in.to(DEVICE)
158
+ self.time_in.to(DEVICE)
159
+ self.guidance_in.to(DEVICE)
160
+ self.vector_in.to(DEVICE)
161
+ self.txt_in.to(DEVICE)
162
+ self.pe_embedder.to(DEVICE)
163
+ self.final_layer.to(DEVICE)
164
+ if self.pulid_ca:
165
+ self.pulid_ca.to(DEVICE)
flux/modules/__init__.py ADDED
File without changes
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
+
4
+
5
+ class HFEmbedder(nn.Module):
6
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
7
+ super().__init__()
8
+ self.is_clip = version.startswith("openai")
9
+ self.max_length = max_length
10
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
+
12
+ if self.is_clip:
13
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
+ else:
16
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
+
19
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
20
+
21
+ def forward(self, text: list[str]) -> Tensor:
22
+ batch_encoding = self.tokenizer(
23
+ text,
24
+ truncation=True,
25
+ max_length=self.max_length,
26
+ return_length=False,
27
+ return_overflowing_tokens=False,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ outputs = self.hf_module(
33
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
+ attention_mask=None,
35
+ output_hidden_states=False,
36
+ )
37
+ return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
+ self.norm = QKNorm(head_dim)
95
+ self.proj = nn.Linear(dim, dim)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
+
120
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True)
137
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
+
140
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
+ self.img_mlp = nn.Sequential(
142
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
+ nn.GELU(approximate="tanh"),
144
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True)
148
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
+
151
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.txt_mlp = nn.Sequential(
153
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
+ nn.GELU(approximate="tanh"),
155
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x