Spaces:
Running
on
L40S
Running
on
L40S
Akash Garg
commited on
Commit
·
616f571
1
Parent(s):
7e8f630
adding cube sources
Browse files- cube/.gitignore +176 -0
- cube/LICENSE +80 -0
- cube/README.md +166 -0
- cube/SECURITY.md +9 -0
- cube/cube3d/__init__.py +0 -0
- cube/cube3d/colab_cube3d.ipynb +0 -0
- cube/cube3d/configs/open_model.yaml +32 -0
- cube/cube3d/generate.py +144 -0
- cube/cube3d/inference/__init__.py +0 -0
- cube/cube3d/inference/engine.py +499 -0
- cube/cube3d/inference/logits_postprocesses.py +44 -0
- cube/cube3d/inference/utils.py +56 -0
- cube/cube3d/mesh_utils/postprocessing.py +84 -0
- cube/cube3d/model/__init__.py +0 -0
- cube/cube3d/model/autoencoder/__init__.py +0 -0
- cube/cube3d/model/autoencoder/embedder.py +52 -0
- cube/cube3d/model/autoencoder/grid.py +84 -0
- cube/cube3d/model/autoencoder/one_d_autoencoder.py +694 -0
- cube/cube3d/model/autoencoder/spherical_vq.py +159 -0
- cube/cube3d/model/gpt/__init__.py +0 -0
- cube/cube3d/model/gpt/dual_stream_roformer.py +279 -0
- cube/cube3d/model/transformers/__init__.py +0 -0
- cube/cube3d/model/transformers/attention.py +300 -0
- cube/cube3d/model/transformers/cache.py +9 -0
- cube/cube3d/model/transformers/dual_stream_attention.py +348 -0
- cube/cube3d/model/transformers/norm.py +46 -0
- cube/cube3d/model/transformers/roformer.py +229 -0
- cube/cube3d/model/transformers/rope.py +91 -0
- cube/cube3d/renderer/blender_script.py +723 -0
- cube/cube3d/renderer/renderer.py +88 -0
- cube/cube3d/vq_vae_encode_decode.py +150 -0
- cube/pyproject.toml +38 -0
- cube/setup.py +6 -0
cube/.gitignore
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
.vscode/*
|
171 |
+
|
172 |
+
.DS_Store
|
173 |
+
|
174 |
+
# Output folder
|
175 |
+
outputs/
|
176 |
+
model_weights/
|
cube/LICENSE
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUBE3D RESEARCH-ONLY RAIL-MS LICENSE
|
2 |
+
|
3 |
+
Licensed Artifacts:
|
4 |
+
Cube3d-v0.1 and related inference code
|
5 |
+
|
6 |
+
I. SCOPE
|
7 |
+
This Research-Only RAIL License is generally applicable to the Artifacts identified above.
|
8 |
+
For valuable consideration, You and Licensor agree as follows:
|
9 |
+
1. Definitions
|
10 |
+
(a) “Artifact” means a software application (in either binary or source code format), Model, or Source Code, in accordance with what are specified above as the “Licensed Artifacts.”
|
11 |
+
(b) “Contribution” means any work, including any modifications or additions to an Artifact, that is intentionally submitted to Licensor for inclusion or incorporation in the Artifact directly or indirectly by the rights owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing, sharing and improving the Artifact, but excluding communication that is conspicuously marked or otherwise designated in writing by the contributor as “Not a Contribution.”
|
12 |
+
(c) “Contributor” means Licensor or any other individual or legal entity that creates or owns a Contribution that is added to or incorporated into an Artifact or Derivative.
|
13 |
+
(d) “Data” means a collection of information or content extracted from the dataset used with a given Model, including to train, pretrain, or otherwise evaluate the Model.
|
14 |
+
(e) “Derivative” means a work derived from or based upon an Artifact, and includes all modified versions of such Artifact.
|
15 |
+
(f) “Distribution” means any transmission, reproduction, publication or other sharing of an Artifact or Derivative to a third party, including providing a hosted service incorporating the Artifact, which is made available by electronic or other remote means—e.g., API-based or web access.
|
16 |
+
(g) “License” means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
17 |
+
(h) “Licensor” means the rights owner (by virtue of creation or documented transfer of ownership) or entity authorized by the rights owner (e.g., exclusive licensee) that is granting the rights in this License.
|
18 |
+
(i) “Model” means any machine-learning based assembly or assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Source Code.
|
19 |
+
(j) “Output” means the results of operating a Model as embodied in informational content resulting therefrom.
|
20 |
+
(k) “Permitted Purpose” means for academic or research purposes only.
|
21 |
+
(l) “Source Code” means any collection of text written using human-readable programming language, including the code and scripts used to define, run, load, benchmark or evaluate a Model or any component thereof, or used to prepare data for training or evaluation. Source Code includes any accompanying documentation, tutorials and examples. For clarity, the term “Source Code” as used in this License includes any and all Derivatives of such Source Code.
|
22 |
+
(m) “Third Party” means any individual or legal entity that is not under common control with Licensor or You.
|
23 |
+
(n) “Use,” with respect to an Artifact, means accessing, using, copying, modifying, distributing, and making available the Artifact; in connection with a Model as Artifact, Use also includes creating content, fine-tuning, updating, running, training, evaluating and re-parametrizing such Model.
|
24 |
+
(o) “You” (or “Your”) means an individual or legal entity receiving and exercising permissions granted by this License or making use of the Artifact for the Permitted Purpose and in any permitted field of use, including usage of the Artifact in an end-use application.
|
25 |
+
|
26 |
+
II. INTELLECTUAL PROPERTY RIGHTS
|
27 |
+
1. Both copyright and patent grants may apply to the Artifacts. The Artifacts are subject to additional terms as described in Section III below, which govern the Use of the Artifacts in the event that Section II is held unenforceable or inapplicable.
|
28 |
+
2. Grant of Copyright License. Conditioned upon compliance with Section III below and subject to the terms and conditions of this License, each Contributor hereby grants to You, only in connection with the Permitted Purpose, a worldwide, non-exclusive, royalty-free copyright license to reproduce, publicly display, publicly perform, distribute, and make derivatives of the Artifacts.
|
29 |
+
3. Grant of Patent License. Conditioned upon compliance with Section III below and subject to the terms and conditions of this License, and only where and as applicable, each Contributor hereby grants to You, only in connection with the Permitted Purpose, a worldwide, non-exclusive, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, sell, offer to sell, import, and otherwise transfer the Artifacts where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contributions alone or by combination of their Contributions with the Artifact to which such Contribution was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that an Artifact or Contribution constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License in connection with the Artifact shall terminate as of the date such litigation is asserted or filed.
|
30 |
+
4. Licensor and Contributor each have the right to grant the licenses above.
|
31 |
+
5. The Data is not licensed under this License.
|
32 |
+
|
33 |
+
III. CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
34 |
+
1. Use-based restrictions. The restrictions set forth in Attachment A are mandatory Use-based restrictions. Therefore You may not Use any Artifact in violation of such restrictions. You may Use Artifacts only subject to this License. You shall require all of Your users who Use the Artifacts or Derivatives to comply with the terms of this paragraph and only Use the Artifacts and Derivatives for the Permitted Purpose.
|
35 |
+
2. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output generated by You or Your users. You are accountable for the Output You generate and its subsequent uses. No use of the Output may contravene any provision as stated in this License.
|
36 |
+
3. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce, distribute, and make available Artifacts and Derivatives in any medium, with or without modifications, provided that You meet the following conditions:
|
37 |
+
(a) Use-based restrictions in Paragraph III.1 MUST be included as a condition precedent to effect any type of legal agreement (e.g., a license) governing the Use of the Artifacts and Derivatives, and You shall give such notice to any subsequent Third Party recipients.
|
38 |
+
(b) You shall give any Third Party recipients of any Artifacts or Derivatives a copy of this License.
|
39 |
+
(c) You shall cause any modified files to carry prominent notices stating that You changed the files.
|
40 |
+
(d) You shall retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Artifacts or Derivatives.
|
41 |
+
(e) You and any Third Party recipients of any Artifacts or Derivatives shall adhere to the Permitted Purpose.
|
42 |
+
4. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions with respect to Paragraph III.3(a), to govern the Use or Distribution of Your modifications, or for any Derivative, provided that Your Use and Distribution of the Artifacts or their Derivatives otherwise complies with the conditions stated in this License. In other words, the Use-based restrictions referred to in Paragraph III.1 form the minimum set of terms for You to license to Third Parties any Artifacts or Derivatives, but You may add more restrictive terms if You deem it necessary.
|
43 |
+
|
44 |
+
IV. OTHER PROVISIONS
|
45 |
+
1. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of Artifacts in violation of this License or update Artifacts through electronic means.
|
46 |
+
2. Trademarks. Nothing in this License permits You to make use of Licensor’s trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between You and Licensor; and any rights not expressly granted herein are reserved by the Licensor.
|
47 |
+
3. DISCLAIMER OF WARRANTY. UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING, LICENSOR PROVIDES THE ARTIFACT (AND EACH CONTRIBUTOR PROVIDES ITS CONTRIBUTIONS) ON AN “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING THE ARTIFACT, AND ASSUME ANY RISKS ASSOCIATED WITH YOUR EXERCISE OF PERMISSIONS UNDER THIS LICENSE.
|
48 |
+
4. LIMITATION OF LIABILITY. IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW (SUCH AS DELIBERATE AND GROSSLY NEGLIGENT ACTS) OR AGREED TO IN WRITING, SHALL ANY CONTRIBUTOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER ARISING AS A RESULT OF THIS LICENSE OR OUT OF THE USE OR INABILITY TO USE THE ARTIFACT (INCLUDING BUT NOT LIMITED TO DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF SUCH CONTRIBUTOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
49 |
+
5. Severability. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
50 |
+
6. Term and Termination. The term of this License will commence upon the earlier of (a) Your acceptance of this License or (b) accessing the Artifact; and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Licensor may terminate this License if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of the Artifacts. This paragraph shall survive the termination of this License.
|
51 |
+
|
52 |
+
Attachment A – Use Restrictions
|
53 |
+
1. Discrimination. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
54 |
+
(a) to discriminate, mock, or promote hatred against individuals or groups, or encourage others to do so directly or indirectly, on the basis of their age; race, perceived race, or ethnicity; national origin; sexual orientation; gender, gender identity, or gender expression; religion or religious affiliation or beliefs; disability status including diseases, bodily conditions, disfigurement, mobility issues, and mental impairment; veteran status; caste; or familial status.
|
55 |
+
(b) to exploit any of the vulnerabilities of an individual or specific group of persons based on their age, social, physical, or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm.
|
56 |
+
(c) to engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, or other essential goods and services.
|
57 |
+
2. Intellectual Property. You agree not to, and not to allow others to Use Artifacts or Derivatives
|
58 |
+
(a) to infringe or attempt to infringe, misappropriate or otherwise violate any intellectual property rights of Licensor or any Third Party;
|
59 |
+
(b) synthesize or modify a natural person’s appearance, voice, or other individual characteristics, unless prior informed consent of said natural person is obtained; or
|
60 |
+
(c) to reverse engineer, disassemble, decompile, or otherwise attempt to derive or gain access to Data that was used to create, train, pretrain, or otherwise evaluate such Artifacts or Derivatives.
|
61 |
+
3. Legal. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
62 |
+
(a) in any way that violates any applicable national, federal, state, local or international law or regulation;
|
63 |
+
(b) to engage in, facilitate, or assist in the planning or development of criminal activities; or
|
64 |
+
(c) to generate unlawful content.
|
65 |
+
4. Disinformation. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
66 |
+
(a) to create, present or disseminate false or misleading information for economic gain or to intentionally deceive the public, including creating false impersonations of natural persons;
|
67 |
+
(b) to defame or harm a person’s reputation, such as by generating, creating, promoting, or spreading defamatory content.
|
68 |
+
5. Privacy. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
69 |
+
(a) to engage in, promote, incite, or facilitate the harassment, abuse, threatening or bullying of individuals or groups of individuals; or
|
70 |
+
(b) in connection with personal information to infer additional personal information about a natural person, including but not limited to legally protected characteristics, vulnerabilities or categories; unless informed consent from the data subject to collect said inferred personal information for a stated purpose and defined duration is received.
|
71 |
+
6. Health and Safety. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
72 |
+
(a) to provide health or medical advice, medical results interpretation, or make clinical decisions; or
|
73 |
+
(b) in connection with any activities that present a risk of death or bodily harm to individuals, including self-harm or harm to others, or in connection with regulated or controlled substances.
|
74 |
+
7. Military or Law Enforcement. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
75 |
+
(a) for purposes of administration of justice, law enforcement, immigration, or asylum processes, such as predicting that a natural person will commit a crime or the likelihood thereof;
|
76 |
+
(b) for weaponry or warfare; for building or optimizing weapons; or in service of nuclear proliferation or nuclear weapons technology; or
|
77 |
+
(c) military surveillance, including any research or development relating to military surveillance.
|
78 |
+
8. General. You agree not to Use, or allow others to Use, Artifacts or Derivatives
|
79 |
+
(a) in any manner that would constitute high risk, restricted, or prohibited of AI under applicable law; or
|
80 |
+
(b) to generate or disseminate malware or ransomware or to otherwise harm electronic systems.
|
cube/README.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cube: Generative AI System for 3D
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="./resources/teaser.png" width="800" style="margin: 5px;">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
<div align="center">
|
8 |
+
<a href=https://corp.roblox.com/newsroom/2025/03/introducing-roblox-cube target="_blank"><img src=https://img.shields.io/badge/Roblox-Blog-000000.svg?logo=Roblox height=22px></a>
|
9 |
+
<a href=https://huggingface.co/Roblox/cube3d-0.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-d96902.svg height=22px></a>
|
10 |
+
<a href=https://arxiv.org/abs/2503.15475 target="_blank"><img src=https://img.shields.io/badge/ArXiv-Report-b5212f.svg?logo=arxiv height=22px></a>
|
11 |
+
<a href=https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t target="_blank"><img src=https://img.shields.io/badge/Google-Open_In_Colab-blue.svg?logo=googlecolab height=22px></a>
|
12 |
+
</div>
|
13 |
+
|
14 |
+
|
15 |
+
Foundation models trained on vast amounts of data have demonstrated remarkable reasoning and
|
16 |
+
generation capabilities in the domains of text, images, audio and video. Our goal is to build
|
17 |
+
such a foundation model for 3D intelligence, a model that can support developers in producing all aspects
|
18 |
+
of a Roblox experience, from generating 3D objects and scenes to rigging characters for animation to
|
19 |
+
producing programmatic scripts describing object behaviors. As we start open-sourcing a family of models
|
20 |
+
towards this vision, we hope to engage others in the research community to address these goals with us.
|
21 |
+
|
22 |
+
## Get Started with Cube 3D
|
23 |
+
|
24 |
+
<p align="center">
|
25 |
+
<img src="./resources/greyscale_512.gif" width="600" style="margin: 5px;">
|
26 |
+
</p>
|
27 |
+
|
28 |
+
Cube 3D is our first step towards 3D intelligence, which involves a shape tokenizer and a text-to-shape generation model. We are unlocking the power of generating 3D assets and enhancing creativity for all artists. Our latest version of Cube 3D is now accessible to individuals, creators, researchers and businesses of all sizes so that they can experiment, innovate and scale their ideas responsibly. This release includes model weights and starting code for using our text-to-shape model to create 3D assets.
|
29 |
+
|
30 |
+
### Try it out on [Google Colab](https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t)
|
31 |
+
|
32 |
+
### Install Requirements
|
33 |
+
|
34 |
+
Clone and install this repo in a virtual environment, via:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
git clone https://github.com/Roblox/cube.git
|
38 |
+
cd cube
|
39 |
+
pip install -e .[meshlab]
|
40 |
+
```
|
41 |
+
|
42 |
+
> **CUDA**: If you are using a Windows machine, you may need to install the [CUDA](https://developer.nvidia.com/cuda-downloads) toolkit as well as `torch` with cuda support via `pip install torch --index-url https://download.pytorch.org/whl/cu124 --force-reinstall`
|
43 |
+
|
44 |
+
> **Note**: `[meshlab]` is an optional dependency and can be removed by simply running `pip install -e .` for better compatibility but mesh simplification will be disabled.
|
45 |
+
|
46 |
+
### Download Models from Huggingface 🤗
|
47 |
+
|
48 |
+
Download the model weights from [hugging face](https://huggingface.co/Roblox/cube3d-v0.1) or use the
|
49 |
+
`huggingface-cli`:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
huggingface-cli download Roblox/cube3d-v0.1 --local-dir ./model_weights
|
53 |
+
```
|
54 |
+
|
55 |
+
### Inference
|
56 |
+
|
57 |
+
#### 1. Shape Generation
|
58 |
+
|
59 |
+
To generate 3D models using the downloaded models simply run:
|
60 |
+
|
61 |
+
```bash
|
62 |
+
python -m cube3d.generate \
|
63 |
+
--gpt-ckpt-path model_weights/shape_gpt.safetensors \
|
64 |
+
--shape-ckpt-path model_weights/shape_tokenizer.safetensors \
|
65 |
+
--fast-inference \
|
66 |
+
--prompt "Broad-winged flying red dragon, elongated, folded legs."
|
67 |
+
```
|
68 |
+
|
69 |
+
> **Note**: `--fast-inference` is optional and may not be available for all GPU that have limited VRAM. This flag will also not work on MacOS.
|
70 |
+
|
71 |
+
The output will be an `.obj` file saved in the specified `output` directory.
|
72 |
+
|
73 |
+
If you want to render a turntable gif of the mesh, you can use the `--render-gif` flag, which will render a turntable gif of the mesh
|
74 |
+
and save it as `turntable.gif` in the specified `output` directory.
|
75 |
+
|
76 |
+
We provide several example output objects and their corresponding text prompts in the `examples` folder.
|
77 |
+
|
78 |
+
> **Note**: You must have Blender installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
|
79 |
+
|
80 |
+
> **Note**: If shape decoding is slow, you can try to specify a lower resolution using the `--resolution-base` flag. A lower resolution will create a coarser and lower quality output mesh but faster decoding. Values between 4.0 and 9.0 are recommended.
|
81 |
+
|
82 |
+
#### 2. Shape Tokenization and De-tokenization
|
83 |
+
|
84 |
+
To tokenize a 3D shape into token indices and reconstruct it back, you can use the following command:
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python -m cube3d.vq_vae_encode_decode \
|
88 |
+
--shape-ckpt-path model_weights/shape_tokenizer.safetensors \
|
89 |
+
--mesh-path ./outputs/output.obj
|
90 |
+
```
|
91 |
+
|
92 |
+
This will process the `.obj` file located at `./outputs/output.obj` and prints the tokenized representation as well as exports the mesh reconstructed from the token indices.
|
93 |
+
|
94 |
+
### Hardware Requirements
|
95 |
+
|
96 |
+
We have tested our model on:
|
97 |
+
* Nvidia H100 GPU
|
98 |
+
* Nvidia A100 GPU
|
99 |
+
* Nvidia Geforce 3080
|
100 |
+
* Apple Silicon M2-4 Chips.
|
101 |
+
|
102 |
+
We recommend using a GPU with at least 24GB of VRAM available when using `--fast-inference` (or `EngineFast`) and 16GB otherwise.
|
103 |
+
|
104 |
+
### Code Usage
|
105 |
+
|
106 |
+
We have designed a minimalist API that allows the use this repo as a Python library:
|
107 |
+
|
108 |
+
```python
|
109 |
+
import torch
|
110 |
+
import trimesh
|
111 |
+
from cube3d.inference.engine import Engine, EngineFast
|
112 |
+
|
113 |
+
# load ckpt
|
114 |
+
config_path = "cube3d/configs/open_model.yaml"
|
115 |
+
gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
|
116 |
+
shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
|
117 |
+
engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise
|
118 |
+
config_path,
|
119 |
+
gpt_ckpt_path,
|
120 |
+
shape_ckpt_path,
|
121 |
+
device=torch.device("cuda"),
|
122 |
+
)
|
123 |
+
|
124 |
+
# inference
|
125 |
+
input_prompt = "A pair of noise-canceling headphones"
|
126 |
+
# NOTE: Reduce `resolution_base` for faster inference and lower VRAM usage
|
127 |
+
# The `top_k` parameter controls randomness between inferences:
|
128 |
+
# - A value of 1 yields deterministic results.
|
129 |
+
# - Higher values introduce more randomness.
|
130 |
+
mesh_v_f = engine_fast.t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_k=5)
|
131 |
+
|
132 |
+
# save output
|
133 |
+
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
134 |
+
_ = trimesh.Trimesh(vertices=vertices, faces=faces).export("output.obj")
|
135 |
+
```
|
136 |
+
|
137 |
+
## Coming Soon
|
138 |
+
|
139 |
+
### Controlling shape generation with bounding box conditioning
|
140 |
+
<div align="center">
|
141 |
+
<img src="./resources/truck_black_text_512.gif" width="300" height="300" style="margin: 5px;">
|
142 |
+
<img src="./resources/couch_black_text_512.gif" width="300" height="300" style="margin: 5px;">
|
143 |
+
</div>
|
144 |
+
|
145 |
+
### Scene Generation
|
146 |
+
|
147 |
+
https://github.com/user-attachments/assets/987c459a-5708-41a5-9b92-89068a70a239
|
148 |
+
|
149 |
+
https://github.com/user-attachments/assets/ab501a86-b0cb-4c73-827e-988b2120d4c0
|
150 |
+
|
151 |
+
## Citation
|
152 |
+
If you find this work helpful, please consider citing our technical report:
|
153 |
+
|
154 |
+
```bibtex
|
155 |
+
@article{roblox2025cube,
|
156 |
+
title = {Cube: A Roblox View of 3D Intelligence},
|
157 |
+
author = {Roblox, Foundation AI Team},
|
158 |
+
journal = {arXiv preprint arXiv:2503.15475},
|
159 |
+
year = {2025}
|
160 |
+
}
|
161 |
+
```
|
162 |
+
|
163 |
+
## Acknowledgements
|
164 |
+
|
165 |
+
We would like to thank the contributors of [TRELLIS](https://github.com/microsoft/TRELLIS), [CraftsMan3D](https://github.com/wyysf-98/CraftsMan3D), [threestudio](https://github.com/threestudio-project/threestudio), [Hunyuan3D-2](https://github.com/Tencent/Hunyuan3D-2), [minGPT](https://github.com/karpathy/minGPT), [dinov2](https://github.com/facebookresearch/dinov2), [OptVQ](https://github.com/zbr17/OptVQ), [1d-tokenizer](https://github.com/bytedance/1d-tokenizer)
|
166 |
+
repositories, for their open source contributions.
|
cube/SECURITY.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Security Policy
|
2 |
+
|
3 |
+
## Reporting a Vulnerability
|
4 |
+
|
5 |
+
If you discover a security vulnerability in this repository, we appreciate your help in ensuring that the issue is addressed quickly.
|
6 |
+
|
7 |
+
Report any vulnerabilities found to our bug bounty program on HackerOne: https://hackerone.com/roblox
|
8 |
+
|
9 |
+
Please **do not create a public issue in this repo**.
|
cube/cube3d/__init__.py
ADDED
File without changes
|
cube/cube3d/colab_cube3d.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cube/cube3d/configs/open_model.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gpt_model:
|
2 |
+
n_layer: 23
|
3 |
+
n_single_layer: 1
|
4 |
+
rope_theta: 10000
|
5 |
+
n_head: 12
|
6 |
+
n_embd: 1536
|
7 |
+
bias: true
|
8 |
+
eps: 1.e-6
|
9 |
+
shape_model_vocab_size: 16384
|
10 |
+
text_model_embed_dim: 768
|
11 |
+
use_pooled_text_embed: False
|
12 |
+
shape_model_embed_dim: 32
|
13 |
+
encoder_with_cls_token: true
|
14 |
+
|
15 |
+
shape_model:
|
16 |
+
encoder_with_cls_token: true
|
17 |
+
num_encoder_latents: 512
|
18 |
+
num_decoder_latents: 0
|
19 |
+
embed_dim: 32
|
20 |
+
width: 768
|
21 |
+
num_heads: 12
|
22 |
+
out_dim: 1
|
23 |
+
eps: 1.e-6
|
24 |
+
num_freqs: 128
|
25 |
+
point_feats: 3
|
26 |
+
embed_point_feats: false
|
27 |
+
num_encoder_layers: 13
|
28 |
+
encoder_cross_attention_levels: [0, 2, 4, 8]
|
29 |
+
num_decoder_layers: 24
|
30 |
+
num_codes: 16384
|
31 |
+
|
32 |
+
text_model_pretrained_model_name_or_path: "openai/clip-vit-large-patch14"
|
cube/cube3d/generate.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import trimesh
|
6 |
+
|
7 |
+
from cube3d.inference.engine import Engine, EngineFast
|
8 |
+
from cube3d.mesh_utils.postprocessing import (
|
9 |
+
PYMESHLAB_AVAILABLE,
|
10 |
+
create_pymeshset,
|
11 |
+
postprocess_mesh,
|
12 |
+
save_mesh,
|
13 |
+
)
|
14 |
+
from cube3d.renderer import renderer
|
15 |
+
|
16 |
+
def generate_mesh(
|
17 |
+
engine,
|
18 |
+
prompt,
|
19 |
+
output_dir,
|
20 |
+
output_name,
|
21 |
+
resolution_base=8.0,
|
22 |
+
disable_postprocess=False,
|
23 |
+
top_k: int = 1,
|
24 |
+
):
|
25 |
+
mesh_v_f = engine.t2s(
|
26 |
+
[prompt],
|
27 |
+
use_kv_cache=True,
|
28 |
+
resolution_base=resolution_base,
|
29 |
+
top_k=top_k,
|
30 |
+
)
|
31 |
+
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
32 |
+
obj_path = os.path.join(output_dir, f"{output_name}.obj")
|
33 |
+
if PYMESHLAB_AVAILABLE:
|
34 |
+
ms = create_pymeshset(vertices, faces)
|
35 |
+
if not disable_postprocess:
|
36 |
+
target_face_num = max(10000, int(faces.shape[0] * 0.1))
|
37 |
+
print(f"Postprocessing mesh to {target_face_num} faces")
|
38 |
+
postprocess_mesh(ms, target_face_num, obj_path)
|
39 |
+
|
40 |
+
save_mesh(ms, obj_path)
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
|
44 |
+
)
|
45 |
+
mesh = trimesh.Trimesh(vertices, faces)
|
46 |
+
mesh.export(obj_path)
|
47 |
+
|
48 |
+
return obj_path
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
parser = argparse.ArgumentParser(description="cube shape generation script")
|
53 |
+
parser.add_argument(
|
54 |
+
"--config-path",
|
55 |
+
type=str,
|
56 |
+
default="cube3d/configs/open_model.yaml",
|
57 |
+
help="Path to the configuration YAML file.",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--output-dir",
|
61 |
+
type=str,
|
62 |
+
default="outputs/",
|
63 |
+
help="Path to the output directory to store .obj and .gif files",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--gpt-ckpt-path",
|
67 |
+
type=str,
|
68 |
+
required=True,
|
69 |
+
help="Path to the main GPT checkpoint file.",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--shape-ckpt-path",
|
73 |
+
type=str,
|
74 |
+
required=True,
|
75 |
+
help="Path to the shape encoder/decoder checkpoint file.",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--fast-inference",
|
79 |
+
help="Use optimized inference",
|
80 |
+
default=False,
|
81 |
+
action="store_true",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--prompt",
|
85 |
+
type=str,
|
86 |
+
required=True,
|
87 |
+
help="Text prompt for generating a 3D mesh",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--top-k",
|
91 |
+
type=int,
|
92 |
+
default=1,
|
93 |
+
help="Top k filtering, 0 means no filtering, by default 1, which is determistic.",
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--render-gif",
|
97 |
+
help="Render a turntable gif of the mesh",
|
98 |
+
default=False,
|
99 |
+
action="store_true",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--disable-postprocessing",
|
103 |
+
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
|
104 |
+
default=False,
|
105 |
+
action="store_true",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--resolution-base",
|
109 |
+
type=float,
|
110 |
+
default=8.0,
|
111 |
+
help="Resolution base for the shape decoder.",
|
112 |
+
)
|
113 |
+
args = parser.parse_args()
|
114 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
115 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
116 |
+
print(f"Using device: {device}")
|
117 |
+
# Initialize engine based on fast_inference flag
|
118 |
+
if args.fast_inference:
|
119 |
+
print(
|
120 |
+
"Using cuda graphs, this will take some time to warmup and capture the graph."
|
121 |
+
)
|
122 |
+
engine = EngineFast(
|
123 |
+
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
|
124 |
+
)
|
125 |
+
print("Compiled the graph.")
|
126 |
+
else:
|
127 |
+
engine = Engine(
|
128 |
+
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
|
129 |
+
)
|
130 |
+
|
131 |
+
# Generate meshes based on input source
|
132 |
+
obj_path = generate_mesh(
|
133 |
+
engine,
|
134 |
+
args.prompt,
|
135 |
+
args.output_dir,
|
136 |
+
"output",
|
137 |
+
args.resolution_base,
|
138 |
+
args.disable_postprocessing,
|
139 |
+
args.top_k,
|
140 |
+
)
|
141 |
+
if args.render_gif:
|
142 |
+
gif_path = renderer.render_turntable(obj_path, args.output_dir)
|
143 |
+
print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`")
|
144 |
+
print(f"Generated mesh for {args.prompt} at `{obj_path}`")
|
cube/cube3d/inference/__init__.py
ADDED
File without changes
|
cube/cube3d/inference/engine.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
|
4 |
+
|
5 |
+
from cube3d.inference.logits_postprocesses import process_logits
|
6 |
+
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
|
7 |
+
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
|
8 |
+
from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer
|
9 |
+
from cube3d.model.transformers.cache import Cache
|
10 |
+
|
11 |
+
|
12 |
+
class Engine:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
config_path: str,
|
16 |
+
gpt_ckpt_path: str,
|
17 |
+
shape_ckpt_path: str,
|
18 |
+
device: torch.device,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Initializes the inference engine with the given configuration and checkpoint paths.
|
22 |
+
Args:
|
23 |
+
config_path (str): Path to the configuration file.
|
24 |
+
gpt_ckpt_path (str): Path to the GPT model checkpoint file.
|
25 |
+
shape_ckpt_path (str): Path to the shape model checkpoint file.
|
26 |
+
device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
|
27 |
+
Attributes:
|
28 |
+
cfg (dict): Loaded configuration from the config file.
|
29 |
+
device (torch.device): The device to run the models on.
|
30 |
+
gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights.
|
31 |
+
shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights.
|
32 |
+
text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model.
|
33 |
+
text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model.
|
34 |
+
max_new_tokens (int): Maximum number of new tokens for the shape model.
|
35 |
+
min_id (int): Minimum ID for the shape model codes.
|
36 |
+
max_id (int): Maximum ID for the shape model codes.
|
37 |
+
"""
|
38 |
+
|
39 |
+
self.cfg = load_config(config_path)
|
40 |
+
self.device = device
|
41 |
+
|
42 |
+
self.gpt_model = DualStreamRoformer(
|
43 |
+
parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
|
44 |
+
)
|
45 |
+
load_model_weights(
|
46 |
+
self.gpt_model,
|
47 |
+
gpt_ckpt_path,
|
48 |
+
)
|
49 |
+
self.gpt_model = self.gpt_model.eval().to(self.device)
|
50 |
+
|
51 |
+
self.shape_model = OneDAutoEncoder(
|
52 |
+
parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
|
53 |
+
)
|
54 |
+
load_model_weights(
|
55 |
+
self.shape_model,
|
56 |
+
shape_ckpt_path,
|
57 |
+
)
|
58 |
+
self.shape_model = self.shape_model.eval().to(self.device)
|
59 |
+
|
60 |
+
# copy vq codebook to gpt
|
61 |
+
with torch.no_grad():
|
62 |
+
codebook = self.shape_model.bottleneck.block.get_codebook()
|
63 |
+
codebook = self.gpt_model.shape_proj(codebook).detach()
|
64 |
+
self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook
|
65 |
+
|
66 |
+
self.text_model = CLIPTextModelWithProjection.from_pretrained(
|
67 |
+
self.cfg.text_model_pretrained_model_name_or_path,
|
68 |
+
force_download=False,
|
69 |
+
device_map=self.device,
|
70 |
+
).eval()
|
71 |
+
self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
|
72 |
+
self.cfg.text_model_pretrained_model_name_or_path
|
73 |
+
)
|
74 |
+
|
75 |
+
self.max_new_tokens = self.shape_model.cfg.num_encoder_latents
|
76 |
+
self.min_id = 0
|
77 |
+
self.max_id = self.shape_model.cfg.num_codes
|
78 |
+
|
79 |
+
@torch.inference_mode()
|
80 |
+
def prepare_inputs(self, prompts: list[str], guidance_scale: float):
|
81 |
+
"""
|
82 |
+
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
|
83 |
+
Args:
|
84 |
+
prompts (list[str]): A list of prompt strings to be encoded.
|
85 |
+
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
|
86 |
+
Returns:
|
87 |
+
tuple: A tuple containing:
|
88 |
+
- embed (torch.Tensor): The encoded input embeddings.
|
89 |
+
- cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
|
90 |
+
"""
|
91 |
+
|
92 |
+
prompt_embeds = self.run_clip(prompts)
|
93 |
+
|
94 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
95 |
+
embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
|
96 |
+
|
97 |
+
cond = prompt_embeds
|
98 |
+
if guidance_scale > 0.0:
|
99 |
+
embed = torch.cat([embed, embed], dim=0)
|
100 |
+
uncond_embeds = self.run_clip([""] * len(prompts))
|
101 |
+
cond = torch.cat([prompt_embeds, uncond_embeds], dim=0)
|
102 |
+
|
103 |
+
return embed, cond
|
104 |
+
|
105 |
+
@torch.inference_mode()
|
106 |
+
def run_clip(self, text_inputs):
|
107 |
+
"""
|
108 |
+
Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings.
|
109 |
+
Args:
|
110 |
+
text_inputs (str or List[str]): The input text or list of texts to be processed.
|
111 |
+
Returns:
|
112 |
+
torch.Tensor: The encoded text embeddings.
|
113 |
+
"""
|
114 |
+
|
115 |
+
text_inputs = self.text_tokenizer(
|
116 |
+
text_inputs,
|
117 |
+
max_length=self.text_tokenizer.model_max_length,
|
118 |
+
padding="max_length",
|
119 |
+
truncation=True,
|
120 |
+
return_tensors="pt",
|
121 |
+
)
|
122 |
+
with torch.no_grad():
|
123 |
+
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
124 |
+
# use full precision for text encoder
|
125 |
+
with torch.autocast(device_type=self.device.type, enabled=False):
|
126 |
+
encoded = self.text_model(**text_inputs)
|
127 |
+
if self.gpt_model.cfg.use_pooled_text_embed:
|
128 |
+
embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512]
|
129 |
+
else:
|
130 |
+
embed = encoded.last_hidden_state # [bs, 77, 512]
|
131 |
+
embed = self.gpt_model.encode_text(embed)
|
132 |
+
|
133 |
+
return embed
|
134 |
+
|
135 |
+
@torch.inference_mode()
|
136 |
+
def encode_input(self, inputs: torch.Tensor, bos: int):
|
137 |
+
"""
|
138 |
+
Encodes the beginning of sequence (BOS) token for the given input tensor.
|
139 |
+
Args:
|
140 |
+
inputs (torch.Tensor): The input tensor containing sequences.
|
141 |
+
bos (int): The beginning of sequence token ID.
|
142 |
+
Returns:
|
143 |
+
torch.Tensor: The encoded BOS token embeddings.
|
144 |
+
"""
|
145 |
+
|
146 |
+
b = inputs.shape[0]
|
147 |
+
bos_embed = self.gpt_model.encode_token(
|
148 |
+
torch.full(
|
149 |
+
(b, 1),
|
150 |
+
fill_value=bos,
|
151 |
+
dtype=torch.long,
|
152 |
+
device=self.device,
|
153 |
+
)
|
154 |
+
)
|
155 |
+
return bos_embed
|
156 |
+
|
157 |
+
@torch.inference_mode()
|
158 |
+
def run_gpt(
|
159 |
+
self,
|
160 |
+
prompts: list[str],
|
161 |
+
use_kv_cache: bool,
|
162 |
+
guidance_scale: float = 3.0,
|
163 |
+
top_k: int = 1,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
Generates text using a GPT model based on the provided prompts.
|
167 |
+
Args:
|
168 |
+
prompts (list[str]): A list of input prompts to generate text from.
|
169 |
+
use_kv_cache (bool): Whether to use key-value caching for faster generation.
|
170 |
+
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
|
171 |
+
top_k : (int, optional): Top k filtering, 0 means no filtering, by default 1.
|
172 |
+
Returns:
|
173 |
+
torch.Tensor: A tensor containing the generated token IDs.
|
174 |
+
"""
|
175 |
+
embed, cond = self.prepare_inputs(prompts, guidance_scale)
|
176 |
+
|
177 |
+
output_ids = []
|
178 |
+
|
179 |
+
batch_size, input_seq_len, dim = embed.shape
|
180 |
+
max_seq_len = input_seq_len + self.max_new_tokens
|
181 |
+
embed_buffer = torch.zeros(
|
182 |
+
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
|
183 |
+
)
|
184 |
+
embed_buffer[:, :input_seq_len, :].copy_(embed)
|
185 |
+
cond_len = cond.shape[1]
|
186 |
+
kv_cache = None
|
187 |
+
if use_kv_cache:
|
188 |
+
kv_cache = self.gpt_model.init_kv_cache(
|
189 |
+
batch_size,
|
190 |
+
cond_len,
|
191 |
+
self.max_new_tokens + 1, # +1 for the BOS token
|
192 |
+
torch.bfloat16,
|
193 |
+
embed.device,
|
194 |
+
)
|
195 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
196 |
+
for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
|
197 |
+
curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
|
198 |
+
logits = self.gpt_model(
|
199 |
+
embed_buffer,
|
200 |
+
cond,
|
201 |
+
kv_cache=kv_cache,
|
202 |
+
curr_pos_id=curr_pos_id if use_kv_cache else None,
|
203 |
+
decode=(i > 0) if use_kv_cache else False,
|
204 |
+
)
|
205 |
+
if use_kv_cache:
|
206 |
+
logits = logits[:, 0, ...]
|
207 |
+
else:
|
208 |
+
logits = logits[:, i, ...]
|
209 |
+
|
210 |
+
logits = logits[..., self.min_id : self.max_id]
|
211 |
+
|
212 |
+
if guidance_scale > 0.0:
|
213 |
+
logits, uncond_logits = logits.float().chunk(2, dim=0)
|
214 |
+
gamma = (
|
215 |
+
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
216 |
+
)
|
217 |
+
logits = (1 + gamma) * logits - gamma * uncond_logits
|
218 |
+
probs = process_logits(
|
219 |
+
logits,
|
220 |
+
top_k=top_k,
|
221 |
+
)
|
222 |
+
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
223 |
+
output_ids.append(next_id)
|
224 |
+
next_embed = self.gpt_model.encode_token(next_id)
|
225 |
+
if guidance_scale > 0.0:
|
226 |
+
next_embed = torch.cat([next_embed, next_embed], dim=0)
|
227 |
+
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
|
228 |
+
|
229 |
+
return torch.cat(output_ids, dim=1)
|
230 |
+
|
231 |
+
@torch.inference_mode()
|
232 |
+
def run_shape_decode(
|
233 |
+
self,
|
234 |
+
output_ids: torch.Tensor,
|
235 |
+
resolution_base: float = 8.0,
|
236 |
+
chunk_size: int = 100_000,
|
237 |
+
):
|
238 |
+
"""
|
239 |
+
Decodes the shape from the given output IDs and extracts the geometry.
|
240 |
+
Args:
|
241 |
+
output_ids (torch.Tensor): The tensor containing the output IDs.
|
242 |
+
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
|
243 |
+
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
|
244 |
+
Returns:
|
245 |
+
tuple: A tuple containing the vertices and faces of the mesh.
|
246 |
+
"""
|
247 |
+
shape_ids = (
|
248 |
+
output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...]
|
249 |
+
.clamp_(0, self.shape_model.cfg.num_codes - 1)
|
250 |
+
.view(-1, self.shape_model.cfg.num_encoder_latents)
|
251 |
+
)
|
252 |
+
latents = self.shape_model.decode_indices(shape_ids)
|
253 |
+
mesh_v_f, _ = self.shape_model.extract_geometry(
|
254 |
+
latents,
|
255 |
+
resolution_base=resolution_base,
|
256 |
+
chunk_size=chunk_size,
|
257 |
+
use_warp=True,
|
258 |
+
)
|
259 |
+
return mesh_v_f
|
260 |
+
|
261 |
+
@torch.inference_mode()
|
262 |
+
def t2s(
|
263 |
+
self,
|
264 |
+
prompts: list[str],
|
265 |
+
use_kv_cache: bool,
|
266 |
+
guidance_scale: float = 3.0,
|
267 |
+
resolution_base: float = 8.0,
|
268 |
+
chunk_size: int = 100_000,
|
269 |
+
top_k: int = 5,
|
270 |
+
):
|
271 |
+
"""
|
272 |
+
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
|
273 |
+
Args:
|
274 |
+
prompts (list[str]): A list of text prompts to guide the generation.
|
275 |
+
use_kv_cache (bool): Whether to use key-value caching for the GPT model.
|
276 |
+
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
|
277 |
+
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
|
278 |
+
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
|
279 |
+
Returns:
|
280 |
+
mesh_v_f: The generated 3D mesh vertices and faces.
|
281 |
+
"""
|
282 |
+
output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale, top_k)
|
283 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
284 |
+
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
|
285 |
+
return mesh_v_f
|
286 |
+
|
287 |
+
|
288 |
+
class EngineFast(Engine):
|
289 |
+
def __init__(
|
290 |
+
self,
|
291 |
+
config_path: str,
|
292 |
+
gpt_ckpt_path: str,
|
293 |
+
shape_ckpt_path: str,
|
294 |
+
device: torch.device,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
Initializes the inference engine with the given configuration and checkpoint paths.
|
298 |
+
Args:
|
299 |
+
config_path (str): Path to the configuration file.
|
300 |
+
gpt_ckpt_path (str): Path to the GPT checkpoint file.
|
301 |
+
shape_ckpt_path (str): Path to the shape checkpoint file.
|
302 |
+
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
|
303 |
+
"""
|
304 |
+
|
305 |
+
super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
|
306 |
+
|
307 |
+
# CUDA Graph params
|
308 |
+
self.graph = torch.cuda.CUDAGraph()
|
309 |
+
self.embed_buffer = torch.Tensor()
|
310 |
+
self.cond_buffer = torch.Tensor()
|
311 |
+
self.logits_buffer = torch.Tensor()
|
312 |
+
self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device)
|
313 |
+
self.kv_cache: list[Cache] = []
|
314 |
+
|
315 |
+
self._warmup_and_capture_graph()
|
316 |
+
|
317 |
+
def _warmup_and_capture_graph(self):
|
318 |
+
"""
|
319 |
+
Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution.
|
320 |
+
This method performs the following steps:
|
321 |
+
1. Prepares the input embeddings and conditions using a warmup prompt.
|
322 |
+
2. Initializes buffers for embeddings and conditions.
|
323 |
+
3. Initializes the key-value cache for the GPT model.
|
324 |
+
4. Runs a series of warmup passes to prefill the model and generate logits.
|
325 |
+
5. Captures the CUDA graph for the model's forward pass to optimize future executions.
|
326 |
+
"""
|
327 |
+
|
328 |
+
warmup_prompt = "A cube"
|
329 |
+
embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0)
|
330 |
+
|
331 |
+
batch_size, input_seq_len, dim = embed.shape
|
332 |
+
max_seq_len = input_seq_len + self.max_new_tokens
|
333 |
+
self.embed_buffer = torch.zeros(
|
334 |
+
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device
|
335 |
+
)
|
336 |
+
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
|
337 |
+
|
338 |
+
self.cond_buffer = torch.empty_like(cond)
|
339 |
+
self.cond_buffer.copy_(cond)
|
340 |
+
cond_len = self.cond_buffer.shape[1]
|
341 |
+
|
342 |
+
# Initialize kv_cache for the first time
|
343 |
+
self.kv_cache = self.gpt_model.init_kv_cache(
|
344 |
+
batch_size,
|
345 |
+
cond_len,
|
346 |
+
self.max_new_tokens + 1, # +1 for the BOS token
|
347 |
+
torch.bfloat16,
|
348 |
+
self.device,
|
349 |
+
)
|
350 |
+
|
351 |
+
num_warmup_passes = 10
|
352 |
+
|
353 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
354 |
+
self._set_curr_pos_id(0)
|
355 |
+
_ = self._prefill_and_return_logits()
|
356 |
+
|
357 |
+
for x in range(1, num_warmup_passes):
|
358 |
+
self._set_curr_pos_id(x)
|
359 |
+
self.logits_buffer = self.gpt_model(
|
360 |
+
embed=self.embed_buffer,
|
361 |
+
cond=self.cond_buffer,
|
362 |
+
kv_cache=self.kv_cache,
|
363 |
+
curr_pos_id=self.curr_pos_id,
|
364 |
+
decode=True,
|
365 |
+
)
|
366 |
+
|
367 |
+
side_stream = torch.cuda.Stream(device=self.device)
|
368 |
+
with torch.cuda.graph(self.graph, stream=side_stream):
|
369 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
370 |
+
self.logits_buffer = self.gpt_model(
|
371 |
+
embed=self.embed_buffer,
|
372 |
+
cond=self.cond_buffer,
|
373 |
+
kv_cache=self.kv_cache,
|
374 |
+
curr_pos_id=self.curr_pos_id,
|
375 |
+
decode=True,
|
376 |
+
)
|
377 |
+
|
378 |
+
def _reset_kv_cache(self):
|
379 |
+
"""
|
380 |
+
Resets the key-value cache by setting all key and value states to zero.
|
381 |
+
This method iterates through each cache in the `kv_cache` attribute and
|
382 |
+
calls the `zero_()` method on both `key_states` and `value_states` to
|
383 |
+
reset them to their initial state.
|
384 |
+
"""
|
385 |
+
|
386 |
+
for cache in self.kv_cache:
|
387 |
+
cache.key_states.zero_()
|
388 |
+
cache.value_states.zero_()
|
389 |
+
|
390 |
+
def _prefill_and_return_logits(self) -> torch.Tensor:
|
391 |
+
"""
|
392 |
+
Prefills the model's key-value cache and returns the logits.
|
393 |
+
This method resets the key-value cache and then performs a forward pass
|
394 |
+
through the GPT model in eager mode to prefill the logits.
|
395 |
+
Returns:
|
396 |
+
torch.Tensor: The prefilled logits tensor with the first dimension removed.
|
397 |
+
"""
|
398 |
+
|
399 |
+
self._reset_kv_cache()
|
400 |
+
|
401 |
+
# Prefill is always eager
|
402 |
+
prefill_logits = self.gpt_model(
|
403 |
+
embed=self.embed_buffer,
|
404 |
+
cond=self.cond_buffer,
|
405 |
+
kv_cache=self.kv_cache,
|
406 |
+
curr_pos_id=self.curr_pos_id,
|
407 |
+
decode=False,
|
408 |
+
)
|
409 |
+
|
410 |
+
return prefill_logits[:, 0, ...]
|
411 |
+
|
412 |
+
def _set_curr_pos_id(self, pos: int):
|
413 |
+
"""
|
414 |
+
Set the current position ID.
|
415 |
+
This method updates the `curr_pos_id` attribute with the given position.
|
416 |
+
Args:
|
417 |
+
pos (int): The position ID to set.
|
418 |
+
"""
|
419 |
+
|
420 |
+
self.curr_pos_id.copy_(
|
421 |
+
torch.tensor([pos], dtype=torch.long, device=self.device)
|
422 |
+
)
|
423 |
+
|
424 |
+
def run_gpt(
|
425 |
+
self,
|
426 |
+
prompts: list[str],
|
427 |
+
use_kv_cache: bool,
|
428 |
+
guidance_scale: float = 3.0,
|
429 |
+
top_k: int = 1,
|
430 |
+
):
|
431 |
+
"""
|
432 |
+
Runs the GPT model to generate text based on the provided prompts.
|
433 |
+
Args:
|
434 |
+
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
|
435 |
+
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
|
436 |
+
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
|
437 |
+
Returns:
|
438 |
+
torch.Tensor: A tensor containing the generated output token IDs.
|
439 |
+
Raises:
|
440 |
+
AssertionError: If the batch size is greater than 1.
|
441 |
+
"""
|
442 |
+
|
443 |
+
embed, cond = self.prepare_inputs(prompts, guidance_scale)
|
444 |
+
assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
|
445 |
+
|
446 |
+
batch_size, input_seq_len, _ = embed.shape
|
447 |
+
self.embed_buffer.zero_()
|
448 |
+
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
|
449 |
+
|
450 |
+
assert self.cond_buffer.shape == cond.shape
|
451 |
+
self.cond_buffer.copy_(cond)
|
452 |
+
|
453 |
+
output_ids = torch.zeros(
|
454 |
+
(batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
|
455 |
+
)
|
456 |
+
|
457 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16):
|
458 |
+
self._set_curr_pos_id(0)
|
459 |
+
|
460 |
+
logits = self._prefill_and_return_logits()
|
461 |
+
|
462 |
+
logits = logits[..., self.min_id : self.max_id]
|
463 |
+
if guidance_scale > 0.0:
|
464 |
+
logits, uncond_logits = logits.float().chunk(2, dim=0)
|
465 |
+
gamma = guidance_scale
|
466 |
+
logits = (1 + gamma) * logits - gamma * uncond_logits
|
467 |
+
|
468 |
+
probs = process_logits(logits, top_k=top_k)
|
469 |
+
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
470 |
+
|
471 |
+
output_ids[:, 0] = next_id.squeeze()
|
472 |
+
next_embed = self.gpt_model.encode_token(next_id)
|
473 |
+
next_embed = next_embed.repeat(2, 1, 1)
|
474 |
+
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
|
475 |
+
|
476 |
+
for i in tqdm(
|
477 |
+
range(1, self.max_new_tokens), desc=f"generating"
|
478 |
+
):
|
479 |
+
self._set_curr_pos_id(i)
|
480 |
+
self.graph.replay()
|
481 |
+
|
482 |
+
logits = self.logits_buffer[:, 0, ...]
|
483 |
+
|
484 |
+
logits = logits[..., self.min_id : self.max_id]
|
485 |
+
if guidance_scale > 0.0:
|
486 |
+
logits, uncond_logits = logits.float().chunk(2, dim=0)
|
487 |
+
gamma = (
|
488 |
+
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
|
489 |
+
)
|
490 |
+
logits = (1 + gamma) * logits - gamma * uncond_logits
|
491 |
+
probs = process_logits(logits, top_k=top_k)
|
492 |
+
next_id = torch.multinomial(probs, num_samples=1, replacement=True)
|
493 |
+
|
494 |
+
output_ids[:, i] = next_id.squeeze()
|
495 |
+
next_embed = self.gpt_model.encode_token(next_id)
|
496 |
+
next_embed = next_embed.repeat(2, 1, 1)
|
497 |
+
self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
|
498 |
+
|
499 |
+
return output_ids
|
cube/cube3d/inference/logits_postprocesses.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def top_k_filtering(logits, top_k: int = 1):
|
6 |
+
"""
|
7 |
+
Filter a distribution of logits using top-k and/or top-p (nucleus) filtering.
|
8 |
+
The input logits tensor is modified in-place.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size].
|
12 |
+
top_k: If > 0, only keep the top k tokens with highest probability.
|
13 |
+
top_p: If < 1.0, only keep tokens whose cumulative probability is below this threshold.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
A tensor of logits where values outside the top-k/top-p threshold are set to -∞.
|
17 |
+
"""
|
18 |
+
if top_k > 0:
|
19 |
+
idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[
|
20 |
+
0
|
21 |
+
].amin(dim=-1, keepdim=True)
|
22 |
+
logits.masked_fill_(idx_to_remove, -torch.inf)
|
23 |
+
|
24 |
+
return logits
|
25 |
+
|
26 |
+
|
27 |
+
def process_logits(
|
28 |
+
logits,
|
29 |
+
top_k: int = 1,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Process logits by optionally applying top-k filtering.
|
33 |
+
The final probabilities are returned after applying softmax on the filtered logits.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
logits: A tensor of logits to process. Expected shape is [..., vocab_size].
|
37 |
+
top_k: If > 0, only keep the top k tokens with highest probability.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
A tensor of probabilities after filtering, with the same shape as the input logits.
|
41 |
+
"""
|
42 |
+
logits = top_k_filtering(logits, top_k=top_k)
|
43 |
+
probs = F.softmax(logits, dim=-1)
|
44 |
+
return probs
|
cube/cube3d/inference/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from omegaconf import DictConfig, OmegaConf
|
6 |
+
from safetensors.torch import load_model
|
7 |
+
|
8 |
+
|
9 |
+
def load_config(cfg_path: str) -> Any:
|
10 |
+
"""
|
11 |
+
Load and resolve a configuration file.
|
12 |
+
Args:
|
13 |
+
cfg_path (str): The path to the configuration file.
|
14 |
+
Returns:
|
15 |
+
Any: The loaded and resolved configuration object.
|
16 |
+
Raises:
|
17 |
+
AssertionError: If the loaded configuration is not an instance of DictConfig.
|
18 |
+
"""
|
19 |
+
|
20 |
+
cfg = OmegaConf.load(cfg_path)
|
21 |
+
OmegaConf.resolve(cfg)
|
22 |
+
assert isinstance(cfg, DictConfig)
|
23 |
+
return cfg
|
24 |
+
|
25 |
+
|
26 |
+
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
|
27 |
+
"""
|
28 |
+
Parses a configuration dictionary into a structured configuration object.
|
29 |
+
Args:
|
30 |
+
cfg_type (Any): The type of the structured configuration object.
|
31 |
+
cfg (DictConfig): The configuration dictionary to be parsed.
|
32 |
+
Returns:
|
33 |
+
Any: The structured configuration object created from the dictionary.
|
34 |
+
"""
|
35 |
+
|
36 |
+
scfg = OmegaConf.structured(cfg_type(**cfg))
|
37 |
+
return scfg
|
38 |
+
|
39 |
+
|
40 |
+
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
|
41 |
+
"""
|
42 |
+
Load a safetensors checkpoint into a PyTorch model.
|
43 |
+
The model is updated in place.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
model: PyTorch model to load weights into
|
47 |
+
ckpt_path: Path to the safetensors checkpoint file
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
None
|
51 |
+
"""
|
52 |
+
assert ckpt_path.endswith(".safetensors"), (
|
53 |
+
f"Checkpoint path '{ckpt_path}' is not a safetensors file"
|
54 |
+
)
|
55 |
+
|
56 |
+
load_model(model, ckpt_path)
|
cube/cube3d/mesh_utils/postprocessing.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
try:
|
6 |
+
import pymeshlab
|
7 |
+
|
8 |
+
PYMESHLAB_AVAILABLE = True
|
9 |
+
except ImportError:
|
10 |
+
logging.warning(
|
11 |
+
"pymeshlab is not installed or could not be loaded. Please install it with `pip install pymeshlab`."
|
12 |
+
)
|
13 |
+
PYMESHLAB_AVAILABLE = False
|
14 |
+
from typing import Any
|
15 |
+
|
16 |
+
# Create stub class for typing
|
17 |
+
class pymeshlab:
|
18 |
+
MeshSet = Any
|
19 |
+
Mesh = Any
|
20 |
+
|
21 |
+
|
22 |
+
def create_pymeshset(vertices: np.ndarray, faces: np.ndarray):
|
23 |
+
"""
|
24 |
+
Creates a MeshLab MeshSet given a list of vertices and faces.
|
25 |
+
"""
|
26 |
+
assert PYMESHLAB_AVAILABLE, "pymeshlab is not installed or could not be loaded."
|
27 |
+
# Initialize MeshSet and create pymeshlab.Mesh
|
28 |
+
mesh_set = pymeshlab.MeshSet()
|
29 |
+
input_mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)
|
30 |
+
mesh_set.add_mesh(input_mesh, "input_mesh")
|
31 |
+
logging.info("Mesh successfully added to pymeshlab MeshSet.")
|
32 |
+
return mesh_set
|
33 |
+
|
34 |
+
|
35 |
+
def cleanup(ms: pymeshlab.MeshSet):
|
36 |
+
"""
|
37 |
+
General cleanup for a given Mesh. Removes degenerate elements from the
|
38 |
+
geometry.
|
39 |
+
"""
|
40 |
+
ms.meshing_remove_null_faces()
|
41 |
+
ms.meshing_remove_folded_faces()
|
42 |
+
ms.meshing_remove_duplicate_vertices()
|
43 |
+
ms.meshing_remove_duplicate_faces()
|
44 |
+
ms.meshing_remove_t_vertices()
|
45 |
+
ms.meshing_remove_unreferenced_vertices()
|
46 |
+
|
47 |
+
|
48 |
+
def remove_floaters(ms: pymeshlab.MeshSet, threshold: float = 0.005):
|
49 |
+
"""
|
50 |
+
Remove any floating artifacts that exist from our mesh generation.
|
51 |
+
"""
|
52 |
+
assert PYMESHLAB_AVAILABLE, "pymeshlab is not installed or could not be loaded."
|
53 |
+
ms.meshing_remove_connected_component_by_diameter(
|
54 |
+
mincomponentdiag=pymeshlab.PercentageValue(15), removeunref=True
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def simplify_mesh(ms: pymeshlab.MeshSet, target_face_num: int):
|
59 |
+
"""
|
60 |
+
Simplify the mesh to the target number of faces.
|
61 |
+
"""
|
62 |
+
ms.meshing_decimation_quadric_edge_collapse(
|
63 |
+
targetfacenum=target_face_num,
|
64 |
+
qualitythr=0.4,
|
65 |
+
preservenormal=True,
|
66 |
+
autoclean=True,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def save_mesh(ms: pymeshlab.MeshSet, output_path: str):
|
71 |
+
"""
|
72 |
+
Save the mesh to a file.
|
73 |
+
"""
|
74 |
+
ms.save_current_mesh(output_path)
|
75 |
+
logging.info(f"Mesh saved to {output_path}.")
|
76 |
+
|
77 |
+
|
78 |
+
def postprocess_mesh(ms: pymeshlab.MeshSet, target_face_num: int, output_path: str):
|
79 |
+
"""
|
80 |
+
Postprocess the mesh to the target number of faces.
|
81 |
+
"""
|
82 |
+
cleanup(ms)
|
83 |
+
remove_floaters(ms)
|
84 |
+
simplify_mesh(ms, target_face_num)
|
cube/cube3d/model/__init__.py
ADDED
File without changes
|
cube/cube3d/model/autoencoder/__init__.py
ADDED
File without changes
|
cube/cube3d/model/autoencoder/embedder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class PhaseModulatedFourierEmbedder(torch.nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
num_freqs: int,
|
11 |
+
input_dim: int = 3,
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Initializes the PhaseModulatedFourierEmbedder class.
|
15 |
+
Args:
|
16 |
+
num_freqs (int): The number of frequencies to be used.
|
17 |
+
input_dim (int, optional): The dimension of the input. Defaults to 3.
|
18 |
+
Attributes:
|
19 |
+
weight (torch.nn.Parameter): The weight parameter initialized with random values.
|
20 |
+
carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem.
|
21 |
+
out_dim (int): The output dimension calculated based on the input dimension and number of frequencies.
|
22 |
+
"""
|
23 |
+
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.weight = nn.Parameter(
|
27 |
+
torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs)
|
28 |
+
)
|
29 |
+
|
30 |
+
# NOTE this is the highest frequency we can get (2 for peaks, 2 for zeros, and 4 for interpolation points), see also https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
|
31 |
+
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
|
32 |
+
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi
|
33 |
+
self.register_buffer("carrier", carrier, persistent=False)
|
34 |
+
|
35 |
+
self.out_dim = input_dim * (num_freqs * 2 + 1)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
"""
|
39 |
+
Perform the forward pass of the embedder model.
|
40 |
+
Args:
|
41 |
+
x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim).
|
42 |
+
Returns:
|
43 |
+
torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where
|
44 |
+
output_dim = input_dim + 2 * input_dim.
|
45 |
+
"""
|
46 |
+
|
47 |
+
m = x.float().unsqueeze(-1)
|
48 |
+
fm = (m * self.weight).view(*x.shape[:-1], -1)
|
49 |
+
pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1)
|
50 |
+
embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1)
|
51 |
+
|
52 |
+
return embedding
|
cube/cube3d/model/autoencoder/grid.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import warp as wp
|
6 |
+
|
7 |
+
|
8 |
+
def generate_dense_grid_points(
|
9 |
+
bbox_min: np.ndarray,
|
10 |
+
bbox_max: np.ndarray,
|
11 |
+
resolution_base: float,
|
12 |
+
indexing: Literal["xy", "ij"] = "ij",
|
13 |
+
) -> tuple[np.ndarray, list[int], np.ndarray]:
|
14 |
+
"""
|
15 |
+
Generate a dense grid of points within a bounding box.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
bbox_min (np.ndarray): The minimum coordinates of the bounding box (3D).
|
19 |
+
bbox_max (np.ndarray): The maximum coordinates of the bounding box (3D).
|
20 |
+
resolution_base (float): The base resolution for the grid. The number of cells along each axis will be 2^resolution_base.
|
21 |
+
indexing (Literal["xy", "ij"], optional): The indexing convention for the grid. "xy" for Cartesian indexing, "ij" for matrix indexing. Default is "ij".
|
22 |
+
Returns:
|
23 |
+
tuple: A tuple containing:
|
24 |
+
- xyz (np.ndarray): A 2D array of shape (N, 3) where N is the total number of grid points. Each row represents the (x, y, z) coordinates of a grid point.
|
25 |
+
- grid_size (list): A list of three integers representing the number of grid points along each axis.
|
26 |
+
- length (np.ndarray): The length of the bounding box along each axis.
|
27 |
+
"""
|
28 |
+
length = bbox_max - bbox_min
|
29 |
+
num_cells = np.exp2(resolution_base)
|
30 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
31 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
32 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
33 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
34 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
35 |
+
xyz = xyz.reshape(-1, 3)
|
36 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
37 |
+
|
38 |
+
return xyz, grid_size, length
|
39 |
+
|
40 |
+
|
41 |
+
def marching_cubes_with_warp(
|
42 |
+
grid_logits: torch.Tensor,
|
43 |
+
level: float,
|
44 |
+
device: Union[str, torch.device] = "cuda",
|
45 |
+
max_verts: int = 3_000_000,
|
46 |
+
max_tris: int = 3_000_000,
|
47 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
48 |
+
"""
|
49 |
+
Perform the marching cubes algorithm on a 3D grid with warp support.
|
50 |
+
Args:
|
51 |
+
grid_logits (torch.Tensor): A 3D tensor containing the grid logits.
|
52 |
+
level (float): The threshold level for the isosurface.
|
53 |
+
device (Union[str, torch.device], optional): The device to perform the computation on. Defaults to "cuda".
|
54 |
+
max_verts (int, optional): The maximum number of vertices. Defaults to 3,000,000.
|
55 |
+
max_tris (int, optional): The maximum number of triangles. Defaults to 3,000,000.
|
56 |
+
Returns:
|
57 |
+
Tuple[np.ndarray, np.ndarray]: A tuple containing the vertices and faces of the isosurface.
|
58 |
+
"""
|
59 |
+
if isinstance(device, torch.device):
|
60 |
+
device = str(device)
|
61 |
+
|
62 |
+
assert grid_logits.ndim == 3
|
63 |
+
if "cuda" in device:
|
64 |
+
assert wp.is_cuda_available()
|
65 |
+
else:
|
66 |
+
raise ValueError(
|
67 |
+
f"Device {device} is not supported for marching_cubes_with_warp"
|
68 |
+
)
|
69 |
+
|
70 |
+
dim = grid_logits.shape[0]
|
71 |
+
field = wp.from_torch(grid_logits)
|
72 |
+
|
73 |
+
iso = wp.MarchingCubes(
|
74 |
+
nx=dim,
|
75 |
+
ny=dim,
|
76 |
+
nz=dim,
|
77 |
+
max_verts=int(max_verts),
|
78 |
+
max_tris=int(max_tris),
|
79 |
+
device=device,
|
80 |
+
)
|
81 |
+
iso.surface(field=field, threshold=level)
|
82 |
+
vertices = iso.verts.numpy()
|
83 |
+
faces = iso.indices.numpy().reshape(-1, 3)
|
84 |
+
return vertices, faces
|
cube/cube3d/model/autoencoder/one_d_autoencoder.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from functools import partial
|
5 |
+
from typing import List, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from skimage import measure
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from cube3d.model.autoencoder.embedder import PhaseModulatedFourierEmbedder
|
15 |
+
from cube3d.model.autoencoder.grid import (
|
16 |
+
generate_dense_grid_points,
|
17 |
+
marching_cubes_with_warp,
|
18 |
+
)
|
19 |
+
from cube3d.model.autoencoder.spherical_vq import SphericalVectorQuantizer
|
20 |
+
from cube3d.model.transformers.attention import (
|
21 |
+
EncoderCrossAttentionLayer,
|
22 |
+
EncoderLayer,
|
23 |
+
init_linear,
|
24 |
+
init_tfixup,
|
25 |
+
)
|
26 |
+
from cube3d.model.transformers.norm import LayerNorm
|
27 |
+
|
28 |
+
|
29 |
+
def init_sort(x):
|
30 |
+
"""
|
31 |
+
Sorts the input tensor `x` based on its pairwise distances to the first element.
|
32 |
+
This function computes the pairwise distances between all elements in `x` and the
|
33 |
+
first element of `x`. It then sorts the elements of `x` in ascending order of
|
34 |
+
their distances to the first element.
|
35 |
+
Args:
|
36 |
+
x (torch.Tensor): A 2D tensor where each row represents a data point.
|
37 |
+
Returns:
|
38 |
+
torch.Tensor: A tensor containing the rows of `x` sorted by their distances
|
39 |
+
to the first row of `x`.
|
40 |
+
"""
|
41 |
+
|
42 |
+
distances = torch.cdist(x, x[:1])
|
43 |
+
_, indices = torch.sort(distances.squeeze(), dim=0)
|
44 |
+
x = x[indices]
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class MLPEmbedder(nn.Module):
|
49 |
+
def __init__(self, in_dim: int, embed_dim: int, bias: bool = True):
|
50 |
+
super().__init__()
|
51 |
+
self.in_layer = nn.Linear(in_dim, embed_dim, bias=bias)
|
52 |
+
self.silu = nn.SiLU()
|
53 |
+
self.out_layer = nn.Linear(embed_dim, embed_dim, bias=bias)
|
54 |
+
|
55 |
+
self.apply(partial(init_linear, embed_dim=embed_dim))
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
58 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
59 |
+
|
60 |
+
|
61 |
+
class OneDEncoder(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
embedder,
|
65 |
+
num_latents: int,
|
66 |
+
point_feats: int,
|
67 |
+
embed_point_feats: bool,
|
68 |
+
width: int,
|
69 |
+
num_heads: int,
|
70 |
+
num_layers: int,
|
71 |
+
with_cls_token: bool = False,
|
72 |
+
cross_attention_levels: Optional[List[int]] = None,
|
73 |
+
eps: float = 1e-6,
|
74 |
+
) -> None:
|
75 |
+
"""
|
76 |
+
Initializes the OneDEncoder model.
|
77 |
+
Args:
|
78 |
+
embedder: An embedding module that provides the input embedding functionality.
|
79 |
+
num_latents (int): The number of latent variables.
|
80 |
+
point_feats (int): The number of point features.
|
81 |
+
embed_point_feats (bool): Whether to embed point features or not.
|
82 |
+
width (int): The width of the embedding dimension.
|
83 |
+
num_heads (int): The number of attention heads.
|
84 |
+
num_layers (int): The number of encoder layers.
|
85 |
+
with_cls_token (bool, optional): Whether to include a classification token like in Vision Transformers (ViT). Defaults to False.
|
86 |
+
cross_attention_levels (Optional[List[int]], optional): The indices of layers where cross-attention is applied. Defaults to None.
|
87 |
+
eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
|
88 |
+
Returns:
|
89 |
+
None
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.embedder = embedder
|
94 |
+
|
95 |
+
# add cls token like ViT
|
96 |
+
self.with_cls_token = with_cls_token
|
97 |
+
if self.with_cls_token:
|
98 |
+
query = torch.empty((1 + num_latents, width))
|
99 |
+
else:
|
100 |
+
query = torch.empty((num_latents, width))
|
101 |
+
|
102 |
+
# initialize then sort query to potentially get better ordering
|
103 |
+
query.uniform_(-1.0, 1.0)
|
104 |
+
query = init_sort(query)
|
105 |
+
|
106 |
+
# set parameter
|
107 |
+
self.query = nn.Parameter(query)
|
108 |
+
|
109 |
+
self.embed_point_feats = embed_point_feats
|
110 |
+
in_dim = (
|
111 |
+
self.embedder.out_dim * 2
|
112 |
+
if self.embed_point_feats
|
113 |
+
else self.embedder.out_dim + point_feats
|
114 |
+
)
|
115 |
+
self.feat_in = MLPEmbedder(in_dim, embed_dim=width)
|
116 |
+
|
117 |
+
if cross_attention_levels is None:
|
118 |
+
cross_attention_levels = [0]
|
119 |
+
|
120 |
+
self.blocks = nn.ModuleList()
|
121 |
+
for i in range(num_layers):
|
122 |
+
if i in cross_attention_levels:
|
123 |
+
self.blocks.append(
|
124 |
+
EncoderCrossAttentionLayer(
|
125 |
+
embed_dim=width,
|
126 |
+
num_heads=num_heads,
|
127 |
+
eps=eps,
|
128 |
+
)
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.blocks.append(
|
132 |
+
EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
|
133 |
+
)
|
134 |
+
self.ln_f = LayerNorm(width, eps=eps)
|
135 |
+
|
136 |
+
init_tfixup(self, num_layers)
|
137 |
+
|
138 |
+
def _forward(self, h, data, attn_mask=None):
|
139 |
+
"""
|
140 |
+
Forward pass for the autoencoder model.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
h (torch.Tensor): The input tensor to be processed, typically representing
|
144 |
+
the hidden state or intermediate representation.
|
145 |
+
data (torch.Tensor): The input data tensor to be transformed by the feature
|
146 |
+
extraction layer and used in cross-attention layers.
|
147 |
+
attn_mask (torch.Tensor, optional): An optional attention mask tensor to be
|
148 |
+
used in attention layers for masking specific positions. Defaults to None.
|
149 |
+
Returns:
|
150 |
+
torch.Tensor: The output tensor after processing through the layers and
|
151 |
+
applying final normalization.
|
152 |
+
"""
|
153 |
+
|
154 |
+
data = self.feat_in(data)
|
155 |
+
|
156 |
+
for block in self.blocks:
|
157 |
+
if isinstance(block, EncoderCrossAttentionLayer):
|
158 |
+
h = block(h, data)
|
159 |
+
else:
|
160 |
+
h = block(h, attn_mask=attn_mask)
|
161 |
+
|
162 |
+
h = self.ln_f(h)
|
163 |
+
return h
|
164 |
+
|
165 |
+
def forward(
|
166 |
+
self, pts: torch.Tensor, feats: torch.Tensor
|
167 |
+
) -> Tuple[torch.Tensor, list[torch.Tensor]]:
|
168 |
+
"""
|
169 |
+
Forward pass of the 1D autoencoder model.
|
170 |
+
Args:
|
171 |
+
pts (torch.Tensor): Input tensor representing points with shape (batch_size, num_points, point_dim).
|
172 |
+
feats (torch.Tensor): Input tensor representing features with shape (batch_size, num_points, feature_dim).
|
173 |
+
Can be None if no features are provided.
|
174 |
+
Returns:
|
175 |
+
Tuple[torch.Tensor, list[torch.Tensor]]:
|
176 |
+
- The output tensor after processing the input data.
|
177 |
+
- A list of intermediate tensors (if applicable) generated during the forward pass.
|
178 |
+
"""
|
179 |
+
|
180 |
+
b = pts.shape[0]
|
181 |
+
data = self.embedder(pts)
|
182 |
+
|
183 |
+
if feats is not None:
|
184 |
+
if self.embed_point_feats:
|
185 |
+
feats = self.embedder(feats)
|
186 |
+
data = torch.cat([data, feats], dim=-1)
|
187 |
+
|
188 |
+
# prepare query and data
|
189 |
+
h = self.query.unsqueeze(0).expand(b, -1, -1)
|
190 |
+
return self._forward(h, data, attn_mask=None)
|
191 |
+
|
192 |
+
|
193 |
+
class OneDBottleNeck(nn.Module):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
block,
|
197 |
+
) -> None:
|
198 |
+
"""
|
199 |
+
Initializes the OneDBottleNeck class.
|
200 |
+
Args:
|
201 |
+
block: The building block or module used within the autoencoder.
|
202 |
+
"""
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.block = block
|
206 |
+
|
207 |
+
def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
208 |
+
"""
|
209 |
+
Forward pass of the OneDBottleNeck function.
|
210 |
+
Args:
|
211 |
+
h (torch.Tensor): Input tensor to the model.
|
212 |
+
Returns:
|
213 |
+
Tuple[torch.Tensor, dict]: A tuple containing:
|
214 |
+
- The transformed tensor `z` after passing through the block (if applicable).
|
215 |
+
- A dictionary `ret_dict` containing additional information:
|
216 |
+
- "indices": Indices from the block output (if present).
|
217 |
+
- "z_q": Quantized tensor from the block output (if present).
|
218 |
+
|
219 |
+
"""
|
220 |
+
|
221 |
+
z = h
|
222 |
+
ret_dict = {}
|
223 |
+
if self.block is not None:
|
224 |
+
z, d = self.block(z)
|
225 |
+
|
226 |
+
key_mappings = {
|
227 |
+
"q": "indices",
|
228 |
+
"z_q": "z_q",
|
229 |
+
}
|
230 |
+
for in_key, out_key in key_mappings.items():
|
231 |
+
if in_key in d:
|
232 |
+
ret_dict[out_key] = d[in_key]
|
233 |
+
|
234 |
+
return z, ret_dict
|
235 |
+
|
236 |
+
|
237 |
+
class OneDDecoder(nn.Module):
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
num_latents: int,
|
241 |
+
width: int,
|
242 |
+
num_heads: int,
|
243 |
+
num_layers: int,
|
244 |
+
eps: float = 1e-6,
|
245 |
+
) -> None:
|
246 |
+
"""
|
247 |
+
Initializes the OneDDecoder class.
|
248 |
+
Args:
|
249 |
+
num_latents (int): The number of latent variables.
|
250 |
+
width (int): The width of the embedding dimension.
|
251 |
+
num_heads (int): The number of attention heads in each encoder layer.
|
252 |
+
num_layers (int): The number of encoder layers.
|
253 |
+
eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
|
254 |
+
"""
|
255 |
+
super().__init__()
|
256 |
+
|
257 |
+
self.register_buffer("query", torch.empty([0, width]), persistent=False)
|
258 |
+
self.positional_encodings = nn.Parameter(
|
259 |
+
init_sort(F.normalize(torch.empty(num_latents, width).normal_()))
|
260 |
+
)
|
261 |
+
self.blocks = nn.ModuleList(
|
262 |
+
[
|
263 |
+
EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
|
264 |
+
for _ in range(num_layers)
|
265 |
+
]
|
266 |
+
)
|
267 |
+
|
268 |
+
init_tfixup(self, num_layers)
|
269 |
+
|
270 |
+
def _forward(self, h):
|
271 |
+
"""
|
272 |
+
Applies a sequence of operations to the input tensor `h` using the blocks
|
273 |
+
defined in the model.
|
274 |
+
Args:
|
275 |
+
h (torch.Tensor): The input tensor to be processed by the blocks.
|
276 |
+
Returns:
|
277 |
+
torch.Tensor: The output tensor after applying all blocks sequentially.
|
278 |
+
"""
|
279 |
+
|
280 |
+
for block in self.blocks:
|
281 |
+
h = block(h)
|
282 |
+
return h
|
283 |
+
|
284 |
+
def forward(self, z):
|
285 |
+
"""
|
286 |
+
This method processes the input tensor `z` by padding it to a fixed length,
|
287 |
+
adding positional encodings, and then passing it through the `_forward` method.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
z (torch.Tensor): Input tensor.
|
291 |
+
Returns:
|
292 |
+
torch.Tensor: Output tensor after processing through the autoencoder.
|
293 |
+
Notes:
|
294 |
+
- If the `query` attribute has a non-zero shape, the input tensor `z` is padded
|
295 |
+
to match the required length using slices of `query`.
|
296 |
+
- Positional encodings are added to the padded input tensor before passing it
|
297 |
+
to the `_forward` method.
|
298 |
+
"""
|
299 |
+
|
300 |
+
# pad input to fixed length
|
301 |
+
if self.query.shape[0] > 0:
|
302 |
+
pad_len = self.query.shape[0] + 1 - z.shape[1]
|
303 |
+
paddings = self.query[:pad_len, ...].unsqueeze(0).expand(z.shape[0], -1, -1)
|
304 |
+
z = torch.cat([paddings, z], dim=1)
|
305 |
+
h = z + self.positional_encodings[: z.shape[1], :].unsqueeze(0).expand(
|
306 |
+
z.shape[0], -1, -1
|
307 |
+
)
|
308 |
+
|
309 |
+
return self._forward(h)
|
310 |
+
|
311 |
+
|
312 |
+
class OneDOccupancyDecoder(nn.Module):
|
313 |
+
def __init__(
|
314 |
+
self, embedder, out_features: int, width: int, num_heads: int, eps=1e-6
|
315 |
+
) -> None:
|
316 |
+
"""
|
317 |
+
Initializes the OneDOccupancyDecoder module.
|
318 |
+
Args:
|
319 |
+
embedder: An embedding module that provides input embeddings.
|
320 |
+
out_features (int): The number of output features for the final linear layer.
|
321 |
+
width (int): The width of the intermediate layers.
|
322 |
+
num_heads (int): The number of attention heads for the cross-attention layer.
|
323 |
+
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6.
|
324 |
+
"""
|
325 |
+
super().__init__()
|
326 |
+
|
327 |
+
self.embedder = embedder
|
328 |
+
self.query_in = MLPEmbedder(self.embedder.out_dim, width)
|
329 |
+
|
330 |
+
self.attn_out = EncoderCrossAttentionLayer(embed_dim=width, num_heads=num_heads)
|
331 |
+
self.ln_f = LayerNorm(width, eps=eps)
|
332 |
+
self.c_head = nn.Linear(width, out_features)
|
333 |
+
|
334 |
+
def query(self, queries: torch.Tensor):
|
335 |
+
"""
|
336 |
+
Processes the input tensor through the embedder and query_in layers.
|
337 |
+
Args:
|
338 |
+
queries (torch.Tensor): A tensor containing the input data to be processed.
|
339 |
+
Returns:
|
340 |
+
torch.Tensor: The output tensor after being processed by the embedder and query_in layers.
|
341 |
+
"""
|
342 |
+
|
343 |
+
return self.query_in(self.embedder(queries))
|
344 |
+
|
345 |
+
def forward(self, queries: torch.Tensor, latents: torch.Tensor):
|
346 |
+
"""
|
347 |
+
Defines the forward pass of the model.
|
348 |
+
Args:
|
349 |
+
queries (torch.Tensor): Input tensor representing the queries.
|
350 |
+
latents (torch.Tensor): Input tensor representing the latent representations.
|
351 |
+
Returns:
|
352 |
+
torch.Tensor: Output tensor after applying the query transformation,
|
353 |
+
attention mechanism, and final processing layers.
|
354 |
+
"""
|
355 |
+
queries = self.query(queries)
|
356 |
+
x = self.attn_out(queries, latents)
|
357 |
+
x = self.c_head(self.ln_f(x))
|
358 |
+
return x
|
359 |
+
|
360 |
+
|
361 |
+
class OneDAutoEncoder(nn.Module):
|
362 |
+
@dataclass
|
363 |
+
class Config:
|
364 |
+
checkpoint_path: str = ""
|
365 |
+
|
366 |
+
# network params
|
367 |
+
num_encoder_latents: int = 256
|
368 |
+
num_decoder_latents: int = 256
|
369 |
+
embed_dim: int = 12
|
370 |
+
width: int = 768
|
371 |
+
num_heads: int = 12
|
372 |
+
out_dim: int = 1
|
373 |
+
eps: float = 1e-6
|
374 |
+
|
375 |
+
# grid features embedding
|
376 |
+
num_freqs: int = 128
|
377 |
+
point_feats: int = 0
|
378 |
+
embed_point_feats: bool = False
|
379 |
+
|
380 |
+
num_encoder_layers: int = 1
|
381 |
+
encoder_cross_attention_levels: list[int] = field(default_factory=list)
|
382 |
+
num_decoder_layers: int = 23
|
383 |
+
|
384 |
+
encoder_with_cls_token: bool = True
|
385 |
+
num_codes: int = 16384
|
386 |
+
|
387 |
+
def __init__(self, cfg: Config) -> None:
|
388 |
+
"""
|
389 |
+
Initializes the OneDAutoencoder model.
|
390 |
+
Args:
|
391 |
+
cfg (Config): Configuration object containing the parameters for the model.
|
392 |
+
Attributes:
|
393 |
+
cfg (Config): Stores the configuration object.
|
394 |
+
embedder (PhaseModulatedFourierEmbedder): Embeds input data using phase-modulated Fourier features.
|
395 |
+
encoder (OneDEncoder): Encodes the input data into latent representations.
|
396 |
+
bottleneck (OneDBottleNeck): Bottleneck layer containing a spherical vector quantizer for dimensionality reduction.
|
397 |
+
decoder (OneDDecoder): Decodes latent representations back into the original data space.
|
398 |
+
occupancy_decoder (OneDOccupancyDecoder): Decodes occupancy information from latent representations.
|
399 |
+
"""
|
400 |
+
|
401 |
+
super().__init__()
|
402 |
+
|
403 |
+
self.cfg = cfg
|
404 |
+
|
405 |
+
self.embedder = PhaseModulatedFourierEmbedder(
|
406 |
+
num_freqs=self.cfg.num_freqs, input_dim=3
|
407 |
+
)
|
408 |
+
|
409 |
+
self.encoder = OneDEncoder(
|
410 |
+
embedder=self.embedder,
|
411 |
+
num_latents=self.cfg.num_encoder_latents,
|
412 |
+
with_cls_token=self.cfg.encoder_with_cls_token,
|
413 |
+
point_feats=self.cfg.point_feats,
|
414 |
+
embed_point_feats=self.cfg.embed_point_feats,
|
415 |
+
width=self.cfg.width,
|
416 |
+
num_heads=self.cfg.num_heads,
|
417 |
+
num_layers=self.cfg.num_encoder_layers,
|
418 |
+
cross_attention_levels=self.cfg.encoder_cross_attention_levels,
|
419 |
+
eps=self.cfg.eps,
|
420 |
+
)
|
421 |
+
|
422 |
+
block = SphericalVectorQuantizer(
|
423 |
+
self.cfg.embed_dim,
|
424 |
+
self.cfg.num_codes,
|
425 |
+
self.cfg.width,
|
426 |
+
codebook_regularization="kl",
|
427 |
+
)
|
428 |
+
self.bottleneck = OneDBottleNeck(block=block)
|
429 |
+
|
430 |
+
self.decoder = OneDDecoder(
|
431 |
+
num_latents=self.cfg.num_encoder_latents,
|
432 |
+
width=self.cfg.width,
|
433 |
+
num_heads=self.cfg.num_heads,
|
434 |
+
num_layers=self.cfg.num_decoder_layers,
|
435 |
+
eps=self.cfg.eps,
|
436 |
+
)
|
437 |
+
|
438 |
+
self.occupancy_decoder = OneDOccupancyDecoder(
|
439 |
+
embedder=self.embedder,
|
440 |
+
out_features=self.cfg.out_dim,
|
441 |
+
width=self.cfg.width,
|
442 |
+
num_heads=self.cfg.num_heads,
|
443 |
+
eps=self.cfg.eps,
|
444 |
+
)
|
445 |
+
|
446 |
+
@torch.no_grad()
|
447 |
+
def decode_indices(self, shape_ids: torch.Tensor):
|
448 |
+
"""
|
449 |
+
Decodes the given shape indices into latent representations.
|
450 |
+
Args:
|
451 |
+
shape_ids (torch.Tensor): A tensor containing the shape indices to be decoded.
|
452 |
+
Returns:
|
453 |
+
torch.Tensor: The decoded latent representations corresponding to the input shape indices.
|
454 |
+
"""
|
455 |
+
|
456 |
+
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
457 |
+
latents = self.decode(z_q)
|
458 |
+
return latents
|
459 |
+
|
460 |
+
@torch.no_grad()
|
461 |
+
def query_embeds(self, shape_ids: torch.Tensor):
|
462 |
+
"""
|
463 |
+
Retrieves the latent embeddings corresponding to the given shape IDs.
|
464 |
+
Args:
|
465 |
+
shape_ids (torch.Tensor): A tensor containing the IDs of the shapes
|
466 |
+
for which the latent embeddings are to be queried.
|
467 |
+
Returns:
|
468 |
+
torch.Tensor: A tensor containing the latent embeddings retrieved
|
469 |
+
from the codebook for the provided shape IDs.
|
470 |
+
"""
|
471 |
+
|
472 |
+
z_q = self.bottleneck.block.lookup_codebook_latents(shape_ids)
|
473 |
+
return z_q
|
474 |
+
|
475 |
+
@torch.no_grad()
|
476 |
+
def query_indices(self, shape_embs: torch.Tensor):
|
477 |
+
"""
|
478 |
+
Queries the indices of the quantized embeddings from the bottleneck layer.
|
479 |
+
Args:
|
480 |
+
shape_embs (torch.Tensor): The input tensor containing shape embeddings
|
481 |
+
to be quantized.
|
482 |
+
Returns:
|
483 |
+
torch.Tensor: A tensor containing the quantized indices.
|
484 |
+
"""
|
485 |
+
|
486 |
+
_, ret_dict = self.bottleneck.block.quantize(shape_embs)
|
487 |
+
return ret_dict["q"]
|
488 |
+
|
489 |
+
def encode(self, x: torch.Tensor, **kwargs):
|
490 |
+
"""
|
491 |
+
Encodes the input tensor using the encoder and bottleneck layers.
|
492 |
+
Args:
|
493 |
+
x (torch.Tensor): Input tensor with shape (..., N), where the first 3
|
494 |
+
dimensions represent points (pts) and the remaining dimensions
|
495 |
+
represent features (feats).
|
496 |
+
**kwargs: Additional keyword arguments.
|
497 |
+
Returns:
|
498 |
+
Tuple[torch.Tensor, torch.Tensor, None, dict]: A tuple containing:
|
499 |
+
- z_e (torch.Tensor): Encoded tensor before bottleneck processing.
|
500 |
+
- z (torch.Tensor): Encoded tensor after bottleneck processing.
|
501 |
+
- None: Placeholder for compatibility with other methods.
|
502 |
+
- d (dict): Dictionary containing additional information, including:
|
503 |
+
- "z_cls" (torch.Tensor, optional): Class token if
|
504 |
+
`self.cfg.encoder_with_cls_token` is True.
|
505 |
+
"""
|
506 |
+
|
507 |
+
pts, feats = x[..., :3], x[..., 3:]
|
508 |
+
z_e = self.encoder(pts, feats)
|
509 |
+
|
510 |
+
# split class token
|
511 |
+
if self.cfg.encoder_with_cls_token:
|
512 |
+
z_cls = z_e[:, 0, ...]
|
513 |
+
z_e = z_e[:, 1:, ...]
|
514 |
+
|
515 |
+
# quantize or kl
|
516 |
+
z, d = self.bottleneck(z_e)
|
517 |
+
|
518 |
+
if self.cfg.encoder_with_cls_token:
|
519 |
+
d["z_cls"] = z_cls
|
520 |
+
return z_e, z, None, d
|
521 |
+
|
522 |
+
def decode(self, z: torch.Tensor):
|
523 |
+
"""
|
524 |
+
Decodes the latent representation `z` using the decoder network.
|
525 |
+
Args:
|
526 |
+
z (torch.Tensor): The latent representation tensor to be decoded.
|
527 |
+
Returns:
|
528 |
+
torch.Tensor: The decoded output tensor.
|
529 |
+
"""
|
530 |
+
|
531 |
+
h = self.decoder(z)
|
532 |
+
return h
|
533 |
+
|
534 |
+
def query(self, queries: torch.Tensor, latents: torch.Tensor):
|
535 |
+
"""
|
536 |
+
Computes the logits by decoding the given queries and latent representations.
|
537 |
+
Args:
|
538 |
+
queries (torch.Tensor): A tensor containing the query points to be decoded.
|
539 |
+
latents (torch.Tensor): A tensor containing the latent representations corresponding to the queries.
|
540 |
+
Returns:
|
541 |
+
torch.Tensor: A tensor containing the decoded logits for the given queries and latents.
|
542 |
+
"""
|
543 |
+
|
544 |
+
logits = self.occupancy_decoder(queries, latents).squeeze(-1)
|
545 |
+
return logits
|
546 |
+
|
547 |
+
def forward(self, surface, queries, **kwargs):
|
548 |
+
"""
|
549 |
+
Perform a forward pass through the autoencoder model.
|
550 |
+
Args:
|
551 |
+
surface (torch.Tensor): The input surface tensor to be encoded.
|
552 |
+
queries (torch.Tensor): The query tensor used for generating logits.
|
553 |
+
**kwargs: Additional keyword arguments.
|
554 |
+
Returns:
|
555 |
+
tuple: A tuple containing:
|
556 |
+
- z (torch.Tensor): The latent representation of the input surface.
|
557 |
+
- latents (torch.Tensor): The decoded output from the latent representation.
|
558 |
+
- None: Placeholder for a potential future return value.
|
559 |
+
- logits (torch.Tensor): The logits generated from the queries and latents.
|
560 |
+
- d (torch.Tensor): Additional output from the encoding process.
|
561 |
+
"""
|
562 |
+
|
563 |
+
_, z, _, d = self.encode(surface)
|
564 |
+
|
565 |
+
latents = self.decode(z)
|
566 |
+
logits = self.query(queries, latents)
|
567 |
+
|
568 |
+
return z, latents, None, logits, d
|
569 |
+
|
570 |
+
@torch.no_grad()
|
571 |
+
def extract_geometry(
|
572 |
+
self,
|
573 |
+
latents: torch.FloatTensor,
|
574 |
+
bounds: list[float] = [
|
575 |
+
-1.05,
|
576 |
+
-1.05,
|
577 |
+
-1.05,
|
578 |
+
1.05,
|
579 |
+
1.05,
|
580 |
+
1.05,
|
581 |
+
],
|
582 |
+
resolution_base: float = 9.0,
|
583 |
+
chunk_size: int = 2_000_000,
|
584 |
+
use_warp: bool = False,
|
585 |
+
):
|
586 |
+
"""
|
587 |
+
Extracts 3D geometry from latent representations using a dense grid sampling
|
588 |
+
and marching cubes algorithm.
|
589 |
+
Args:
|
590 |
+
latents (torch.FloatTensor): A tensor of latent representations with shape
|
591 |
+
(batch_size, latent_dim).
|
592 |
+
bounds (list[float], optional): A list of six floats defining the bounding box
|
593 |
+
for the 3D grid in the format [xmin, ymin, zmin, xmax, ymax, zmax].
|
594 |
+
Defaults to [-1.05, -1.05, -1.05, 1.05, 1.05, 1.05].
|
595 |
+
resolution_base (float, optional): The base resolution for the grid. Higher
|
596 |
+
values result in finer grids. Defaults to 9.0.
|
597 |
+
chunk_size (int, optional): The number of grid points to process in a single
|
598 |
+
chunk. Defaults to 2,000,000.
|
599 |
+
use_warp (bool, optional): Whether to use a GPU-accelerated marching cubes
|
600 |
+
implementation. If False, falls back to a CPU implementation. Defaults to False.
|
601 |
+
Returns:
|
602 |
+
tuple:
|
603 |
+
- mesh_v_f (list[tuple]): A list of tuples containing vertices and faces
|
604 |
+
for each batch element. Each tuple is of the form
|
605 |
+
(vertices, faces), where:
|
606 |
+
- vertices (np.ndarray): Array of vertex coordinates with shape
|
607 |
+
(num_vertices, 3).
|
608 |
+
- faces (np.ndarray): Array of face indices with shape
|
609 |
+
(num_faces, 3).
|
610 |
+
If geometry extraction fails for a batch element, the tuple will be
|
611 |
+
(None, None).
|
612 |
+
- has_surface (np.ndarray): A boolean array indicating whether a surface
|
613 |
+
was successfully extracted for each batch element.
|
614 |
+
Raises:
|
615 |
+
Exception: Logs warnings or errors if geometry extraction fails for any
|
616 |
+
batch element or if the marching cubes algorithm encounters issues.
|
617 |
+
"""
|
618 |
+
bbox_min = np.array(bounds[0:3])
|
619 |
+
bbox_max = np.array(bounds[3:6])
|
620 |
+
bbox_size = bbox_max - bbox_min
|
621 |
+
|
622 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
623 |
+
bbox_min=bbox_min,
|
624 |
+
bbox_max=bbox_max,
|
625 |
+
resolution_base=resolution_base,
|
626 |
+
indexing="ij",
|
627 |
+
)
|
628 |
+
xyz_samples = torch.FloatTensor(xyz_samples)
|
629 |
+
batch_size = latents.shape[0]
|
630 |
+
|
631 |
+
batch_logits = []
|
632 |
+
|
633 |
+
progress_bar = tqdm(
|
634 |
+
range(0, xyz_samples.shape[0], chunk_size),
|
635 |
+
desc=f"extracting geometry",
|
636 |
+
unit="chunk",
|
637 |
+
)
|
638 |
+
for start in progress_bar:
|
639 |
+
queries = xyz_samples[start : start + chunk_size, :]
|
640 |
+
|
641 |
+
num_queries = queries.shape[0]
|
642 |
+
if start > 0 and num_queries < chunk_size:
|
643 |
+
queries = F.pad(queries, [0, 0, 0, chunk_size - num_queries])
|
644 |
+
batch_queries = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
|
645 |
+
|
646 |
+
logits = self.query(batch_queries, latents)[:, :num_queries]
|
647 |
+
batch_logits.append(logits)
|
648 |
+
|
649 |
+
grid_logits = (
|
650 |
+
torch.cat(batch_logits, dim=1)
|
651 |
+
.detach()
|
652 |
+
.view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
653 |
+
.float()
|
654 |
+
)
|
655 |
+
|
656 |
+
mesh_v_f = []
|
657 |
+
has_surface = np.zeros((batch_size,), dtype=np.bool_)
|
658 |
+
for i in range(batch_size):
|
659 |
+
try:
|
660 |
+
warp_success = False
|
661 |
+
if use_warp:
|
662 |
+
try:
|
663 |
+
vertices, faces = marching_cubes_with_warp(
|
664 |
+
grid_logits[i],
|
665 |
+
level=0.0,
|
666 |
+
device=grid_logits.device,
|
667 |
+
)
|
668 |
+
warp_success = True
|
669 |
+
except Exception as e:
|
670 |
+
logging.warning(
|
671 |
+
f"Warning: error in marching cubes with warp: {e}"
|
672 |
+
)
|
673 |
+
warp_success = False # Fall back to CPU version
|
674 |
+
|
675 |
+
if not warp_success:
|
676 |
+
logging.warning(
|
677 |
+
"Warning: falling back to CPU version of marching cubes using skimage measure"
|
678 |
+
)
|
679 |
+
vertices, faces, _, _ = measure.marching_cubes(
|
680 |
+
grid_logits[i].cpu().numpy(), 0, method="lewiner"
|
681 |
+
)
|
682 |
+
|
683 |
+
vertices = vertices / grid_size * bbox_size + bbox_min
|
684 |
+
faces = faces[:, [2, 1, 0]]
|
685 |
+
mesh_v_f.append(
|
686 |
+
(vertices.astype(np.float32), np.ascontiguousarray(faces))
|
687 |
+
)
|
688 |
+
has_surface[i] = True
|
689 |
+
except Exception as e:
|
690 |
+
logging.error(f"Error: error in extract_geometry: {e}")
|
691 |
+
mesh_v_f.append((None, None))
|
692 |
+
has_surface[i] = False
|
693 |
+
|
694 |
+
return mesh_v_f, has_surface
|
cube/cube3d/model/autoencoder/spherical_vq.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from typing import Literal, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from cube3d.model.transformers.norm import RMSNorm
|
9 |
+
|
10 |
+
|
11 |
+
class SphericalVectorQuantizer(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
embed_dim: int,
|
15 |
+
num_codes: int,
|
16 |
+
width: Optional[int] = None,
|
17 |
+
codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm",
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Initializes the SphericalVQ module.
|
21 |
+
Args:
|
22 |
+
embed_dim (int): The dimensionality of the embeddings.
|
23 |
+
num_codes (int): The number of codes in the codebook.
|
24 |
+
width (Optional[int], optional): The width of the input. Defaults to None.
|
25 |
+
Raises:
|
26 |
+
ValueError: If beta is not in the range [0, 1].
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_codes = num_codes
|
31 |
+
|
32 |
+
self.codebook = nn.Embedding(num_codes, embed_dim)
|
33 |
+
self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
|
34 |
+
|
35 |
+
width = width or embed_dim
|
36 |
+
if width != embed_dim:
|
37 |
+
self.c_in = nn.Linear(width, embed_dim)
|
38 |
+
self.c_x = nn.Linear(width, embed_dim) # shortcut
|
39 |
+
self.c_out = nn.Linear(embed_dim, width)
|
40 |
+
else:
|
41 |
+
self.c_in = self.c_out = self.c_x = nn.Identity()
|
42 |
+
|
43 |
+
self.norm = RMSNorm(embed_dim, elementwise_affine=False)
|
44 |
+
self.cb_reg = codebook_regularization
|
45 |
+
if self.cb_reg == "batch_norm":
|
46 |
+
self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False)
|
47 |
+
else:
|
48 |
+
self.cb_weight = nn.Parameter(torch.ones([embed_dim]))
|
49 |
+
self.cb_bias = nn.Parameter(torch.zeros([embed_dim]))
|
50 |
+
self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias)
|
51 |
+
|
52 |
+
def get_codebook(self):
|
53 |
+
"""
|
54 |
+
Retrieves the normalized codebook weights.
|
55 |
+
This method applies a series of normalization operations to the
|
56 |
+
codebook weights, ensuring they are properly scaled and normalized
|
57 |
+
before being returned.
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: The normalized weights of the codebook.
|
60 |
+
"""
|
61 |
+
|
62 |
+
return self.norm(self.cb_norm(self.codebook.weight))
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
|
66 |
+
def lookup_codebook(self, q: torch.Tensor):
|
67 |
+
"""
|
68 |
+
Perform a lookup in the codebook and process the result.
|
69 |
+
This method takes an input tensor of indices, retrieves the corresponding
|
70 |
+
embeddings from the codebook, and applies a transformation to the retrieved
|
71 |
+
embeddings.
|
72 |
+
Args:
|
73 |
+
q (torch.Tensor): A tensor containing indices to look up in the codebook.
|
74 |
+
Returns:
|
75 |
+
torch.Tensor: The transformed embeddings retrieved from the codebook.
|
76 |
+
"""
|
77 |
+
|
78 |
+
# normalize codebook
|
79 |
+
z_q = F.embedding(q, self.get_codebook())
|
80 |
+
z_q = self.c_out(z_q)
|
81 |
+
return z_q
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def lookup_codebook_latents(self, q: torch.Tensor):
|
85 |
+
"""
|
86 |
+
Retrieves the latent representations from the codebook corresponding to the given indices.
|
87 |
+
Args:
|
88 |
+
q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve.
|
89 |
+
The indices should be integers and correspond to the rows in the codebook.
|
90 |
+
Returns:
|
91 |
+
torch.Tensor: A tensor containing the latent representations retrieved from the codebook.
|
92 |
+
The shape of the returned tensor depends on the shape of the input indices
|
93 |
+
and the dimensionality of the codebook entries.
|
94 |
+
"""
|
95 |
+
|
96 |
+
# normalize codebook
|
97 |
+
z_q = F.embedding(q, self.get_codebook())
|
98 |
+
return z_q
|
99 |
+
|
100 |
+
def quantize(self, z: torch.Tensor):
|
101 |
+
"""
|
102 |
+
Quantizes the latent codes z with the codebook
|
103 |
+
|
104 |
+
Args:
|
105 |
+
z (Tensor): B x ... x F
|
106 |
+
"""
|
107 |
+
|
108 |
+
# normalize codebook
|
109 |
+
codebook = self.get_codebook()
|
110 |
+
# the process of finding quantized codes is non differentiable
|
111 |
+
with torch.no_grad():
|
112 |
+
# flatten z
|
113 |
+
z_flat = z.view(-1, z.shape[-1])
|
114 |
+
|
115 |
+
# calculate distance and find the closest code
|
116 |
+
d = torch.cdist(z_flat, codebook)
|
117 |
+
q = torch.argmin(d, dim=1) # num_ele
|
118 |
+
|
119 |
+
z_q = codebook[q, :].reshape(*z.shape[:-1], -1)
|
120 |
+
q = q.view(*z.shape[:-1])
|
121 |
+
|
122 |
+
return z_q, {"z": z.detach(), "q": q}
|
123 |
+
|
124 |
+
def straight_through_approximation(self, z, z_q):
|
125 |
+
"""passed gradient from z_q to z"""
|
126 |
+
z_q = z + (z_q - z).detach()
|
127 |
+
return z_q
|
128 |
+
|
129 |
+
def forward(self, z: torch.Tensor):
|
130 |
+
"""
|
131 |
+
Forward pass of the spherical vector quantization autoencoder.
|
132 |
+
Args:
|
133 |
+
z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim).
|
134 |
+
Returns:
|
135 |
+
Tuple[torch.Tensor, Dict[str, Any]]:
|
136 |
+
- z_q (torch.Tensor): The quantized output tensor after applying the
|
137 |
+
straight-through approximation and output projection.
|
138 |
+
- ret_dict (Dict[str, Any]): A dictionary containing additional
|
139 |
+
information:
|
140 |
+
- "z_q" (torch.Tensor): Detached quantized tensor.
|
141 |
+
- "q" (torch.Tensor): Indices of the quantized vectors.
|
142 |
+
- "perplexity" (torch.Tensor): The perplexity of the quantization,
|
143 |
+
calculated as the exponential of the negative sum of the
|
144 |
+
probabilities' log values.
|
145 |
+
"""
|
146 |
+
|
147 |
+
with torch.autocast(device_type=z.device.type, enabled=False):
|
148 |
+
# work in full precision
|
149 |
+
z = z.float()
|
150 |
+
|
151 |
+
# project and normalize
|
152 |
+
z_e = self.norm(self.c_in(z))
|
153 |
+
z_q, ret_dict = self.quantize(z_e)
|
154 |
+
|
155 |
+
ret_dict["z_q"] = z_q.detach()
|
156 |
+
z_q = self.straight_through_approximation(z_e, z_q)
|
157 |
+
z_q = self.c_out(z_q)
|
158 |
+
|
159 |
+
return z_q, ret_dict
|
cube/cube3d/model/gpt/__init__.py
ADDED
File without changes
|
cube/cube3d/model/gpt/dual_stream_roformer.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from cube3d.model.transformers.cache import Cache
|
8 |
+
from cube3d.model.transformers.dual_stream_attention import (
|
9 |
+
DualStreamDecoderLayerWithRotaryEmbedding,
|
10 |
+
)
|
11 |
+
from cube3d.model.transformers.norm import LayerNorm
|
12 |
+
from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding
|
13 |
+
from cube3d.model.transformers.rope import precompute_freqs_cis
|
14 |
+
|
15 |
+
|
16 |
+
class DualStreamRoformer(nn.Module):
|
17 |
+
@dataclass
|
18 |
+
class Config:
|
19 |
+
checkpoint_path: str = ""
|
20 |
+
n_layer: int = 12
|
21 |
+
n_single_layer: int = 0
|
22 |
+
rope_theta: float = 1000
|
23 |
+
|
24 |
+
n_head: int = 16
|
25 |
+
n_embd: int = 2048
|
26 |
+
bias: bool = False # bias in Linears and LayerNorms
|
27 |
+
eps: float = 1e-6 # Norm eps
|
28 |
+
|
29 |
+
shape_model_vocab_size: int = 4096
|
30 |
+
shape_model_embed_dim: int = 16
|
31 |
+
|
32 |
+
text_model_embed_dim: int = 512
|
33 |
+
use_pooled_text_embed: bool = False
|
34 |
+
|
35 |
+
encoder_with_cls_token: bool = True
|
36 |
+
|
37 |
+
def __init__(self, cfg: Config) -> None:
|
38 |
+
"""
|
39 |
+
Initializes the DualStreamRoFormer model.
|
40 |
+
Args:
|
41 |
+
cfg (Config): Configuration object containing model parameters.
|
42 |
+
Attributes:
|
43 |
+
cfg (Config): Stores the configuration object.
|
44 |
+
text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension.
|
45 |
+
shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding
|
46 |
+
dimension
|
47 |
+
vocab_size (int): Vocabulary size for the shape model, including special tokens.
|
48 |
+
shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model.
|
49 |
+
shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model.
|
50 |
+
padding_id (int): Token ID for the padding token.
|
51 |
+
transformer (nn.ModuleDict): Dictionary containing the following components:
|
52 |
+
- wte (nn.Embedding): Embedding layer for the vocabulary.
|
53 |
+
- dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings.
|
54 |
+
- single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings.
|
55 |
+
- ln_f (LayerNorm): Layer normalization applied to the final output.
|
56 |
+
lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling.
|
57 |
+
"""
|
58 |
+
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.cfg = cfg
|
62 |
+
|
63 |
+
self.text_proj = nn.Linear(
|
64 |
+
in_features=self.cfg.text_model_embed_dim,
|
65 |
+
out_features=self.cfg.n_embd,
|
66 |
+
bias=self.cfg.bias,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd)
|
70 |
+
|
71 |
+
self.vocab_size = self.cfg.shape_model_vocab_size
|
72 |
+
|
73 |
+
def add_special_token():
|
74 |
+
token_id = self.vocab_size
|
75 |
+
self.vocab_size += 1
|
76 |
+
return token_id
|
77 |
+
|
78 |
+
self.shape_bos_id = add_special_token()
|
79 |
+
self.shape_eos_id = add_special_token()
|
80 |
+
self.padding_id = add_special_token()
|
81 |
+
|
82 |
+
self.transformer = nn.ModuleDict(
|
83 |
+
dict(
|
84 |
+
wte=nn.Embedding(
|
85 |
+
self.vocab_size,
|
86 |
+
self.cfg.n_embd,
|
87 |
+
padding_idx=self.padding_id,
|
88 |
+
),
|
89 |
+
dual_blocks=nn.ModuleList(
|
90 |
+
[
|
91 |
+
DualStreamDecoderLayerWithRotaryEmbedding.from_config(
|
92 |
+
self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1)
|
93 |
+
)
|
94 |
+
for i in range(self.cfg.n_layer)
|
95 |
+
]
|
96 |
+
),
|
97 |
+
single_blocks=nn.ModuleList(
|
98 |
+
[
|
99 |
+
DecoderLayerWithRotaryEmbedding.from_config(self.cfg)
|
100 |
+
for _ in range(self.cfg.n_single_layer)
|
101 |
+
]
|
102 |
+
),
|
103 |
+
ln_f=LayerNorm(
|
104 |
+
self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps
|
105 |
+
),
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
|
110 |
+
|
111 |
+
def encode_text(self, text_embed):
|
112 |
+
"""
|
113 |
+
Encodes the given text embeddings by projecting them through a linear transformation.
|
114 |
+
Args:
|
115 |
+
text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
|
116 |
+
Returns:
|
117 |
+
torch.Tensor: The projected text embeddings after applying the linear transformation.
|
118 |
+
"""
|
119 |
+
|
120 |
+
return self.text_proj(text_embed)
|
121 |
+
|
122 |
+
def encode_token(self, tokens):
|
123 |
+
"""
|
124 |
+
Encodes the input tokens using the word token embedding layer of the transformer model.
|
125 |
+
Args:
|
126 |
+
tokens (torch.Tensor): A tensor containing the input tokens to be encoded.
|
127 |
+
Returns:
|
128 |
+
torch.Tensor: A tensor containing the encoded token embeddings.
|
129 |
+
"""
|
130 |
+
|
131 |
+
return self.transformer.wte(tokens)
|
132 |
+
|
133 |
+
def init_kv_cache(
|
134 |
+
self,
|
135 |
+
batch_size: int,
|
136 |
+
cond_len: int,
|
137 |
+
max_shape_tokens: int,
|
138 |
+
dtype: torch.dtype,
|
139 |
+
device: torch.device,
|
140 |
+
) -> list[Cache]:
|
141 |
+
"""
|
142 |
+
Initializes the key-value cache for the transformer model.
|
143 |
+
This method creates a list of `Cache` objects to store the key and value
|
144 |
+
states for both dual-stream and single-stream transformer blocks. The
|
145 |
+
cache is pre-allocated with zeros and is used to optimize the computation
|
146 |
+
of attention mechanisms during model inference.
|
147 |
+
Args:
|
148 |
+
batch_size (int): The batch size for the input data.
|
149 |
+
cond_len (int): The length of the conditioning sequence.
|
150 |
+
max_shape_tokens (int): The maximum number of tokens in the shape sequence.
|
151 |
+
dtype (torch.dtype): The data type for the tensors (e.g., torch.float32).
|
152 |
+
device (torch.device): The device on which the tensors will be allocated
|
153 |
+
(e.g., torch.device('cuda') or torch.device('cpu')).
|
154 |
+
Returns:
|
155 |
+
list[Cache]: A list of `Cache` objects containing pre-allocated key and
|
156 |
+
value states for each transformer block.
|
157 |
+
"""
|
158 |
+
num_heads = self.cfg.n_head
|
159 |
+
max_all_tokens = cond_len + max_shape_tokens
|
160 |
+
per_head_dim = self.cfg.n_embd // num_heads
|
161 |
+
|
162 |
+
kv_cache = [
|
163 |
+
Cache(
|
164 |
+
key_states=torch.zeros(
|
165 |
+
(batch_size, num_heads, max_all_tokens, per_head_dim),
|
166 |
+
dtype=dtype,
|
167 |
+
device=device,
|
168 |
+
),
|
169 |
+
value_states=torch.zeros(
|
170 |
+
(batch_size, num_heads, max_all_tokens, per_head_dim),
|
171 |
+
dtype=dtype,
|
172 |
+
device=device,
|
173 |
+
),
|
174 |
+
)
|
175 |
+
for _ in range(len(self.transformer.dual_blocks))
|
176 |
+
]
|
177 |
+
kv_cache += [
|
178 |
+
Cache(
|
179 |
+
key_states=torch.zeros(
|
180 |
+
(batch_size, num_heads, max_shape_tokens, per_head_dim),
|
181 |
+
dtype=dtype,
|
182 |
+
device=device,
|
183 |
+
),
|
184 |
+
value_states=torch.zeros(
|
185 |
+
(batch_size, num_heads, max_shape_tokens, per_head_dim),
|
186 |
+
dtype=dtype,
|
187 |
+
device=device,
|
188 |
+
),
|
189 |
+
)
|
190 |
+
for _ in range(len(self.transformer.single_blocks))
|
191 |
+
]
|
192 |
+
return kv_cache
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self,
|
196 |
+
embed: torch.Tensor,
|
197 |
+
cond: torch.Tensor,
|
198 |
+
kv_cache: Optional[list[Cache]] = None,
|
199 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
200 |
+
decode: bool = False,
|
201 |
+
):
|
202 |
+
"""
|
203 |
+
Forward pass for the dual-stream RoFormer model.
|
204 |
+
Args:
|
205 |
+
embed (torch.Tensor): The input embedding tensor.
|
206 |
+
cond (torch.Tensor): The conditioning tensor.
|
207 |
+
kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None.
|
208 |
+
curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None.
|
209 |
+
decode (bool): Whether the model is in decoding mode. Default is False.
|
210 |
+
Returns:
|
211 |
+
torch.Tensor: The output logits tensor.
|
212 |
+
"""
|
213 |
+
b, l = embed.shape[:2]
|
214 |
+
s = cond.shape[1]
|
215 |
+
device = embed.device
|
216 |
+
|
217 |
+
attn_mask = torch.tril(
|
218 |
+
torch.ones(s + l, s + l, dtype=torch.bool, device=device)
|
219 |
+
)
|
220 |
+
|
221 |
+
position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t)
|
222 |
+
position_ids = position_ids.unsqueeze_(0).expand(b, -1)
|
223 |
+
|
224 |
+
s_freqs_cis = precompute_freqs_cis(
|
225 |
+
dim=self.cfg.n_embd // self.cfg.n_head,
|
226 |
+
t=position_ids,
|
227 |
+
theta=self.cfg.rope_theta,
|
228 |
+
)
|
229 |
+
|
230 |
+
position_ids = torch.cat(
|
231 |
+
[
|
232 |
+
torch.zeros([b, s], dtype=torch.long, device=position_ids.device),
|
233 |
+
position_ids,
|
234 |
+
],
|
235 |
+
dim=1,
|
236 |
+
)
|
237 |
+
d_freqs_cis = precompute_freqs_cis(
|
238 |
+
dim=self.cfg.n_embd // self.cfg.n_head,
|
239 |
+
t=position_ids,
|
240 |
+
theta=self.cfg.rope_theta,
|
241 |
+
)
|
242 |
+
|
243 |
+
if kv_cache is not None and decode:
|
244 |
+
assert curr_pos_id is not None
|
245 |
+
embed = embed[:, curr_pos_id, :]
|
246 |
+
|
247 |
+
h = embed
|
248 |
+
c = cond
|
249 |
+
|
250 |
+
layer_idx = 0
|
251 |
+
for block in self.transformer.dual_blocks:
|
252 |
+
h, c = block(
|
253 |
+
h,
|
254 |
+
c=c,
|
255 |
+
freqs_cis=d_freqs_cis,
|
256 |
+
attn_mask=attn_mask,
|
257 |
+
is_causal=True,
|
258 |
+
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
259 |
+
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
260 |
+
decode=decode,
|
261 |
+
)
|
262 |
+
layer_idx += 1
|
263 |
+
for block in self.transformer.single_blocks:
|
264 |
+
h = block(
|
265 |
+
h,
|
266 |
+
freqs_cis=s_freqs_cis,
|
267 |
+
attn_mask=None,
|
268 |
+
is_causal=True,
|
269 |
+
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
270 |
+
curr_pos_id=curr_pos_id,
|
271 |
+
decode=decode,
|
272 |
+
)
|
273 |
+
layer_idx += 1
|
274 |
+
|
275 |
+
# Normalization
|
276 |
+
h = self.transformer.ln_f(h)
|
277 |
+
logits = self.lm_head(h)
|
278 |
+
|
279 |
+
return logits
|
cube/cube3d/model/transformers/__init__.py
ADDED
File without changes
|
cube/cube3d/model/transformers/attention.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from cube3d.model.transformers.norm import LayerNorm, RMSNorm
|
7 |
+
|
8 |
+
|
9 |
+
def init_linear(module, embed_dim: int):
|
10 |
+
"""
|
11 |
+
Initializes the weights and biases of a given linear module.
|
12 |
+
Args:
|
13 |
+
module (nn.Module): The module to initialize. Expected to be an instance of nn.Linear.
|
14 |
+
embed_dim (int): The embedding dimension used to calculate the standard deviation
|
15 |
+
for weight initialization.
|
16 |
+
Returns:
|
17 |
+
None
|
18 |
+
"""
|
19 |
+
|
20 |
+
if isinstance(module, nn.Linear):
|
21 |
+
nn.init.normal_(module.weight, std=math.sqrt(1.0 / embed_dim))
|
22 |
+
if module.bias is not None:
|
23 |
+
torch.nn.init.zeros_(module.bias)
|
24 |
+
|
25 |
+
|
26 |
+
def init_tfixup(module: nn.Module, num_layers: int):
|
27 |
+
"""Special initialization from https://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
|
28 |
+
|
29 |
+
Args:
|
30 |
+
module (nn.Module): decoder/encoder module
|
31 |
+
num_layers (int): number of layers in the module
|
32 |
+
"""
|
33 |
+
with torch.no_grad():
|
34 |
+
for pn, p in module.named_parameters():
|
35 |
+
if (
|
36 |
+
pn.endswith("c_proj.weight")
|
37 |
+
or pn.endswith("up_proj.weight")
|
38 |
+
or pn.endswith("down_proj.weight")
|
39 |
+
):
|
40 |
+
p *= (4 * num_layers) ** (-0.25)
|
41 |
+
elif pn.endswith("c_v.weight"):
|
42 |
+
p *= (4 * num_layers) ** (-0.25) * math.sqrt(2)
|
43 |
+
|
44 |
+
|
45 |
+
class MLP(nn.Module):
|
46 |
+
def __init__(self, embed_dim, hidden_dim, bias=True, approximate="none"):
|
47 |
+
"""
|
48 |
+
MLP with GELU activation function."
|
49 |
+
"""
|
50 |
+
|
51 |
+
super().__init__()
|
52 |
+
self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
|
53 |
+
self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
|
54 |
+
self.act_fn = nn.GELU(approximate=approximate)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return self.down_proj(self.act_fn(self.up_proj(x)))
|
58 |
+
|
59 |
+
|
60 |
+
class SelfAttention(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
embed_dim: int,
|
64 |
+
num_heads: int,
|
65 |
+
bias: bool = True,
|
66 |
+
eps: float = 1e-6,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Initializes the self attention mechanism.
|
70 |
+
Args:
|
71 |
+
embed_dim (int): The dimensionality of the embedding space.
|
72 |
+
num_heads (int): The number of attention heads.
|
73 |
+
bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
|
74 |
+
eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
|
75 |
+
Raises:
|
76 |
+
AssertionError: If `embed_dim` is not divisible by `num_heads`.
|
77 |
+
"""
|
78 |
+
|
79 |
+
super().__init__()
|
80 |
+
assert embed_dim % num_heads == 0
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=bias)
|
83 |
+
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
|
84 |
+
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
85 |
+
|
86 |
+
head_dim = embed_dim // num_heads
|
87 |
+
self.q_norm = RMSNorm(head_dim)
|
88 |
+
self.k_norm = RMSNorm(head_dim)
|
89 |
+
|
90 |
+
def forward(self, x, attn_mask=None, is_causal: bool = False):
|
91 |
+
"""
|
92 |
+
Performs the forward pass of the attention mechanism.
|
93 |
+
Args:
|
94 |
+
x (torch.Tensor): Input tensor.
|
95 |
+
attn_mask (Optional[torch.Tensor]): Attention mask to apply. Default is None.
|
96 |
+
is_causal (bool): If True, applies a causal mask to prevent attending to future positions.
|
97 |
+
Default is False.
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Output tensor after applying
|
100 |
+
the attention mechanism and projection.
|
101 |
+
"""
|
102 |
+
|
103 |
+
b, l, d = x.shape
|
104 |
+
|
105 |
+
q, k = self.c_qk(x).chunk(2, dim=-1)
|
106 |
+
v = self.c_v(x)
|
107 |
+
|
108 |
+
q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
109 |
+
k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
110 |
+
v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
111 |
+
|
112 |
+
q = self.q_norm(q)
|
113 |
+
k = self.k_norm(k)
|
114 |
+
|
115 |
+
is_causal = is_causal and attn_mask is None
|
116 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
117 |
+
q,
|
118 |
+
k,
|
119 |
+
v,
|
120 |
+
attn_mask=attn_mask,
|
121 |
+
dropout_p=0.0,
|
122 |
+
is_causal=is_causal,
|
123 |
+
)
|
124 |
+
|
125 |
+
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
126 |
+
|
127 |
+
y = self.c_proj(y)
|
128 |
+
return y
|
129 |
+
|
130 |
+
|
131 |
+
class CrossAttention(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
embed_dim: int,
|
135 |
+
num_heads: int,
|
136 |
+
q_dim=None,
|
137 |
+
kv_dim=None,
|
138 |
+
bias: bool = True,
|
139 |
+
):
|
140 |
+
"""
|
141 |
+
Initializes the cross attention mechanism.
|
142 |
+
Args:
|
143 |
+
embed_dim (int): The dimensionality of the embedding space.
|
144 |
+
num_heads (int): The number of attention heads.
|
145 |
+
q_dim (int, optional): The dimensionality of the query input. Defaults to `embed_dim`.
|
146 |
+
kv_dim (int, optional): The dimensionality of the key and value inputs. Defaults to `embed_dim`.
|
147 |
+
bias (bool, optional): Whether to include a bias term in the linear projections. Defaults to True.
|
148 |
+
Raises:
|
149 |
+
AssertionError: If `embed_dim` is not divisible by `num_heads`.
|
150 |
+
"""
|
151 |
+
super().__init__()
|
152 |
+
assert embed_dim % num_heads == 0
|
153 |
+
|
154 |
+
q_dim = q_dim or embed_dim
|
155 |
+
kv_dim = kv_dim or embed_dim
|
156 |
+
|
157 |
+
self.c_q = nn.Linear(q_dim, embed_dim, bias=bias)
|
158 |
+
self.c_k = nn.Linear(kv_dim, embed_dim, bias=bias)
|
159 |
+
self.c_v = nn.Linear(kv_dim, embed_dim, bias=bias)
|
160 |
+
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
161 |
+
self.num_heads = num_heads
|
162 |
+
|
163 |
+
def forward(self, x, c, attn_mask=None, is_causal: bool = False):
|
164 |
+
"""
|
165 |
+
Forward pass for the attention mechanism.
|
166 |
+
Args:
|
167 |
+
x (torch.Tensor): Input tensor of shape.
|
168 |
+
c (torch.Tensor): Context tensor.
|
169 |
+
attn_mask (torch.Tensor, optional): Attention mask.
|
170 |
+
Defaults to None.
|
171 |
+
is_causal (bool, optional): Whether to apply causal masking. Defaults to False.
|
172 |
+
Returns:
|
173 |
+
torch.Tensor: Output tensor.
|
174 |
+
"""
|
175 |
+
|
176 |
+
q, k = self.c_q(x), self.c_k(c)
|
177 |
+
v = self.c_v(c)
|
178 |
+
|
179 |
+
b, l, d = q.shape
|
180 |
+
s = k.shape[1]
|
181 |
+
|
182 |
+
q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
183 |
+
k = k.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
184 |
+
v = v.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
185 |
+
|
186 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
187 |
+
q,
|
188 |
+
k,
|
189 |
+
v,
|
190 |
+
attn_mask=attn_mask,
|
191 |
+
dropout_p=0.0,
|
192 |
+
is_causal=(attn_mask is not None) and is_causal,
|
193 |
+
)
|
194 |
+
|
195 |
+
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
196 |
+
|
197 |
+
y = self.c_proj(y)
|
198 |
+
return y
|
199 |
+
|
200 |
+
|
201 |
+
class EncoderLayer(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
embed_dim: int,
|
205 |
+
num_heads: int,
|
206 |
+
bias: bool = True,
|
207 |
+
eps: float = 1e-6,
|
208 |
+
) -> None:
|
209 |
+
"""
|
210 |
+
Initializes the EncoderLayer module.
|
211 |
+
Args:
|
212 |
+
embed_dim (int): The dimensionality of the embedding space.
|
213 |
+
num_heads (int): The number of attention heads.
|
214 |
+
bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
|
215 |
+
eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
|
216 |
+
"""
|
217 |
+
super().__init__()
|
218 |
+
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
219 |
+
self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps)
|
220 |
+
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
221 |
+
self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
|
222 |
+
|
223 |
+
def forward(self, x, attn_mask=None, is_causal: bool = False):
|
224 |
+
"""
|
225 |
+
Performs the forward pass of the transformer block.
|
226 |
+
Args:
|
227 |
+
x (torch.Tensor): The input tensor.
|
228 |
+
attn_mask (torch.Tensor, optional): An optional attention mask tensor to apply during the
|
229 |
+
attention computation. Default is None.
|
230 |
+
is_causal (bool, optional): If True, applies a causal mask to prevent attention to future
|
231 |
+
positions. Default is False.
|
232 |
+
Returns:
|
233 |
+
torch.Tensor: The output tensor of the same shape as the input.
|
234 |
+
"""
|
235 |
+
|
236 |
+
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
|
237 |
+
x = x + self.mlp(self.ln_2(x))
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class EncoderCrossAttentionLayer(nn.Module):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
embed_dim: int,
|
245 |
+
num_heads: int,
|
246 |
+
q_dim=None,
|
247 |
+
kv_dim=None,
|
248 |
+
bias: bool = True,
|
249 |
+
eps: float = 1e-6,
|
250 |
+
) -> None:
|
251 |
+
"""
|
252 |
+
Initializes the EncoderAttentionLayer module with cross-attention,
|
253 |
+
and a feed-forward MLP.
|
254 |
+
Args:
|
255 |
+
embed_dim (int): The dimensionality of the embedding space.
|
256 |
+
num_heads (int): The number of attention heads.
|
257 |
+
q_dim (int, optional): Dimensionality of the query input. Defaults to `embed_dim`.
|
258 |
+
kv_dim (int, optional): Dimensionality of the key and value inputs. Defaults to `embed_dim`.
|
259 |
+
bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
|
260 |
+
eps (float, optional): A small value added to the denominator for numerical stability
|
261 |
+
in layer normalization. Defaults to 1e-6.
|
262 |
+
"""
|
263 |
+
super().__init__()
|
264 |
+
|
265 |
+
q_dim = q_dim or embed_dim
|
266 |
+
kv_dim = kv_dim or embed_dim
|
267 |
+
|
268 |
+
self.attn = CrossAttention(
|
269 |
+
embed_dim,
|
270 |
+
num_heads,
|
271 |
+
q_dim=q_dim,
|
272 |
+
kv_dim=kv_dim,
|
273 |
+
bias=bias,
|
274 |
+
)
|
275 |
+
|
276 |
+
self.ln_1 = LayerNorm(q_dim, elementwise_affine=False, eps=eps)
|
277 |
+
self.ln_2 = LayerNorm(kv_dim, elementwise_affine=False, eps=eps)
|
278 |
+
|
279 |
+
self.ln_f = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
280 |
+
self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
|
281 |
+
|
282 |
+
def forward(self, x, c, attn_mask=None, is_causal: bool = False):
|
283 |
+
"""
|
284 |
+
Forward pass for the attention mechanism.
|
285 |
+
Args:
|
286 |
+
x (torch.Tensor): The input tensor to the attention mechanism.
|
287 |
+
c (torch.Tensor): The context tensor used for cross-attention.
|
288 |
+
attn_mask (torch.Tensor, optional): An optional attention mask to control
|
289 |
+
which positions can attend to others. Defaults to None.
|
290 |
+
is_causal (bool, optional): If True, applies a causal mask to prevent
|
291 |
+
attending to future positions. Defaults to False.
|
292 |
+
Returns:
|
293 |
+
torch.Tensor: The output tensor after applying attention and MLP layers.
|
294 |
+
"""
|
295 |
+
|
296 |
+
x = x + self.attn(
|
297 |
+
self.ln_1(x), self.ln_2(c), attn_mask=attn_mask, is_causal=is_causal
|
298 |
+
)
|
299 |
+
x = x + self.mlp(self.ln_f(x))
|
300 |
+
return x
|
cube/cube3d/model/transformers/cache.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Cache:
|
8 |
+
key_states: torch.Tensor
|
9 |
+
value_states: torch.Tensor
|
cube/cube3d/model/transformers/dual_stream_attention.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from cube3d.model.transformers.cache import Cache
|
7 |
+
from cube3d.model.transformers.norm import LayerNorm, RMSNorm
|
8 |
+
from cube3d.model.transformers.roformer import SwiGLUMLP
|
9 |
+
from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb
|
10 |
+
|
11 |
+
|
12 |
+
class DismantledPreAttention(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
embed_dim: int,
|
16 |
+
num_heads: int,
|
17 |
+
query: bool = True,
|
18 |
+
bias: bool = True,
|
19 |
+
) -> None:
|
20 |
+
"""
|
21 |
+
Initializes the DismantledPreAttention module.
|
22 |
+
Args:
|
23 |
+
embed_dim (int): The dimensionality of the embedding space.
|
24 |
+
num_heads (int): The number of attention heads.
|
25 |
+
query (bool, optional): Whether to include query-key projection. Defaults to True.
|
26 |
+
bias (bool, optional): Whether to include bias in linear layers. Defaults to True.
|
27 |
+
Raises:
|
28 |
+
AssertionError: If `embed_dim` is not divisible by `num_heads`.
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
assert embed_dim % num_heads == 0
|
32 |
+
self.query = query
|
33 |
+
|
34 |
+
head_dim = embed_dim // num_heads
|
35 |
+
# key, query, value projections for all heads, but in a batch
|
36 |
+
if query:
|
37 |
+
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
|
38 |
+
self.q_norm = RMSNorm(head_dim)
|
39 |
+
else:
|
40 |
+
self.c_k = nn.Linear(embed_dim, embed_dim, bias=bias)
|
41 |
+
self.k_norm = RMSNorm(head_dim)
|
42 |
+
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
|
43 |
+
|
44 |
+
# (B, T, C) -> (B, nh, T, hs)
|
45 |
+
self.to_mha = lambda x: x.view(*x.shape[:2], num_heads, -1).transpose(1, 2)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
"""
|
49 |
+
Forward pass for the dismantled pre-attention mechanism.
|
50 |
+
Args:
|
51 |
+
x (torch.Tensor): Input tensor of shape (..., input_dim).
|
52 |
+
Returns:
|
53 |
+
tuple: A tuple containing:
|
54 |
+
- q (torch.Tensor or None): Query tensor after normalization and transformation,
|
55 |
+
or None if `self.query` is False.
|
56 |
+
- k (torch.Tensor): Key tensor after normalization and transformation.
|
57 |
+
- v (torch.Tensor): Value tensor after transformation.
|
58 |
+
"""
|
59 |
+
|
60 |
+
if self.query:
|
61 |
+
q, k = self.c_qk(x).chunk(2, dim=-1)
|
62 |
+
q = self.q_norm(self.to_mha(q))
|
63 |
+
else:
|
64 |
+
q = None
|
65 |
+
k = self.c_k(x)
|
66 |
+
|
67 |
+
k = self.k_norm(self.to_mha(k))
|
68 |
+
v = self.to_mha(self.c_v(x))
|
69 |
+
|
70 |
+
return (q, k, v)
|
71 |
+
|
72 |
+
|
73 |
+
class DismantledPostAttention(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
embed_dim,
|
77 |
+
bias: bool = True,
|
78 |
+
eps: float = 1e-6,
|
79 |
+
) -> None:
|
80 |
+
"""
|
81 |
+
Initializes the DismantledPostAttention module.
|
82 |
+
Args:
|
83 |
+
embed_dim (int): The dimensionality of the embedding space.
|
84 |
+
bias (bool, optional): Whether to include a bias term in the linear projection. Defaults to True.
|
85 |
+
eps (float, optional): A small value added to the denominator for numerical stability in layer normalization. Defaults to 1e-6.
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
89 |
+
self.ln_3 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
90 |
+
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias)
|
91 |
+
|
92 |
+
def forward(self, x, a):
|
93 |
+
"""
|
94 |
+
Forward pass of the dual stream attention mechanism.
|
95 |
+
Args:
|
96 |
+
x (torch.Tensor): The input tensor to the model.
|
97 |
+
a (torch.Tensor): The attention tensor to be combined with the input.
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: The output tensor after applying the projection,
|
100 |
+
layer normalization, and MLP transformations.
|
101 |
+
"""
|
102 |
+
|
103 |
+
x = x + self.c_proj(a)
|
104 |
+
x = x + self.mlp(self.ln_3(x))
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
embed_dim: int,
|
112 |
+
num_heads: int,
|
113 |
+
cond_pre_only: bool = False,
|
114 |
+
bias: bool = True,
|
115 |
+
):
|
116 |
+
"""
|
117 |
+
Initializes the DualStreamAttention module.
|
118 |
+
Args:
|
119 |
+
embed_dim (int): The dimensionality of the embedding space.
|
120 |
+
num_heads (int): The number of attention heads.
|
121 |
+
cond_pre_only (bool, optional): If True, the conditional pre-attention
|
122 |
+
will only process the key and value, not the query. Defaults to False.
|
123 |
+
bias (bool, optional): Whether to include a bias term in the attention layers.
|
124 |
+
Defaults to True.
|
125 |
+
"""
|
126 |
+
super().__init__()
|
127 |
+
|
128 |
+
self.cond_pre_only = cond_pre_only
|
129 |
+
|
130 |
+
self.pre_x = DismantledPreAttention(
|
131 |
+
embed_dim=embed_dim, num_heads=num_heads, query=True, bias=bias
|
132 |
+
)
|
133 |
+
|
134 |
+
self.pre_c = DismantledPreAttention(
|
135 |
+
embed_dim=embed_dim, num_heads=num_heads, query=not cond_pre_only, bias=bias
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(
|
139 |
+
self,
|
140 |
+
x,
|
141 |
+
c: Optional[torch.Tensor],
|
142 |
+
freqs_cis,
|
143 |
+
attn_mask: Optional[torch.Tensor] = None,
|
144 |
+
is_causal: bool = False,
|
145 |
+
kv_cache: Optional[Cache] = None,
|
146 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
147 |
+
decode: bool = False,
|
148 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
149 |
+
"""
|
150 |
+
Forward pass for dual stream Multi-Head Attention.
|
151 |
+
|
152 |
+
Efficient single weight matrix multiplication with results split into query, key, value.
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
x : torch.Tensor
|
157 |
+
Hidden states [B, L, D]
|
158 |
+
c : torch.Tensor
|
159 |
+
Condition [B, S, D]
|
160 |
+
freqs_cis: torch.Tensor
|
161 |
+
Precomputed RoPE matrix from precompute_freqs_cis [B, S+L, Hd]
|
162 |
+
attn_mask : torch.Tensor, optional
|
163 |
+
Attention mask [B, S+L, S+L], by default None
|
164 |
+
kv_cache: None | Tensor
|
165 |
+
key-value cache, but only if not None; if None - it means that it's disabled
|
166 |
+
contains cache for keys and value from all previous steps
|
167 |
+
kv_cache_cond: None | Tensor
|
168 |
+
key-value cache, but only if not None; if None - it means that it's disabled
|
169 |
+
contains cache for keys and value from all previous steps for the text conditioning.
|
170 |
+
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
torch.Tensor
|
174 |
+
Hidden state output [B, L, D]
|
175 |
+
"""
|
176 |
+
if kv_cache is None or not decode:
|
177 |
+
# Either training or prefill
|
178 |
+
qkv_c = self.pre_c(c)
|
179 |
+
qkv_x = self.pre_x(x)
|
180 |
+
# prepend condition stream
|
181 |
+
# (B, nh, Tc, hs) + (B, nh, Tx, hs) -> (B, nh, Tc+Tx, hs)
|
182 |
+
if self.cond_pre_only:
|
183 |
+
q = qkv_x[0]
|
184 |
+
else:
|
185 |
+
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
186 |
+
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
187 |
+
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
188 |
+
|
189 |
+
else:
|
190 |
+
# if using kv cache, query would only be the last token in the sequence, hence is_causal is False
|
191 |
+
assert x.shape[1] == 1
|
192 |
+
is_causal = False
|
193 |
+
q, k, v = self.pre_x(x)
|
194 |
+
|
195 |
+
if kv_cache is not None:
|
196 |
+
if not decode:
|
197 |
+
kv_cache.key_states[:, :, : k.shape[2], :].copy_(k)
|
198 |
+
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
199 |
+
else:
|
200 |
+
assert curr_pos_id is not None
|
201 |
+
kv_cache.key_states.index_copy_(2, curr_pos_id, k)
|
202 |
+
kv_cache.value_states.index_copy_(2, curr_pos_id, v)
|
203 |
+
k = kv_cache.key_states
|
204 |
+
v = kv_cache.value_states
|
205 |
+
|
206 |
+
if attn_mask is not None:
|
207 |
+
# trim attention mask to length
|
208 |
+
if decode:
|
209 |
+
assert curr_pos_id is not None
|
210 |
+
attn_mask = attn_mask[..., curr_pos_id, :]
|
211 |
+
else:
|
212 |
+
attn_mask = attn_mask[..., -q.shape[2] :, :]
|
213 |
+
|
214 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
215 |
+
# efficient attention using Flash Attention CUDA kernels
|
216 |
+
y = scaled_dot_product_attention_with_rotary_emb(
|
217 |
+
q,
|
218 |
+
k,
|
219 |
+
v,
|
220 |
+
freqs_cis=freqs_cis,
|
221 |
+
attn_mask=attn_mask,
|
222 |
+
curr_pos_id=curr_pos_id if decode else None,
|
223 |
+
is_causal=is_causal,
|
224 |
+
)
|
225 |
+
|
226 |
+
# re-assemble all head outputs side by side
|
227 |
+
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
228 |
+
|
229 |
+
if y.shape[1] == x.shape[1]:
|
230 |
+
y_c = None
|
231 |
+
y_x = y
|
232 |
+
else:
|
233 |
+
assert c is not None, "Conditioning is required for dual stream attention"
|
234 |
+
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
235 |
+
return y_x, y_c
|
236 |
+
|
237 |
+
|
238 |
+
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
239 |
+
"""Nicely wrapped decoder layer block for dual stream GPT model"""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
embed_dim,
|
244 |
+
num_heads: int,
|
245 |
+
cond_pre_only: bool = False,
|
246 |
+
bias: bool = True,
|
247 |
+
eps: float = 1.0e-6,
|
248 |
+
) -> None:
|
249 |
+
"""
|
250 |
+
Initializes the DualStreamDecoderLayerWithRotaryEmbedding module with optional conditional pre-only mode.
|
251 |
+
Args:
|
252 |
+
embed_dim (int): The dimensionality of the embedding space.
|
253 |
+
num_heads (int): The number of attention heads.
|
254 |
+
cond_pre_only (bool, optional): If True, applies conditional processing only before attention. Defaults to False.
|
255 |
+
bias (bool, optional): If True, includes bias terms in the attention and post-attention layers. Defaults to True.
|
256 |
+
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1.0e-6.
|
257 |
+
"""
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
261 |
+
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
262 |
+
|
263 |
+
self.attn = DualStreamAttentionWithRotaryEmbedding(
|
264 |
+
embed_dim=embed_dim,
|
265 |
+
num_heads=num_heads,
|
266 |
+
cond_pre_only=cond_pre_only,
|
267 |
+
bias=bias,
|
268 |
+
)
|
269 |
+
|
270 |
+
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
|
271 |
+
if not cond_pre_only:
|
272 |
+
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
|
273 |
+
|
274 |
+
@classmethod
|
275 |
+
def from_config(cls, cfg, cond_pre_only: bool = False):
|
276 |
+
"""
|
277 |
+
Create an instance of the class using the provided configuration.
|
278 |
+
Args:
|
279 |
+
cfg: A configuration object containing the necessary parameters:
|
280 |
+
- n_embd (int): The size of the embedding dimension.
|
281 |
+
- n_head (int): The number of attention heads.
|
282 |
+
- bias (bool): Whether to include a bias term.
|
283 |
+
- eps (float): A small value added for numerical stability.
|
284 |
+
cond_pre_only (bool, optional): If True, applies conditioning only in the pre-processing step.
|
285 |
+
Defaults to False.
|
286 |
+
Returns:
|
287 |
+
An instance of the class initialized with the specified configuration.
|
288 |
+
"""
|
289 |
+
|
290 |
+
return cls(
|
291 |
+
cfg.n_embd,
|
292 |
+
num_heads=cfg.n_head,
|
293 |
+
cond_pre_only=cond_pre_only,
|
294 |
+
bias=cfg.bias,
|
295 |
+
eps=cfg.eps,
|
296 |
+
)
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
x,
|
301 |
+
c,
|
302 |
+
freqs_cis: torch.Tensor,
|
303 |
+
attn_mask: Optional[torch.Tensor] = None,
|
304 |
+
is_causal: bool = True,
|
305 |
+
kv_cache: Optional[Cache] = None,
|
306 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
307 |
+
decode: bool = False,
|
308 |
+
):
|
309 |
+
"""
|
310 |
+
Forward pass for DualStreamDecoderLayerWithRotaryEmbedding.
|
311 |
+
|
312 |
+
Parameters
|
313 |
+
----------
|
314 |
+
x : torch.Tensor
|
315 |
+
Hidden states [B, L, D]
|
316 |
+
c : torch.Tensor
|
317 |
+
Condition [B, S, D]
|
318 |
+
freqs_cis: torch.Tensor
|
319 |
+
Postional embedding from RoPE [B, S+L, hd]
|
320 |
+
attn_mask : torch.Tensor, optional
|
321 |
+
Attention mask [B, S+L, S+L], by default None
|
322 |
+
kv_vache : torch.Tensor, optional
|
323 |
+
kv_cache by default None
|
324 |
+
|
325 |
+
Returns
|
326 |
+
-------
|
327 |
+
torch.Tensor
|
328 |
+
Hidden state output [B, L, D]
|
329 |
+
torch.Tensor
|
330 |
+
kv_cache output [1, L, D]
|
331 |
+
"""
|
332 |
+
a_x, a_c = self.attn(
|
333 |
+
self.ln_1(x),
|
334 |
+
# NOTE condition could be none if using kv cache
|
335 |
+
self.ln_2(c) if c is not None else None,
|
336 |
+
freqs_cis=freqs_cis,
|
337 |
+
attn_mask=attn_mask,
|
338 |
+
is_causal=is_causal,
|
339 |
+
kv_cache=kv_cache,
|
340 |
+
curr_pos_id=curr_pos_id,
|
341 |
+
decode=decode,
|
342 |
+
)
|
343 |
+
x = self.post_1(x, a_x)
|
344 |
+
if a_c is not None:
|
345 |
+
c = self.post_2(c, a_c)
|
346 |
+
else:
|
347 |
+
c = None
|
348 |
+
return x, c
|
cube/cube3d/model/transformers/norm.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
|
6 |
+
"""
|
7 |
+
Applies a fused Root Mean Square (RMS) normalization to the input tensor.
|
8 |
+
Args:
|
9 |
+
x (torch.Tensor): The input tensor to be normalized. Expected to have
|
10 |
+
at least one dimension.
|
11 |
+
weight (nn.Parameter): A learnable parameter used to scale the normalized
|
12 |
+
tensor. Its shape must be broadcastable to the shape of `x`.
|
13 |
+
eps (float): A small constant added to the denominator for numerical
|
14 |
+
stability during normalization.
|
15 |
+
Returns:
|
16 |
+
torch.Tensor: The normalized and scaled tensor with the same shape as `x`.
|
17 |
+
"""
|
18 |
+
|
19 |
+
x = x.float()
|
20 |
+
return (x * torch.rsqrt((x * x).mean(-1, keepdim=True).add_(eps))) * weight
|
21 |
+
|
22 |
+
|
23 |
+
class LayerNorm(nn.LayerNorm):
|
24 |
+
def forward(self, input: torch.Tensor):
|
25 |
+
"""
|
26 |
+
Wrapper to ensure that the input tensor is cast to float before normalization.
|
27 |
+
"""
|
28 |
+
y = super().forward(input.float())
|
29 |
+
return y.type_as(input)
|
30 |
+
|
31 |
+
|
32 |
+
class RMSNorm(nn.Module):
|
33 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine: bool = True):
|
34 |
+
"""
|
35 |
+
Initializes the normalization layer.
|
36 |
+
Args:
|
37 |
+
dim (int): The number of features in the input tensor.
|
38 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Defaults to 1e-5.
|
39 |
+
elementwise_affine (bool, optional): If True, this layer will have learnable per-element affine parameters. Defaults to True.
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
self.eps = eps
|
43 |
+
self.weight = nn.Parameter(torch.ones(dim), requires_grad=elementwise_affine)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return fused_rms_norm(x, weight=self.weight, eps=self.eps).type_as(x)
|
cube/cube3d/model/transformers/roformer.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from cube3d.model.transformers.cache import Cache
|
8 |
+
from cube3d.model.transformers.norm import LayerNorm, RMSNorm
|
9 |
+
from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb
|
10 |
+
|
11 |
+
|
12 |
+
class SwiGLUMLP(nn.Module):
|
13 |
+
def __init__(self, embed_dim, hidden_dim, bias=True, **kwargs):
|
14 |
+
"""
|
15 |
+
A PyTorch implementation of the SwiGLU (Swish-Gated Linear Unit) MLP layer.
|
16 |
+
This module consists of three linear projections: `gate_proj`, `up_proj`, and `down_proj`.
|
17 |
+
It applies the SwiGLU activation function, which combines the Swish activation with a gating mechanism,
|
18 |
+
followed by a projection back to the original embedding dimension.
|
19 |
+
Args:
|
20 |
+
embed_dim (int): The dimensionality of the input embeddings.
|
21 |
+
hidden_dim (int): The dimensionality of the hidden layer.
|
22 |
+
bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
|
23 |
+
**kwargs: Additional keyword arguments (currently unused).
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
|
27 |
+
self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
|
28 |
+
self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
|
29 |
+
|
30 |
+
# Ignore copy
|
31 |
+
def forward(self, x):
|
32 |
+
"""
|
33 |
+
Applies a forward pass.
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): The input tensor.
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: The output tensor after applying the forward pass.
|
38 |
+
"""
|
39 |
+
|
40 |
+
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
41 |
+
return down_proj
|
42 |
+
|
43 |
+
|
44 |
+
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
embed_dim: int,
|
48 |
+
num_heads: int,
|
49 |
+
bias: bool = True,
|
50 |
+
eps: float = 1e-6,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
A PyTorch module implementing self-attention with rotary embeddings.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
embed_dim (int): The dimensionality of the input embeddings.
|
57 |
+
num_heads (int): The number of attention heads.
|
58 |
+
bias (bool, optional): Whether to include bias terms in the linear projections. Defaults to True.
|
59 |
+
eps (float, optional): A small value added for numerical stability in normalization. Defaults to 1e-6.
|
60 |
+
"""
|
61 |
+
super().__init__()
|
62 |
+
assert embed_dim % num_heads == 0
|
63 |
+
self.num_heads = num_heads
|
64 |
+
# key, query, value projections for all heads, but in a batch
|
65 |
+
self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
|
66 |
+
self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
|
67 |
+
# output projection
|
68 |
+
self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
69 |
+
|
70 |
+
head_dim = embed_dim // num_heads
|
71 |
+
self.q_norm = RMSNorm(head_dim)
|
72 |
+
self.k_norm = RMSNorm(head_dim)
|
73 |
+
|
74 |
+
def forward(
|
75 |
+
self,
|
76 |
+
x,
|
77 |
+
freqs_cis: torch.Tensor,
|
78 |
+
attn_mask=None,
|
79 |
+
is_causal: bool = False,
|
80 |
+
kv_cache: Optional[Cache] = None,
|
81 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
82 |
+
decode: bool = False,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Forward pass for the SelfAttentionWithRotaryEmbedding instance.
|
86 |
+
Args:
|
87 |
+
x (torch.Tensor): Input tensor.
|
88 |
+
freqs_cis (torch.Tensor): Precomputed rotary positional embeddings.
|
89 |
+
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. Defaults to None.
|
90 |
+
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. Defaults to False.
|
91 |
+
kv_cache (Optional[Cache], optional): Cache object for storing key and value states for decoding. Defaults to None.
|
92 |
+
curr_pos_id (Optional[torch.Tensor], optional): Current position indices for decoding. Required if `decode` is True. Defaults to None.
|
93 |
+
decode (bool, optional): Whether the model is in decoding mode. Defaults to False.
|
94 |
+
Returns:
|
95 |
+
torch.Tensor: Output tensor after applying self-attention and projection.
|
96 |
+
"""
|
97 |
+
# batch size, sequence length, embedding dim
|
98 |
+
b, l, d = x.shape
|
99 |
+
|
100 |
+
# compute q, k, v and then split per q, k, v
|
101 |
+
q, k = self.c_qk(x).chunk(2, dim=-1)
|
102 |
+
v = self.c_v(x)
|
103 |
+
|
104 |
+
# split per head
|
105 |
+
q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
106 |
+
k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
107 |
+
v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
|
108 |
+
|
109 |
+
q = self.q_norm(q)
|
110 |
+
k = self.k_norm(k)
|
111 |
+
|
112 |
+
if kv_cache is not None:
|
113 |
+
if not decode:
|
114 |
+
kv_cache.key_states[:, :, : k.shape[2], :].copy_(k)
|
115 |
+
kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
|
116 |
+
else:
|
117 |
+
assert curr_pos_id is not None
|
118 |
+
kv_cache.key_states.index_copy_(2, curr_pos_id, k)
|
119 |
+
kv_cache.value_states.index_copy_(2, curr_pos_id, v)
|
120 |
+
k = kv_cache.key_states
|
121 |
+
v = kv_cache.value_states
|
122 |
+
|
123 |
+
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
124 |
+
# efficient attention using Flash Attention CUDA kernels
|
125 |
+
y = scaled_dot_product_attention_with_rotary_emb(
|
126 |
+
q,
|
127 |
+
k,
|
128 |
+
v,
|
129 |
+
freqs_cis=freqs_cis,
|
130 |
+
attn_mask=attn_mask,
|
131 |
+
curr_pos_id=curr_pos_id if decode else None,
|
132 |
+
is_causal=is_causal,
|
133 |
+
)
|
134 |
+
|
135 |
+
y = (
|
136 |
+
y.transpose(1, 2).contiguous().view(b, l, d)
|
137 |
+
) # re-assemble all head outputs side by side
|
138 |
+
|
139 |
+
# output projection
|
140 |
+
y = self.c_proj(y)
|
141 |
+
return y
|
142 |
+
|
143 |
+
|
144 |
+
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
embed_dim: int,
|
148 |
+
num_heads: int,
|
149 |
+
bias: bool = True,
|
150 |
+
eps: float = 1e-6,
|
151 |
+
) -> None:
|
152 |
+
"""
|
153 |
+
Initializes the transformer model with rotary embeddings.
|
154 |
+
Args:
|
155 |
+
embed_dim (int): The dimensionality of the embedding space.
|
156 |
+
num_heads (int): The number of attention heads.
|
157 |
+
bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
|
158 |
+
eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6.
|
159 |
+
"""
|
160 |
+
super().__init__()
|
161 |
+
|
162 |
+
self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
163 |
+
self.attn = SelfAttentionWithRotaryEmbedding(
|
164 |
+
embed_dim, num_heads=num_heads, bias=bias, eps=eps
|
165 |
+
)
|
166 |
+
self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
|
167 |
+
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias)
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def from_config(cls, cfg):
|
171 |
+
"""
|
172 |
+
Create an instance of the class using the provided configuration.
|
173 |
+
Args:
|
174 |
+
cfg: A configuration object containing the following attributes:
|
175 |
+
- n_embd (int): The size of the embedding dimension.
|
176 |
+
- n_head (int): The number of attention heads.
|
177 |
+
- bias (bool): Whether to include a bias term.
|
178 |
+
- eps (float): A small value added for numerical stability.
|
179 |
+
Returns:
|
180 |
+
An instance of the class initialized with the specified configuration.
|
181 |
+
"""
|
182 |
+
|
183 |
+
return cls(
|
184 |
+
cfg.n_embd,
|
185 |
+
num_heads=cfg.n_head,
|
186 |
+
bias=cfg.bias,
|
187 |
+
eps=cfg.eps,
|
188 |
+
)
|
189 |
+
|
190 |
+
def forward(
|
191 |
+
self,
|
192 |
+
x,
|
193 |
+
freqs_cis: torch.Tensor,
|
194 |
+
attn_mask=None,
|
195 |
+
is_causal: bool = True,
|
196 |
+
kv_cache: Optional[Cache] = None,
|
197 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
198 |
+
decode: bool = False,
|
199 |
+
):
|
200 |
+
"""
|
201 |
+
Forward pass for the transformer model.
|
202 |
+
Args:
|
203 |
+
x (torch.Tensor): Input tensor.
|
204 |
+
freqs_cis (torch.Tensor): Precomputed sinusoidal positional encodings.
|
205 |
+
attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention.
|
206 |
+
Defaults to None.
|
207 |
+
is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding.
|
208 |
+
Defaults to True.
|
209 |
+
kv_cache (Optional[Cache], optional): Key-value cache for efficient decoding.
|
210 |
+
Defaults to None.
|
211 |
+
curr_pos_id (Optional[torch.Tensor], optional): Current position IDs for decoding.
|
212 |
+
Defaults to None.
|
213 |
+
decode (bool, optional): Whether the model is in decoding mode.
|
214 |
+
Defaults to False.
|
215 |
+
Returns:
|
216 |
+
torch.Tensor: Output tensor.
|
217 |
+
"""
|
218 |
+
out = self.attn(
|
219 |
+
self.ln_1(x),
|
220 |
+
freqs_cis=freqs_cis,
|
221 |
+
attn_mask=attn_mask,
|
222 |
+
is_causal=is_causal,
|
223 |
+
kv_cache=kv_cache,
|
224 |
+
curr_pos_id=curr_pos_id,
|
225 |
+
decode=decode,
|
226 |
+
)
|
227 |
+
x = x + out
|
228 |
+
x = x + self.mlp(self.ln_2(x))
|
229 |
+
return x
|
cube/cube3d/model/transformers/rope.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def apply_rotary_emb(
|
8 |
+
x: torch.Tensor,
|
9 |
+
freqs_cis: torch.Tensor,
|
10 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Applies rotary positional embeddings to the input tensor.
|
14 |
+
Args:
|
15 |
+
x (torch.Tensor): The input tensor.
|
16 |
+
freqs_cis (torch.Tensor): A tensor containing the precomputed rotary
|
17 |
+
frequency components.
|
18 |
+
curr_pos_id (Optional[torch.Tensor]): An optional tensor specifying the
|
19 |
+
current position IDs to use for selecting a subset of `freqs_cis`.
|
20 |
+
If None, the function uses the last `seq_len` positions.
|
21 |
+
Returns:
|
22 |
+
torch.Tensor: The input tensor `x` with rotary positional embeddings
|
23 |
+
applied.
|
24 |
+
"""
|
25 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
26 |
+
if curr_pos_id is None:
|
27 |
+
freqs_cis = freqs_cis[:, -x.shape[2] :].unsqueeze(1)
|
28 |
+
else:
|
29 |
+
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
30 |
+
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
31 |
+
return y.type_as(x)
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad
|
35 |
+
def precompute_freqs_cis(dim: int, t: torch.Tensor, theta: float = 10000.0):
|
36 |
+
"""Calculate rotary embedding cos & sin, this is useful when every blocks in the network use same positional embedding.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
dim (int): dimension of the single head of the transformer block
|
40 |
+
t (torch.Tensor): position ids [..., L]
|
41 |
+
theta (int, optional): rope theta. Defaults to 10000.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Tuple[torch.Tensor, torch.Tensor]: tuple of cos and sin of rope
|
45 |
+
"""
|
46 |
+
assert dim % 2 == 0, (
|
47 |
+
"RoPE only supports embedding dimensions that are multiples of 2"
|
48 |
+
)
|
49 |
+
freqs = 1.0 / (
|
50 |
+
theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)
|
51 |
+
)
|
52 |
+
# [batch_size, seq_len, num_freqs]
|
53 |
+
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
54 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
55 |
+
|
56 |
+
return freqs_cis
|
57 |
+
|
58 |
+
|
59 |
+
def scaled_dot_product_attention_with_rotary_emb(
|
60 |
+
q: torch.Tensor,
|
61 |
+
k: torch.Tensor,
|
62 |
+
v: torch.Tensor,
|
63 |
+
freqs_cis: torch.Tensor,
|
64 |
+
attn_mask: Optional[torch.Tensor] = None,
|
65 |
+
curr_pos_id: Optional[torch.Tensor] = None,
|
66 |
+
is_causal: bool = False,
|
67 |
+
) -> torch.Tensor:
|
68 |
+
"""
|
69 |
+
Computes scaled dot product attention on query, key and value tensors
|
70 |
+
with rotary position embeddings on query and key.
|
71 |
+
|
72 |
+
Without caching enabled,
|
73 |
+
q should be (bs, nh, seqlen, hd).
|
74 |
+
k and v should stay unchanged, (bs, nh, seqlen, hd).
|
75 |
+
With caching enabled,
|
76 |
+
q should be (bs, nh, 1, hd).
|
77 |
+
k and v should stay unchanged, (bs, nh, 1, hd).
|
78 |
+
causal_mask must be False.
|
79 |
+
"""
|
80 |
+
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id) # (bs, nh, l, hd)
|
81 |
+
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None) # (bs, nh, s + l, hd)
|
82 |
+
|
83 |
+
x = F.scaled_dot_product_attention(
|
84 |
+
q,
|
85 |
+
k,
|
86 |
+
v,
|
87 |
+
attn_mask=attn_mask,
|
88 |
+
dropout_p=0.0,
|
89 |
+
is_causal=is_causal and attn_mask is None,
|
90 |
+
)
|
91 |
+
return x
|
cube/cube3d/renderer/blender_script.py
ADDED
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Blender script to render images of 3D models.
|
3 |
+
|
4 |
+
This script is adopted from the Trellis rendering script:
|
5 |
+
https://github.com/microsoft/TRELLIS/blob/main/dataset_toolkits/render.py
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import platform
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Any, Callable, Dict, Generator, Literal, Optional, Tuple
|
17 |
+
|
18 |
+
import bpy
|
19 |
+
import numpy as np
|
20 |
+
from mathutils import Vector
|
21 |
+
|
22 |
+
pathdir = Path(__file__).parent
|
23 |
+
sys.path.append(pathdir.as_posix())
|
24 |
+
|
25 |
+
print(dir(bpy), bpy.__path__)
|
26 |
+
|
27 |
+
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
28 |
+
".obj": bpy.ops.wm.obj_import,
|
29 |
+
".glb": bpy.ops.import_scene.gltf,
|
30 |
+
".gltf": bpy.ops.import_scene.gltf,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def center_and_scale_mesh(scale_value: float = 1.0) -> None:
|
35 |
+
"""Centers and scales the scene to fit in a unit cube.
|
36 |
+
For example,
|
37 |
+
scale_value = 1.0 ==> [-0.5, 0.5]
|
38 |
+
scale_value = 2.0 ==> [-1.0, 1.0]
|
39 |
+
"""
|
40 |
+
# Get all mesh objects
|
41 |
+
mesh_objects = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
42 |
+
if not mesh_objects:
|
43 |
+
return
|
44 |
+
|
45 |
+
# Calculate bounds
|
46 |
+
min_coords = Vector((float("inf"),) * 3)
|
47 |
+
max_coords = Vector((float("-inf"),) * 3)
|
48 |
+
|
49 |
+
for obj in mesh_objects:
|
50 |
+
# Get all vertices in world space
|
51 |
+
for vertex in obj.data.vertices:
|
52 |
+
world_coord = obj.matrix_world @ vertex.co
|
53 |
+
min_coords.x = min(min_coords.x, world_coord.x)
|
54 |
+
min_coords.y = min(min_coords.y, world_coord.y)
|
55 |
+
min_coords.z = min(min_coords.z, world_coord.z)
|
56 |
+
max_coords.x = max(max_coords.x, world_coord.x)
|
57 |
+
max_coords.y = max(max_coords.y, world_coord.y)
|
58 |
+
max_coords.z = max(max_coords.z, world_coord.z)
|
59 |
+
|
60 |
+
# Calculate center and dimensions
|
61 |
+
center = (min_coords + max_coords) / 2
|
62 |
+
dimensions = max_coords - min_coords
|
63 |
+
scale = scale_value / max(
|
64 |
+
dimensions.x, dimensions.y, dimensions.z
|
65 |
+
) # Scale to fit in [-scale_value/2, scale_value/2] cube
|
66 |
+
|
67 |
+
# Create an empty to serve as the parent
|
68 |
+
empty = bpy.data.objects.new("Parent_Empty", None)
|
69 |
+
bpy.context.scene.collection.objects.link(empty)
|
70 |
+
|
71 |
+
# Parent all mesh objects to the empty
|
72 |
+
for obj in mesh_objects:
|
73 |
+
obj.parent = empty
|
74 |
+
|
75 |
+
# Move empty to center everything
|
76 |
+
empty.location = -center
|
77 |
+
|
78 |
+
# Apply scale to empty
|
79 |
+
empty.scale = (scale, scale, scale)
|
80 |
+
|
81 |
+
bpy.context.view_layer.update()
|
82 |
+
bpy.ops.object.select_all(action="DESELECT")
|
83 |
+
empty.select_set(True)
|
84 |
+
bpy.context.view_layer.objects.active = empty
|
85 |
+
bpy.ops.object.transform_apply(location=True, rotation=True, scale=True)
|
86 |
+
print(f"Empty location: {empty.location}")
|
87 |
+
print(f"Empty scale: {empty.scale}")
|
88 |
+
|
89 |
+
return scale
|
90 |
+
|
91 |
+
|
92 |
+
def normalize_scene() -> None:
|
93 |
+
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
94 |
+
at the origin.
|
95 |
+
|
96 |
+
Mostly taken from the Point-E / Shap-E rendering script
|
97 |
+
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
98 |
+
but fix for multiple root objects: (see bug report here:
|
99 |
+
https://github.com/openai/shap-e/pull/60).
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
The new parent object that all objects descend from.
|
103 |
+
"""
|
104 |
+
if len(list(get_scene_root_objects())) > 1:
|
105 |
+
# create an empty object to be used as a parent for all root objects
|
106 |
+
parent_empty = bpy.data.objects.new("ParentEmpty", None)
|
107 |
+
bpy.context.scene.collection.objects.link(parent_empty)
|
108 |
+
|
109 |
+
# parent all root objects to the empty object
|
110 |
+
for obj in get_scene_root_objects():
|
111 |
+
if obj != parent_empty:
|
112 |
+
obj.parent = parent_empty
|
113 |
+
|
114 |
+
bbox_min, bbox_max = scene_bbox()
|
115 |
+
scale = 1 / max(bbox_max - bbox_min)
|
116 |
+
for obj in get_scene_root_objects():
|
117 |
+
obj.scale = obj.scale * scale
|
118 |
+
|
119 |
+
# Apply scale to matrix_world.
|
120 |
+
bpy.context.view_layer.update()
|
121 |
+
bbox_min, bbox_max = scene_bbox()
|
122 |
+
offset = -(bbox_min + bbox_max) / 2
|
123 |
+
for obj in get_scene_root_objects():
|
124 |
+
obj.matrix_world.translation += offset
|
125 |
+
bpy.ops.object.select_all(action="DESELECT")
|
126 |
+
bbox_min, bbox_max = scene_bbox()
|
127 |
+
print(f"After normalize_scene: bbox_min: {bbox_min}, bbox_max: {bbox_max}")
|
128 |
+
|
129 |
+
# unparent the camera
|
130 |
+
bpy.data.objects["Camera"].parent = None
|
131 |
+
|
132 |
+
return parent_empty
|
133 |
+
|
134 |
+
|
135 |
+
def reset_cameras() -> None:
|
136 |
+
"""Resets the cameras in the scene to a single default camera."""
|
137 |
+
# Delete all existing cameras
|
138 |
+
bpy.ops.object.select_all(action="DESELECT")
|
139 |
+
bpy.ops.object.select_by_type(type="CAMERA")
|
140 |
+
bpy.ops.object.delete()
|
141 |
+
|
142 |
+
# Create a new camera with default properties
|
143 |
+
bpy.ops.object.camera_add()
|
144 |
+
|
145 |
+
# Rename the new camera to 'NewDefaultCamera'
|
146 |
+
new_camera = bpy.context.active_object
|
147 |
+
new_camera.name = "Camera"
|
148 |
+
|
149 |
+
# Set the new camera as the active camera for the scene
|
150 |
+
scene.camera = new_camera
|
151 |
+
|
152 |
+
|
153 |
+
def get_camera_with_position(x, y, z, fov_degrees=40):
|
154 |
+
camera = bpy.data.objects["Camera"]
|
155 |
+
camera.data.angle = math.radians(fov_degrees)
|
156 |
+
camera.location = np.array([x, y, z])
|
157 |
+
direction = -camera.location
|
158 |
+
rot_quat = direction.to_track_quat("-Z", "Y")
|
159 |
+
camera.rotation_euler = rot_quat.to_euler()
|
160 |
+
return camera
|
161 |
+
|
162 |
+
|
163 |
+
def reset_scene() -> None:
|
164 |
+
"""Resets the scene to a clean state.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
None
|
168 |
+
"""
|
169 |
+
# delete everything that isn't part of a camera or a light
|
170 |
+
for obj in bpy.data.objects:
|
171 |
+
if obj.type not in {"CAMERA", "LIGHT"}:
|
172 |
+
bpy.data.objects.remove(obj, do_unlink=True)
|
173 |
+
|
174 |
+
# delete all the materials
|
175 |
+
for material in bpy.data.materials:
|
176 |
+
bpy.data.materials.remove(material, do_unlink=True)
|
177 |
+
|
178 |
+
# delete all the textures
|
179 |
+
for texture in bpy.data.textures:
|
180 |
+
bpy.data.textures.remove(texture, do_unlink=True)
|
181 |
+
|
182 |
+
# delete all the images
|
183 |
+
for image in bpy.data.images:
|
184 |
+
bpy.data.images.remove(image, do_unlink=True)
|
185 |
+
|
186 |
+
|
187 |
+
def load_object(object_path: str) -> None:
|
188 |
+
"""Loads a model with a supported file extension into the scene.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
object_path (str): Path to the model file.
|
192 |
+
|
193 |
+
Raises:
|
194 |
+
ValueError: If the file extension is not supported.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
None
|
198 |
+
"""
|
199 |
+
file_extension = Path(object_path).suffix
|
200 |
+
if file_extension is None or file_extension == "":
|
201 |
+
raise ValueError(f"Unsupported file type: {object_path}")
|
202 |
+
|
203 |
+
# load from existing import functions
|
204 |
+
import_function = IMPORT_FUNCTIONS[file_extension]
|
205 |
+
|
206 |
+
if file_extension in {".glb", ".gltf"}:
|
207 |
+
import_function(filepath=object_path, merge_vertices=True)
|
208 |
+
else:
|
209 |
+
import_function(filepath=object_path)
|
210 |
+
|
211 |
+
|
212 |
+
def clear_lights():
|
213 |
+
bpy.ops.object.select_all(action="DESELECT")
|
214 |
+
for obj in bpy.context.scene.objects.values():
|
215 |
+
if isinstance(obj.data, bpy.types.Light):
|
216 |
+
obj.select_set(True)
|
217 |
+
bpy.ops.object.delete()
|
218 |
+
|
219 |
+
|
220 |
+
def create_light(
|
221 |
+
location,
|
222 |
+
energy=1.0,
|
223 |
+
angle=0.5 * math.pi / 180,
|
224 |
+
light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
|
225 |
+
):
|
226 |
+
# https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92
|
227 |
+
light_data = bpy.data.lights.new(name="Light", type=light_type)
|
228 |
+
light_data.energy = energy
|
229 |
+
if light_type != "AREA" and light_type != "POINT":
|
230 |
+
light_data.angle = angle
|
231 |
+
light_object = bpy.data.objects.new(name="Light", object_data=light_data)
|
232 |
+
|
233 |
+
direction = -location
|
234 |
+
rot_quat = direction.to_track_quat("-Z", "Y")
|
235 |
+
light_object.rotation_euler = rot_quat.to_euler()
|
236 |
+
bpy.context.view_layer.update()
|
237 |
+
|
238 |
+
bpy.context.collection.objects.link(light_object)
|
239 |
+
light_object.location = location
|
240 |
+
|
241 |
+
|
242 |
+
def create_uniform_lights(
|
243 |
+
distance=2.0,
|
244 |
+
energy=3.0,
|
245 |
+
light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
|
246 |
+
):
|
247 |
+
clear_lights()
|
248 |
+
create_light(Vector([1, 0, 0]) * distance, energy=energy, light_type=light_type)
|
249 |
+
create_light(-Vector([1, 0, 0]) * distance, energy=energy, light_type=light_type)
|
250 |
+
create_light(Vector([0, 1, 0]) * distance, energy=energy, light_type=light_type)
|
251 |
+
create_light(-Vector([0, 1, 0]) * distance, energy=energy, light_type=light_type)
|
252 |
+
create_light(Vector([0, 0, 1]) * distance, energy=energy, light_type=light_type)
|
253 |
+
create_light(-Vector([0, 0, 1]) * distance, energy=energy, light_type=light_type)
|
254 |
+
|
255 |
+
|
256 |
+
def create_light_at_camera_position(
|
257 |
+
camera_position: Vector,
|
258 |
+
energy=1.5,
|
259 |
+
use_shadow=False,
|
260 |
+
light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
|
261 |
+
):
|
262 |
+
clear_lights()
|
263 |
+
create_light(camera_position, energy=energy, light_type=light_type)
|
264 |
+
# disable shadows
|
265 |
+
if not use_shadow:
|
266 |
+
for light in bpy.data.lights:
|
267 |
+
light.use_shadow = False
|
268 |
+
|
269 |
+
|
270 |
+
def set_world_background_color(
|
271 |
+
color: Tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
|
272 |
+
) -> None:
|
273 |
+
bpy.context.scene.world.use_nodes = True
|
274 |
+
bpy.context.scene.world.node_tree.nodes["Background"].inputs[
|
275 |
+
0
|
276 |
+
].default_value = color
|
277 |
+
bpy.context.scene.view_settings.view_transform = "Standard"
|
278 |
+
|
279 |
+
|
280 |
+
def scene_bbox(
|
281 |
+
single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
|
282 |
+
) -> Tuple[Vector, Vector]:
|
283 |
+
"""Returns the bounding box of the scene.
|
284 |
+
|
285 |
+
Taken from Shap-E rendering script
|
286 |
+
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
287 |
+
|
288 |
+
Args:
|
289 |
+
single_obj (Optional[bpy.types.Object], optional): If not None, only computes
|
290 |
+
the bounding box for the given object. Defaults to None.
|
291 |
+
ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
|
292 |
+
to False.
|
293 |
+
|
294 |
+
Raises:
|
295 |
+
RuntimeError: If there are no objects in the scene.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
299 |
+
"""
|
300 |
+
bbox_min = (math.inf,) * 3
|
301 |
+
bbox_max = (-math.inf,) * 3
|
302 |
+
found = False
|
303 |
+
for obj in get_scene_meshes() if single_obj is None else [single_obj]:
|
304 |
+
found = True
|
305 |
+
for coord in obj.bound_box:
|
306 |
+
coord = Vector(coord)
|
307 |
+
if not ignore_matrix:
|
308 |
+
coord = obj.matrix_world @ coord
|
309 |
+
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
310 |
+
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
311 |
+
|
312 |
+
if not found:
|
313 |
+
raise RuntimeError("no objects in scene to compute bounding box for")
|
314 |
+
|
315 |
+
return Vector(bbox_min), Vector(bbox_max)
|
316 |
+
|
317 |
+
|
318 |
+
def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
|
319 |
+
"""Returns all root objects in the scene.
|
320 |
+
|
321 |
+
Yields:
|
322 |
+
Generator[bpy.types.Object, None, None]: Generator of all root objects in the
|
323 |
+
scene.
|
324 |
+
"""
|
325 |
+
for obj in bpy.context.scene.objects.values():
|
326 |
+
if not obj.parent and not isinstance(obj.data, bpy.types.Light):
|
327 |
+
yield obj
|
328 |
+
|
329 |
+
|
330 |
+
def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
|
331 |
+
"""Returns all meshes in the scene.
|
332 |
+
|
333 |
+
Yields:
|
334 |
+
Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
|
335 |
+
"""
|
336 |
+
for obj in bpy.context.scene.objects.values():
|
337 |
+
if isinstance(obj.data, (bpy.types.Mesh)):
|
338 |
+
yield obj
|
339 |
+
|
340 |
+
|
341 |
+
def delete_missing_textures() -> Dict[str, Any]:
|
342 |
+
"""Deletes all missing textures in the scene.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
Dict[str, Any]: Dictionary with keys "count", "files", and "file_path_to_color".
|
346 |
+
"count" is the number of missing textures, "files" is a list of the missing
|
347 |
+
texture file paths, and "file_path_to_color" is a dictionary mapping the
|
348 |
+
missing texture file paths to a random color.
|
349 |
+
"""
|
350 |
+
missing_file_count = 0
|
351 |
+
out_files = []
|
352 |
+
file_path_to_color = {}
|
353 |
+
|
354 |
+
# Check all materials in the scene
|
355 |
+
for material in bpy.data.materials:
|
356 |
+
if material.use_nodes:
|
357 |
+
for node in material.node_tree.nodes:
|
358 |
+
if node.type == "TEX_IMAGE":
|
359 |
+
image = node.image
|
360 |
+
if image is not None:
|
361 |
+
file_path = bpy.path.abspath(image.filepath)
|
362 |
+
if file_path == "":
|
363 |
+
# means it's embedded
|
364 |
+
continue
|
365 |
+
|
366 |
+
if not os.path.exists(file_path):
|
367 |
+
# Find the connected Principled BSDF node
|
368 |
+
connected_node = node.outputs[0].links[0].to_node
|
369 |
+
|
370 |
+
if connected_node.type == "BSDF_PRINCIPLED":
|
371 |
+
if file_path not in file_path_to_color:
|
372 |
+
# Set a random color for the unique missing file path
|
373 |
+
random_color = [random.random() for _ in range(3)]
|
374 |
+
file_path_to_color[file_path] = random_color + [1]
|
375 |
+
|
376 |
+
connected_node.inputs[
|
377 |
+
"Base Color"
|
378 |
+
].default_value = file_path_to_color[file_path]
|
379 |
+
|
380 |
+
# Delete the TEX_IMAGE node
|
381 |
+
material.node_tree.nodes.remove(node)
|
382 |
+
missing_file_count += 1
|
383 |
+
out_files.append(image.filepath)
|
384 |
+
return {
|
385 |
+
"count": missing_file_count,
|
386 |
+
"files": out_files,
|
387 |
+
"file_path_to_color": file_path_to_color,
|
388 |
+
}
|
389 |
+
|
390 |
+
|
391 |
+
def setup_environment_lighting(envmap_path):
|
392 |
+
world = bpy.context.scene.world
|
393 |
+
world.use_nodes = True
|
394 |
+
nodes = world.node_tree.nodes
|
395 |
+
links = world.node_tree.links
|
396 |
+
|
397 |
+
# Clear existing nodes
|
398 |
+
for node in nodes:
|
399 |
+
nodes.remove(node)
|
400 |
+
|
401 |
+
# Create Background node
|
402 |
+
bg_node = nodes.new(type="ShaderNodeBackground")
|
403 |
+
bg_node.location = (0, 0)
|
404 |
+
|
405 |
+
# Create Environment Texture node
|
406 |
+
env_tex_node = nodes.new(type="ShaderNodeTexEnvironment")
|
407 |
+
env_tex_node.location = (-300, 0)
|
408 |
+
|
409 |
+
# Set the environment texture path (replace this with your file path)
|
410 |
+
env_tex_node.image = bpy.data.images.load(envmap_path)
|
411 |
+
|
412 |
+
# Create World Output node
|
413 |
+
world_output_node = nodes.new(type="ShaderNodeOutputWorld")
|
414 |
+
world_output_node.location = (300, 0)
|
415 |
+
|
416 |
+
# Link nodes
|
417 |
+
links.new(env_tex_node.outputs["Color"], bg_node.inputs["Color"])
|
418 |
+
links.new(bg_node.outputs["Background"], world_output_node.inputs["Surface"])
|
419 |
+
|
420 |
+
|
421 |
+
def create_solid_color_material(name, color):
|
422 |
+
mat = bpy.data.materials.new(name)
|
423 |
+
mat.use_nodes = True
|
424 |
+
node_tree = mat.node_tree
|
425 |
+
color_node = node_tree.nodes.new("ShaderNodeBsdfDiffuse")
|
426 |
+
color_node.inputs["Color"].default_value = color
|
427 |
+
mat_output = node_tree.nodes["Material Output"]
|
428 |
+
node_tree.links.new(color_node.outputs["BSDF"], mat_output.inputs["Surface"])
|
429 |
+
return mat
|
430 |
+
|
431 |
+
|
432 |
+
def create_phong_material(name, color):
|
433 |
+
mat = bpy.data.materials.new(name)
|
434 |
+
mat.use_nodes = True
|
435 |
+
node_tree = mat.node_tree
|
436 |
+
spec_node = node_tree.nodes.new("ShaderNodeBsdfPrincipled")
|
437 |
+
print(spec_node.inputs.keys())
|
438 |
+
spec_node.inputs["Base Color"].default_value = color
|
439 |
+
spec_node.inputs["Roughness"].default_value = 0.5
|
440 |
+
spec_node.inputs["Metallic"].default_value = 1.0
|
441 |
+
mat_output = node_tree.nodes["Material Output"]
|
442 |
+
node_tree.links.new(spec_node.outputs["BSDF"], mat_output.inputs["Surface"])
|
443 |
+
return mat
|
444 |
+
|
445 |
+
|
446 |
+
def render_object(
|
447 |
+
object_file: str,
|
448 |
+
num_renders: int,
|
449 |
+
output_dir: str,
|
450 |
+
transparent_background: bool = False,
|
451 |
+
environment_map: str = None,
|
452 |
+
) -> None:
|
453 |
+
"""Saves rendered images for given asset to specified output directory.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
object_file (str): Path to the object file.
|
457 |
+
num_renders (int): Number of renders to save of the object.
|
458 |
+
output_dir (str): Path to the directory where the rendered images and metadata
|
459 |
+
will be saved. The rendered images will be saved in the subdirectory
|
460 |
+
`output_dir/stemname`.
|
461 |
+
transparent_background (bool): Whether to use transparent background,
|
462 |
+
otherwise the background is white.
|
463 |
+
Returns:
|
464 |
+
None
|
465 |
+
"""
|
466 |
+
os.makedirs(output_dir, exist_ok=True)
|
467 |
+
|
468 |
+
# load the object
|
469 |
+
reset_scene()
|
470 |
+
load_object(object_file)
|
471 |
+
|
472 |
+
if transparent_background:
|
473 |
+
scene.render.film_transparent = True
|
474 |
+
else:
|
475 |
+
scene.render.film_transparent = False
|
476 |
+
|
477 |
+
set_world_background_color([0.2, 0.2, 0.2, 1.0])
|
478 |
+
|
479 |
+
# normalize the scene
|
480 |
+
_ = normalize_scene()
|
481 |
+
|
482 |
+
# Set up cameras
|
483 |
+
cam = scene.objects["Camera"]
|
484 |
+
fov_degrees = 40.0
|
485 |
+
cam.data.angle = np.radians(fov_degrees)
|
486 |
+
|
487 |
+
# Set up camera constraints
|
488 |
+
cam_constraint = cam.constraints.new(type="TRACK_TO")
|
489 |
+
cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
|
490 |
+
cam_constraint.up_axis = "UP_Y"
|
491 |
+
empty = bpy.data.objects.new("Empty", None)
|
492 |
+
empty.location = (0, 0, 0)
|
493 |
+
scene.collection.objects.link(empty)
|
494 |
+
cam_constraint.target = empty
|
495 |
+
cam.parent = empty
|
496 |
+
|
497 |
+
# delete all objects that are not meshes
|
498 |
+
delete_missing_textures()
|
499 |
+
|
500 |
+
if environment_map:
|
501 |
+
setup_environment_lighting(environment_map)
|
502 |
+
else:
|
503 |
+
create_uniform_lights(energy=1.0, light_type="SUN")
|
504 |
+
|
505 |
+
camera_position = [0, -2, 0]
|
506 |
+
|
507 |
+
# determine how much to orbit camera by.
|
508 |
+
stepsize = 360.0 / num_renders
|
509 |
+
|
510 |
+
def render_views(name):
|
511 |
+
for i in range(num_renders):
|
512 |
+
# set camera
|
513 |
+
_ = get_camera_with_position(
|
514 |
+
camera_position[0],
|
515 |
+
camera_position[1],
|
516 |
+
camera_position[2],
|
517 |
+
fov_degrees=fov_degrees,
|
518 |
+
)
|
519 |
+
|
520 |
+
# Set output paths with absolute paths
|
521 |
+
render_path = os.path.abspath(
|
522 |
+
os.path.join(output_dir, f"{i:03d}_{name}.png")
|
523 |
+
)
|
524 |
+
|
525 |
+
# Set file output paths
|
526 |
+
scene.render.filepath = render_path
|
527 |
+
|
528 |
+
# Make sure the output directory exists
|
529 |
+
os.makedirs(output_dir, exist_ok=True)
|
530 |
+
|
531 |
+
# Render
|
532 |
+
bpy.ops.render.render(write_still=True)
|
533 |
+
|
534 |
+
context.view_layer.objects.active = empty
|
535 |
+
empty.rotation_euler[2] += math.radians(stepsize)
|
536 |
+
|
537 |
+
# ensure that all objects have materials, if not then add a default
|
538 |
+
# one.
|
539 |
+
textured_mat = create_solid_color_material("default texture", [0.6, 0.6, 0.6, 1])
|
540 |
+
|
541 |
+
for obj in get_scene_meshes():
|
542 |
+
if obj.active_material is None:
|
543 |
+
obj.active_material = textured_mat
|
544 |
+
|
545 |
+
render_views("textured")
|
546 |
+
|
547 |
+
|
548 |
+
def enable_gpus(device_type, use_cpus=False):
|
549 |
+
preferences = bpy.context.preferences
|
550 |
+
cycles_preferences = preferences.addons["cycles"].preferences
|
551 |
+
cycles_preferences.refresh_devices()
|
552 |
+
try:
|
553 |
+
devices = cycles_preferences.devices
|
554 |
+
except:
|
555 |
+
print("No devices detected")
|
556 |
+
if device_type == "CPU":
|
557 |
+
return []
|
558 |
+
else:
|
559 |
+
raise RuntimeError(f"No devices detected, set use_cpus to True")
|
560 |
+
|
561 |
+
assert device_type in [
|
562 |
+
"CUDA",
|
563 |
+
"METAL",
|
564 |
+
"OPENCL",
|
565 |
+
"CPU",
|
566 |
+
"NONE",
|
567 |
+
], f"Unsupported device type: {device_type}"
|
568 |
+
|
569 |
+
try:
|
570 |
+
# print(devices)
|
571 |
+
iter(devices)
|
572 |
+
except TypeError:
|
573 |
+
# print("Single GPU Detected")
|
574 |
+
devices = [devices]
|
575 |
+
|
576 |
+
activated_gpus = []
|
577 |
+
for device in devices:
|
578 |
+
if device.type == "CPU":
|
579 |
+
device.use = use_cpus
|
580 |
+
else:
|
581 |
+
device.use = True
|
582 |
+
activated_gpus.append(device.name)
|
583 |
+
|
584 |
+
if device_type == "CUDA":
|
585 |
+
cycles_preferences.compute_device_type = "CUDA"
|
586 |
+
bpy.context.scene.cycles.device = "GPU"
|
587 |
+
elif device_type == "METAL":
|
588 |
+
cycles_preferences.compute_device_type = "METAL"
|
589 |
+
bpy.context.scene.cycles.device = "GPU"
|
590 |
+
elif device_type == "OPENCL":
|
591 |
+
cycles_preferences.compute_device_type = "OPENCL"
|
592 |
+
bpy.context.scene.cycles.device = "GPU"
|
593 |
+
else:
|
594 |
+
raise RuntimeError(f"Unsupported device type: {device_type}")
|
595 |
+
|
596 |
+
return activated_gpus
|
597 |
+
|
598 |
+
|
599 |
+
def set_render_settings(engine, resolution):
|
600 |
+
# Set render settings
|
601 |
+
render.engine = engine #
|
602 |
+
render.image_settings.file_format = "PNG"
|
603 |
+
render.image_settings.color_mode = "RGBA"
|
604 |
+
render.resolution_x = resolution
|
605 |
+
render.resolution_y = resolution
|
606 |
+
render.resolution_percentage = 100
|
607 |
+
|
608 |
+
# Set cycles settings
|
609 |
+
scene.cycles.device = "GPU"
|
610 |
+
scene.cycles.use_adaptive_sampling = True
|
611 |
+
scene.cycles.adaptive_threshold = 0.1
|
612 |
+
scene.cycles.samples = 64
|
613 |
+
scene.cycles.adaptive_min_samples = 1
|
614 |
+
scene.cycles.filter_width = 2
|
615 |
+
scene.cycles.use_fast_gi = True
|
616 |
+
scene.cycles.fast_gi_method = "REPLACE"
|
617 |
+
world.light_settings.ao_factor = 1.0
|
618 |
+
world.light_settings.distance = 10
|
619 |
+
scene.cycles.use_denoising = True # ML denoising
|
620 |
+
scene.cycles.denoising_use_gpu = True
|
621 |
+
|
622 |
+
# bake existing frames for faster future renders
|
623 |
+
scene.render.use_persistent_data = True
|
624 |
+
|
625 |
+
# Set eevee settings
|
626 |
+
scene.eevee.use_shadows = True
|
627 |
+
scene.eevee.use_raytracing = True
|
628 |
+
scene.eevee.ray_tracing_options.use_denoise = True
|
629 |
+
scene.eevee.use_fast_gi = True
|
630 |
+
scene.eevee.fast_gi_method = "GLOBAL_ILLUMINATION"
|
631 |
+
scene.eevee.ray_tracing_options.trace_max_roughness = 0.5
|
632 |
+
scene.eevee.fast_gi_resolution = "2"
|
633 |
+
scene.eevee.fast_gi_ray_count = 2
|
634 |
+
scene.eevee.fast_gi_step_count = 8
|
635 |
+
|
636 |
+
|
637 |
+
def print_devices():
|
638 |
+
print("Devices:")
|
639 |
+
preferences = bpy.context.preferences
|
640 |
+
cycles_preferences = preferences.addons["cycles"].preferences
|
641 |
+
cycles_preferences.refresh_devices()
|
642 |
+
|
643 |
+
devices = cycles_preferences.devices
|
644 |
+
for device in devices:
|
645 |
+
print(f' [{device.id}]<{device.type}> "{device.name}" Using: {device.use}')
|
646 |
+
|
647 |
+
print(f"Compute device type: {cycles_preferences.compute_device_type}")
|
648 |
+
print(f"Cycles device: {bpy.context.scene.cycles.device}")
|
649 |
+
|
650 |
+
|
651 |
+
if __name__ == "__main__":
|
652 |
+
parser = argparse.ArgumentParser()
|
653 |
+
parser.add_argument(
|
654 |
+
"--object_path",
|
655 |
+
type=str,
|
656 |
+
required=False,
|
657 |
+
help="Path to the object file",
|
658 |
+
)
|
659 |
+
parser.add_argument(
|
660 |
+
"--output_dir",
|
661 |
+
type=str,
|
662 |
+
required=True,
|
663 |
+
help="Path to the directory where the rendered images and metadata will be saved.",
|
664 |
+
)
|
665 |
+
parser.add_argument(
|
666 |
+
"--engine",
|
667 |
+
type=str,
|
668 |
+
default="BLENDER_EEVEE_NEXT", # BLENDER_BLENDER_EEVEE_NEXT rasterization, better than nvdifrast, CYCLES
|
669 |
+
choices=["CYCLES", "BLENDER_EEVEE_NEXT"],
|
670 |
+
)
|
671 |
+
parser.add_argument(
|
672 |
+
"--num_renders",
|
673 |
+
type=int,
|
674 |
+
default=12,
|
675 |
+
help="Number of renders to save of the object.",
|
676 |
+
)
|
677 |
+
parser.add_argument(
|
678 |
+
"--render_resolution",
|
679 |
+
type=int,
|
680 |
+
default=512,
|
681 |
+
help="Resolution of the rendered images.",
|
682 |
+
)
|
683 |
+
parser.add_argument(
|
684 |
+
"--transparent_background",
|
685 |
+
action="store_true",
|
686 |
+
help="Whether to use transparent background",
|
687 |
+
)
|
688 |
+
parser.add_argument(
|
689 |
+
"--environment_map",
|
690 |
+
default=None,
|
691 |
+
type=str,
|
692 |
+
help="Use the given environment map for lighting",
|
693 |
+
)
|
694 |
+
|
695 |
+
argv = sys.argv[sys.argv.index("--") + 1 :]
|
696 |
+
args = parser.parse_args(argv)
|
697 |
+
|
698 |
+
context = bpy.context
|
699 |
+
scene = context.scene
|
700 |
+
render = scene.render
|
701 |
+
world = bpy.data.worlds["World"]
|
702 |
+
|
703 |
+
set_render_settings(args.engine, args.render_resolution)
|
704 |
+
|
705 |
+
# detect platform and activate GPUs
|
706 |
+
platform = platform.system()
|
707 |
+
if platform == "Darwin":
|
708 |
+
activated_gpus = enable_gpus("METAL", use_cpus=True)
|
709 |
+
elif platform == "Linux":
|
710 |
+
activated_gpus = enable_gpus("CUDA", use_cpus=False)
|
711 |
+
else:
|
712 |
+
raise RuntimeError("Unsupported platform")
|
713 |
+
print(f"Activated GPUs: {activated_gpus}")
|
714 |
+
|
715 |
+
print_devices()
|
716 |
+
|
717 |
+
render_object(
|
718 |
+
object_file=args.object_path,
|
719 |
+
num_renders=args.num_renders,
|
720 |
+
output_dir=args.output_dir,
|
721 |
+
transparent_background=args.transparent_background,
|
722 |
+
environment_map=args.environment_map,
|
723 |
+
)
|
cube/cube3d/renderer/renderer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
def render_asset(
|
11 |
+
asset_path,
|
12 |
+
output_dir,
|
13 |
+
nviews=24,
|
14 |
+
img_resolution=512,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Render given asset into output_dir and return the saved image paths.
|
18 |
+
Assumes that blender is installed and is in your path.
|
19 |
+
|
20 |
+
nviews : number of views to render
|
21 |
+
img_resolution : resolution of each rendered view in pixels
|
22 |
+
"""
|
23 |
+
|
24 |
+
curr_file_path = __file__
|
25 |
+
curr_dir = os.path.dirname(curr_file_path)
|
26 |
+
|
27 |
+
command = [
|
28 |
+
"blender",
|
29 |
+
"--background",
|
30 |
+
"-noaudio",
|
31 |
+
"--python",
|
32 |
+
f"{curr_dir}/blender_script.py",
|
33 |
+
"--",
|
34 |
+
"--object_path",
|
35 |
+
asset_path,
|
36 |
+
"--num_renders",
|
37 |
+
str(nviews),
|
38 |
+
"--output_dir",
|
39 |
+
output_dir,
|
40 |
+
"--render_resolution",
|
41 |
+
str(img_resolution),
|
42 |
+
"--transparent_background",
|
43 |
+
"--engine",
|
44 |
+
"CYCLES",
|
45 |
+
]
|
46 |
+
|
47 |
+
subprocess.run(command, check=True)
|
48 |
+
|
49 |
+
# return the saved images paths
|
50 |
+
images = []
|
51 |
+
|
52 |
+
for i in range(nviews):
|
53 |
+
fp = os.path.abspath(os.path.join(output_dir, f"{i:03d}_textured.png"))
|
54 |
+
images.append(fp)
|
55 |
+
|
56 |
+
return images
|
57 |
+
|
58 |
+
|
59 |
+
def save_gif(image_paths, outfile):
|
60 |
+
images = [Image.open(img) for img in image_paths]
|
61 |
+
if len(images) > 1:
|
62 |
+
background = Image.new("RGBA", images[0].size, (255, 255, 255))
|
63 |
+
images = [
|
64 |
+
Image.alpha_composite(background, png).convert("RGB") for png in images
|
65 |
+
]
|
66 |
+
images[0].save(
|
67 |
+
outfile, save_all=True, append_images=images[1:], duration=100, loop=0
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def render_turntable(obj_path, output_dir, output_name="turntable"):
|
72 |
+
"""
|
73 |
+
Render a turntable gif of the mesh. Assumes that blender is installed and is in your path.
|
74 |
+
obj_path : path to the obj file
|
75 |
+
output_dir : directory to save the gif. Final image will be saved as `turntable.gif`
|
76 |
+
"""
|
77 |
+
image_paths = render_asset(obj_path, output_dir)
|
78 |
+
gif_turntable_outfile = Path(output_dir) / f"{output_name}.gif"
|
79 |
+
save_gif(image_paths, gif_turntable_outfile)
|
80 |
+
return gif_turntable_outfile
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
parser = argparse.ArgumentParser()
|
85 |
+
parser.add_argument("-i", "--input")
|
86 |
+
parser.add_argument("-o", "--output_dir")
|
87 |
+
args = parser.parse_args(sys.argv[1:])
|
88 |
+
render_turntable(args.input, args.output_dir)
|
cube/cube3d/vq_vae_encode_decode.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import trimesh
|
7 |
+
|
8 |
+
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
|
9 |
+
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
|
10 |
+
|
11 |
+
MESH_SCALE = 0.96
|
12 |
+
|
13 |
+
|
14 |
+
def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
|
15 |
+
"""Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
|
16 |
+
vertices = vertices
|
17 |
+
bbmin = vertices.min(0)
|
18 |
+
bbmax = vertices.max(0)
|
19 |
+
center = (bbmin + bbmax) * 0.5
|
20 |
+
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
|
21 |
+
vertices = (vertices - center) * scale
|
22 |
+
return vertices
|
23 |
+
|
24 |
+
|
25 |
+
def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
|
26 |
+
"""
|
27 |
+
Load a mesh and scale it to a unit cube, and clean the mesh.
|
28 |
+
Parameters:
|
29 |
+
file_obj: str | IO
|
30 |
+
file_type: str
|
31 |
+
Returns:
|
32 |
+
mesh: trimesh.Trimesh
|
33 |
+
"""
|
34 |
+
mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
|
35 |
+
mesh.remove_infinite_values()
|
36 |
+
mesh.update_faces(mesh.nondegenerate_faces())
|
37 |
+
mesh.update_faces(mesh.unique_faces())
|
38 |
+
mesh.remove_unreferenced_vertices()
|
39 |
+
if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
|
40 |
+
raise ValueError("Mesh has no vertices or faces after cleaning")
|
41 |
+
mesh.vertices = rescale(mesh.vertices)
|
42 |
+
return mesh
|
43 |
+
|
44 |
+
|
45 |
+
def load_and_process_mesh(file_path: str, n_samples: int = 8192):
|
46 |
+
"""
|
47 |
+
Loads a 3D mesh from the specified file path, samples points from its surface,
|
48 |
+
and processes the sampled points into a point cloud with normals.
|
49 |
+
Args:
|
50 |
+
file_path (str): The file path to the 3D mesh file.
|
51 |
+
n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
|
54 |
+
Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
|
55 |
+
"""
|
56 |
+
|
57 |
+
mesh = load_scaled_mesh(file_path)
|
58 |
+
positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
|
59 |
+
normals = mesh.face_normals[face_indices]
|
60 |
+
point_cloud = np.concatenate(
|
61 |
+
[positions, normals], axis=1
|
62 |
+
) # Shape: (num_samples, 6)
|
63 |
+
point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
|
64 |
+
return point_cloud
|
65 |
+
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def run_shape_decode(
|
69 |
+
shape_model: OneDAutoEncoder,
|
70 |
+
output_ids: torch.Tensor,
|
71 |
+
resolution_base: float = 8.0,
|
72 |
+
chunk_size: int = 100_000,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Decodes the shape from the given output IDs and extracts the geometry.
|
76 |
+
Args:
|
77 |
+
shape_model (OneDAutoEncoder): The shape model.
|
78 |
+
output_ids (torch.Tensor): The tensor containing the output IDs.
|
79 |
+
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
|
80 |
+
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
|
81 |
+
Returns:
|
82 |
+
tuple: A tuple containing the vertices and faces of the mesh.
|
83 |
+
"""
|
84 |
+
shape_ids = (
|
85 |
+
output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
|
86 |
+
.clamp_(0, shape_model.cfg.num_codes - 1)
|
87 |
+
.view(-1, shape_model.cfg.num_encoder_latents)
|
88 |
+
)
|
89 |
+
latents = shape_model.decode_indices(shape_ids)
|
90 |
+
mesh_v_f, _ = shape_model.extract_geometry(
|
91 |
+
latents,
|
92 |
+
resolution_base=resolution_base,
|
93 |
+
chunk_size=chunk_size,
|
94 |
+
use_warp=True,
|
95 |
+
)
|
96 |
+
return mesh_v_f
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
parser = argparse.ArgumentParser(
|
101 |
+
description="cube shape encode and decode example script"
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--mesh-path",
|
105 |
+
type=str,
|
106 |
+
required=True,
|
107 |
+
help="Path to the input mesh file.",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--config-path",
|
111 |
+
type=str,
|
112 |
+
default="cube3d/configs/open_model.yaml",
|
113 |
+
help="Path to the configuration YAML file.",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--shape-ckpt-path",
|
117 |
+
type=str,
|
118 |
+
required=True,
|
119 |
+
help="Path to the shape encoder/decoder checkpoint file.",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--recovered-mesh-path",
|
123 |
+
type=str,
|
124 |
+
default="recovered_mesh.obj",
|
125 |
+
help="Path to save the recovered mesh file.",
|
126 |
+
)
|
127 |
+
args = parser.parse_args()
|
128 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
129 |
+
logging.info(f"Using device: {device}")
|
130 |
+
|
131 |
+
cfg = load_config(args.config_path)
|
132 |
+
|
133 |
+
shape_model = OneDAutoEncoder(
|
134 |
+
parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
|
135 |
+
)
|
136 |
+
load_model_weights(
|
137 |
+
shape_model,
|
138 |
+
args.shape_ckpt_path,
|
139 |
+
)
|
140 |
+
shape_model = shape_model.eval().to(device)
|
141 |
+
point_cloud = load_and_process_mesh(args.mesh_path)
|
142 |
+
output = shape_model.encode(point_cloud.to(device))
|
143 |
+
indices = output[3]["indices"]
|
144 |
+
print("Got the following shape indices:")
|
145 |
+
print(indices)
|
146 |
+
print("Indices shape: ", indices.shape)
|
147 |
+
mesh_v_f = run_shape_decode(shape_model, indices)
|
148 |
+
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
|
149 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
150 |
+
mesh.export(args.recovered_mesh_path)
|
cube/pyproject.toml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "cube"
|
7 |
+
version = "0.1"
|
8 |
+
requires-python = ">=3.7"
|
9 |
+
description = "A generative 3D model to accelerate the creation of 3D assets, accessories, and experiences."
|
10 |
+
authors = [
|
11 |
+
{ name = "Foundation AI", email = "[email protected]" }
|
12 |
+
]
|
13 |
+
keywords = ["cube"]
|
14 |
+
classifiers = [
|
15 |
+
"Development Status :: 3 - Alpha",
|
16 |
+
"Intended Audience :: Developers",
|
17 |
+
"Programming Language :: Python :: 3.10",
|
18 |
+
]
|
19 |
+
dependencies = [
|
20 |
+
"numpy",
|
21 |
+
"torch>=2.2.2",
|
22 |
+
"tqdm",
|
23 |
+
"transformers",
|
24 |
+
"omegaconf",
|
25 |
+
"warp-lang",
|
26 |
+
"accelerate>=0.26.0",
|
27 |
+
"scikit-image",
|
28 |
+
"huggingface_hub[cli]",
|
29 |
+
"trimesh"
|
30 |
+
]
|
31 |
+
[project.optional-dependencies]
|
32 |
+
meshlab = ["pymeshlab"]
|
33 |
+
lint = ["ruff==0.9.10"]
|
34 |
+
|
35 |
+
[tool.setuptools.packages.find]
|
36 |
+
where = ["cube3d"]
|
37 |
+
include = ["cube/*"]
|
38 |
+
namespaces = false
|
cube/setup.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="cube",
|
5 |
+
version="0.0.1",
|
6 |
+
)
|