Akash Garg commited on
Commit
616f571
·
1 Parent(s): 7e8f630

adding cube sources

Browse files
cube/.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ .vscode/*
171
+
172
+ .DS_Store
173
+
174
+ # Output folder
175
+ outputs/
176
+ model_weights/
cube/LICENSE ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUBE3D RESEARCH-ONLY RAIL-MS LICENSE
2
+
3
+ Licensed Artifacts:
4
+ Cube3d-v0.1 and related inference code
5
+
6
+ I. SCOPE
7
+ This Research-Only RAIL License is generally applicable to the Artifacts identified above.
8
+ For valuable consideration, You and Licensor agree as follows:
9
+ 1. Definitions
10
+ (a) “Artifact” means a software application (in either binary or source code format), Model, or Source Code, in accordance with what are specified above as the “Licensed Artifacts.”
11
+ (b) “Contribution” means any work, including any modifications or additions to an Artifact, that is intentionally submitted to Licensor for inclusion or incorporation in the Artifact directly or indirectly by the rights owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing, sharing and improving the Artifact, but excluding communication that is conspicuously marked or otherwise designated in writing by the contributor as “Not a Contribution.”
12
+ (c) “Contributor” means Licensor or any other individual or legal entity that creates or owns a Contribution that is added to or incorporated into an Artifact or Derivative.
13
+ (d) “Data” means a collection of information or content extracted from the dataset used with a given Model, including to train, pretrain, or otherwise evaluate the Model.
14
+ (e) “Derivative” means a work derived from or based upon an Artifact, and includes all modified versions of such Artifact.
15
+ (f) “Distribution” means any transmission, reproduction, publication or other sharing of an Artifact or Derivative to a third party, including providing a hosted service incorporating the Artifact, which is made available by electronic or other remote means—e.g., API-based or web access.
16
+ (g) “License” means the terms and conditions for use, reproduction, and Distribution as defined in this document.
17
+ (h) “Licensor” means the rights owner (by virtue of creation or documented transfer of ownership) or entity authorized by the rights owner (e.g., exclusive licensee) that is granting the rights in this License.
18
+ (i) “Model” means any machine-learning based assembly or assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Source Code.
19
+ (j) “Output” means the results of operating a Model as embodied in informational content resulting therefrom.
20
+ (k) “Permitted Purpose” means for academic or research purposes only.
21
+ (l) “Source Code” means any collection of text written using human-readable programming language, including the code and scripts used to define, run, load, benchmark or evaluate a Model or any component thereof, or used to prepare data for training or evaluation. Source Code includes any accompanying documentation, tutorials and examples. For clarity, the term “Source Code” as used in this License includes any and all Derivatives of such Source Code.
22
+ (m) “Third Party” means any individual or legal entity that is not under common control with Licensor or You.
23
+ (n) “Use,” with respect to an Artifact, means accessing, using, copying, modifying, distributing, and making available the Artifact; in connection with a Model as Artifact, Use also includes creating content, fine-tuning, updating, running, training, evaluating and re-parametrizing such Model.
24
+ (o) “You” (or “Your”) means an individual or legal entity receiving and exercising permissions granted by this License or making use of the Artifact for the Permitted Purpose and in any permitted field of use, including usage of the Artifact in an end-use application.
25
+
26
+ II. INTELLECTUAL PROPERTY RIGHTS
27
+ 1. Both copyright and patent grants may apply to the Artifacts. The Artifacts are subject to additional terms as described in Section III below, which govern the Use of the Artifacts in the event that Section II is held unenforceable or inapplicable.
28
+ 2. Grant of Copyright License. Conditioned upon compliance with Section III below and subject to the terms and conditions of this License, each Contributor hereby grants to You, only in connection with the Permitted Purpose, a worldwide, non-exclusive, royalty-free copyright license to reproduce, publicly display, publicly perform, distribute, and make derivatives of the Artifacts.
29
+ 3. Grant of Patent License. Conditioned upon compliance with Section III below and subject to the terms and conditions of this License, and only where and as applicable, each Contributor hereby grants to You, only in connection with the Permitted Purpose, a worldwide, non-exclusive, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, sell, offer to sell, import, and otherwise transfer the Artifacts where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contributions alone or by combination of their Contributions with the Artifact to which such Contribution was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that an Artifact or Contribution constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License in connection with the Artifact shall terminate as of the date such litigation is asserted or filed.
30
+ 4. Licensor and Contributor each have the right to grant the licenses above.
31
+ 5. The Data is not licensed under this License.
32
+
33
+ III. CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
34
+ 1. Use-based restrictions. The restrictions set forth in Attachment A are mandatory Use-based restrictions. Therefore You may not Use any Artifact in violation of such restrictions. You may Use Artifacts only subject to this License. You shall require all of Your users who Use the Artifacts or Derivatives to comply with the terms of this paragraph and only Use the Artifacts and Derivatives for the Permitted Purpose.
35
+ 2. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output generated by You or Your users. You are accountable for the Output You generate and its subsequent uses. No use of the Output may contravene any provision as stated in this License.
36
+ 3. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce, distribute, and make available Artifacts and Derivatives in any medium, with or without modifications, provided that You meet the following conditions:
37
+ (a) Use-based restrictions in Paragraph III.1 MUST be included as a condition precedent to effect any type of legal agreement (e.g., a license) governing the Use of the Artifacts and Derivatives, and You shall give such notice to any subsequent Third Party recipients.
38
+ (b) You shall give any Third Party recipients of any Artifacts or Derivatives a copy of this License.
39
+ (c) You shall cause any modified files to carry prominent notices stating that You changed the files.
40
+ (d) You shall retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Artifacts or Derivatives.
41
+ (e) You and any Third Party recipients of any Artifacts or Derivatives shall adhere to the Permitted Purpose.
42
+ 4. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions with respect to Paragraph III.3(a), to govern the Use or Distribution of Your modifications, or for any Derivative, provided that Your Use and Distribution of the Artifacts or their Derivatives otherwise complies with the conditions stated in this License. In other words, the Use-based restrictions referred to in Paragraph III.1 form the minimum set of terms for You to license to Third Parties any Artifacts or Derivatives, but You may add more restrictive terms if You deem it necessary.
43
+
44
+ IV. OTHER PROVISIONS
45
+ 1. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of Artifacts in violation of this License or update Artifacts through electronic means.
46
+ 2. Trademarks. Nothing in this License permits You to make use of Licensor’s trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between You and Licensor; and any rights not expressly granted herein are reserved by the Licensor.
47
+ 3. DISCLAIMER OF WARRANTY. UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING, LICENSOR PROVIDES THE ARTIFACT (AND EACH CONTRIBUTOR PROVIDES ITS CONTRIBUTIONS) ON AN “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING THE ARTIFACT, AND ASSUME ANY RISKS ASSOCIATED WITH YOUR EXERCISE OF PERMISSIONS UNDER THIS LICENSE.
48
+ 4. LIMITATION OF LIABILITY. IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW (SUCH AS DELIBERATE AND GROSSLY NEGLIGENT ACTS) OR AGREED TO IN WRITING, SHALL ANY CONTRIBUTOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER ARISING AS A RESULT OF THIS LICENSE OR OUT OF THE USE OR INABILITY TO USE THE ARTIFACT (INCLUDING BUT NOT LIMITED TO DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF SUCH CONTRIBUTOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
49
+ 5. Severability. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
50
+ 6. Term and Termination. The term of this License will commence upon the earlier of (a) Your acceptance of this License or (b) accessing the Artifact; and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Licensor may terminate this License if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of the Artifacts. This paragraph shall survive the termination of this License.
51
+
52
+ Attachment A – Use Restrictions
53
+ 1. Discrimination. You agree not to Use, or allow others to Use, Artifacts or Derivatives
54
+ (a) to discriminate, mock, or promote hatred against individuals or groups, or encourage others to do so directly or indirectly, on the basis of their age; race, perceived race, or ethnicity; national origin; sexual orientation; gender, gender identity, or gender expression; religion or religious affiliation or beliefs; disability status including diseases, bodily conditions, disfigurement, mobility issues, and mental impairment; veteran status; caste; or familial status.
55
+ (b) to exploit any of the vulnerabilities of an individual or specific group of persons based on their age, social, physical, or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm.
56
+ (c) to engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, or other essential goods and services.
57
+ 2. Intellectual Property. You agree not to, and not to allow others to Use Artifacts or Derivatives
58
+ (a) to infringe or attempt to infringe, misappropriate or otherwise violate any intellectual property rights of Licensor or any Third Party;
59
+ (b) synthesize or modify a natural person’s appearance, voice, or other individual characteristics, unless prior informed consent of said natural person is obtained; or
60
+ (c) to reverse engineer, disassemble, decompile, or otherwise attempt to derive or gain access to Data that was used to create, train, pretrain, or otherwise evaluate such Artifacts or Derivatives.
61
+ 3. Legal. You agree not to Use, or allow others to Use, Artifacts or Derivatives
62
+ (a) in any way that violates any applicable national, federal, state, local or international law or regulation;
63
+ (b) to engage in, facilitate, or assist in the planning or development of criminal activities; or
64
+ (c) to generate unlawful content.
65
+ 4. Disinformation. You agree not to Use, or allow others to Use, Artifacts or Derivatives
66
+ (a) to create, present or disseminate false or misleading information for economic gain or to intentionally deceive the public, including creating false impersonations of natural persons;
67
+ (b) to defame or harm a person’s reputation, such as by generating, creating, promoting, or spreading defamatory content.
68
+ 5. Privacy. You agree not to Use, or allow others to Use, Artifacts or Derivatives
69
+ (a) to engage in, promote, incite, or facilitate the harassment, abuse, threatening or bullying of individuals or groups of individuals; or
70
+ (b) in connection with personal information to infer additional personal information about a natural person, including but not limited to legally protected characteristics, vulnerabilities or categories; unless informed consent from the data subject to collect said inferred personal information for a stated purpose and defined duration is received.
71
+ 6. Health and Safety. You agree not to Use, or allow others to Use, Artifacts or Derivatives
72
+ (a) to provide health or medical advice, medical results interpretation, or make clinical decisions; or
73
+ (b) in connection with any activities that present a risk of death or bodily harm to individuals, including self-harm or harm to others, or in connection with regulated or controlled substances.
74
+ 7. Military or Law Enforcement. You agree not to Use, or allow others to Use, Artifacts or Derivatives
75
+ (a) for purposes of administration of justice, law enforcement, immigration, or asylum processes, such as predicting that a natural person will commit a crime or the likelihood thereof;
76
+ (b) for weaponry or warfare; for building or optimizing weapons; or in service of nuclear proliferation or nuclear weapons technology; or
77
+ (c) military surveillance, including any research or development relating to military surveillance.
78
+ 8. General. You agree not to Use, or allow others to Use, Artifacts or Derivatives
79
+ (a) in any manner that would constitute high risk, restricted, or prohibited of AI under applicable law; or
80
+ (b) to generate or disseminate malware or ransomware or to otherwise harm electronic systems.
cube/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cube: Generative AI System for 3D
2
+
3
+ <p align="center">
4
+ <img src="./resources/teaser.png" width="800" style="margin: 5px;">
5
+ </p>
6
+
7
+ <div align="center">
8
+ <a href=https://corp.roblox.com/newsroom/2025/03/introducing-roblox-cube target="_blank"><img src=https://img.shields.io/badge/Roblox-Blog-000000.svg?logo=Roblox height=22px></a>
9
+ <a href=https://huggingface.co/Roblox/cube3d-0.1 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-d96902.svg height=22px></a>
10
+ <a href=https://arxiv.org/abs/2503.15475 target="_blank"><img src=https://img.shields.io/badge/ArXiv-Report-b5212f.svg?logo=arxiv height=22px></a>
11
+ <a href=https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t target="_blank"><img src=https://img.shields.io/badge/Google-Open_In_Colab-blue.svg?logo=googlecolab height=22px></a>
12
+ </div>
13
+
14
+
15
+ Foundation models trained on vast amounts of data have demonstrated remarkable reasoning and
16
+ generation capabilities in the domains of text, images, audio and video. Our goal is to build
17
+ such a foundation model for 3D intelligence, a model that can support developers in producing all aspects
18
+ of a Roblox experience, from generating 3D objects and scenes to rigging characters for animation to
19
+ producing programmatic scripts describing object behaviors. As we start open-sourcing a family of models
20
+ towards this vision, we hope to engage others in the research community to address these goals with us.
21
+
22
+ ## Get Started with Cube 3D
23
+
24
+ <p align="center">
25
+ <img src="./resources/greyscale_512.gif" width="600" style="margin: 5px;">
26
+ </p>
27
+
28
+ Cube 3D is our first step towards 3D intelligence, which involves a shape tokenizer and a text-to-shape generation model. We are unlocking the power of generating 3D assets and enhancing creativity for all artists. Our latest version of Cube 3D is now accessible to individuals, creators, researchers and businesses of all sizes so that they can experiment, innovate and scale their ideas responsibly. This release includes model weights and starting code for using our text-to-shape model to create 3D assets.
29
+
30
+ ### Try it out on [Google Colab](https://colab.research.google.com/drive/1ZvTj49pjDCD_crX5WPZNTAoTTzL6-E5t)
31
+
32
+ ### Install Requirements
33
+
34
+ Clone and install this repo in a virtual environment, via:
35
+
36
+ ```bash
37
+ git clone https://github.com/Roblox/cube.git
38
+ cd cube
39
+ pip install -e .[meshlab]
40
+ ```
41
+
42
+ > **CUDA**: If you are using a Windows machine, you may need to install the [CUDA](https://developer.nvidia.com/cuda-downloads) toolkit as well as `torch` with cuda support via `pip install torch --index-url https://download.pytorch.org/whl/cu124 --force-reinstall`
43
+
44
+ > **Note**: `[meshlab]` is an optional dependency and can be removed by simply running `pip install -e .` for better compatibility but mesh simplification will be disabled.
45
+
46
+ ### Download Models from Huggingface 🤗
47
+
48
+ Download the model weights from [hugging face](https://huggingface.co/Roblox/cube3d-v0.1) or use the
49
+ `huggingface-cli`:
50
+
51
+ ```bash
52
+ huggingface-cli download Roblox/cube3d-v0.1 --local-dir ./model_weights
53
+ ```
54
+
55
+ ### Inference
56
+
57
+ #### 1. Shape Generation
58
+
59
+ To generate 3D models using the downloaded models simply run:
60
+
61
+ ```bash
62
+ python -m cube3d.generate \
63
+ --gpt-ckpt-path model_weights/shape_gpt.safetensors \
64
+ --shape-ckpt-path model_weights/shape_tokenizer.safetensors \
65
+ --fast-inference \
66
+ --prompt "Broad-winged flying red dragon, elongated, folded legs."
67
+ ```
68
+
69
+ > **Note**: `--fast-inference` is optional and may not be available for all GPU that have limited VRAM. This flag will also not work on MacOS.
70
+
71
+ The output will be an `.obj` file saved in the specified `output` directory.
72
+
73
+ If you want to render a turntable gif of the mesh, you can use the `--render-gif` flag, which will render a turntable gif of the mesh
74
+ and save it as `turntable.gif` in the specified `output` directory.
75
+
76
+ We provide several example output objects and their corresponding text prompts in the `examples` folder.
77
+
78
+ > **Note**: You must have Blender installed and available in your system's PATH to render the turntable GIF. You can download it from [Blender's official website](https://www.blender.org/). Ensure that the Blender executable is accessible from the command line.
79
+
80
+ > **Note**: If shape decoding is slow, you can try to specify a lower resolution using the `--resolution-base` flag. A lower resolution will create a coarser and lower quality output mesh but faster decoding. Values between 4.0 and 9.0 are recommended.
81
+
82
+ #### 2. Shape Tokenization and De-tokenization
83
+
84
+ To tokenize a 3D shape into token indices and reconstruct it back, you can use the following command:
85
+
86
+ ```bash
87
+ python -m cube3d.vq_vae_encode_decode \
88
+ --shape-ckpt-path model_weights/shape_tokenizer.safetensors \
89
+ --mesh-path ./outputs/output.obj
90
+ ```
91
+
92
+ This will process the `.obj` file located at `./outputs/output.obj` and prints the tokenized representation as well as exports the mesh reconstructed from the token indices.
93
+
94
+ ### Hardware Requirements
95
+
96
+ We have tested our model on:
97
+ * Nvidia H100 GPU
98
+ * Nvidia A100 GPU
99
+ * Nvidia Geforce 3080
100
+ * Apple Silicon M2-4 Chips.
101
+
102
+ We recommend using a GPU with at least 24GB of VRAM available when using `--fast-inference` (or `EngineFast`) and 16GB otherwise.
103
+
104
+ ### Code Usage
105
+
106
+ We have designed a minimalist API that allows the use this repo as a Python library:
107
+
108
+ ```python
109
+ import torch
110
+ import trimesh
111
+ from cube3d.inference.engine import Engine, EngineFast
112
+
113
+ # load ckpt
114
+ config_path = "cube3d/configs/open_model.yaml"
115
+ gpt_ckpt_path = "model_weights/shape_gpt.safetensors"
116
+ shape_ckpt_path = "model_weights/shape_tokenizer.safetensors"
117
+ engine_fast = EngineFast( # only supported on CUDA devices, replace with Engine otherwise
118
+ config_path,
119
+ gpt_ckpt_path,
120
+ shape_ckpt_path,
121
+ device=torch.device("cuda"),
122
+ )
123
+
124
+ # inference
125
+ input_prompt = "A pair of noise-canceling headphones"
126
+ # NOTE: Reduce `resolution_base` for faster inference and lower VRAM usage
127
+ # The `top_k` parameter controls randomness between inferences:
128
+ # - A value of 1 yields deterministic results.
129
+ # - Higher values introduce more randomness.
130
+ mesh_v_f = engine_fast.t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_k=5)
131
+
132
+ # save output
133
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
134
+ _ = trimesh.Trimesh(vertices=vertices, faces=faces).export("output.obj")
135
+ ```
136
+
137
+ ## Coming Soon
138
+
139
+ ### Controlling shape generation with bounding box conditioning
140
+ <div align="center">
141
+ <img src="./resources/truck_black_text_512.gif" width="300" height="300" style="margin: 5px;">
142
+ <img src="./resources/couch_black_text_512.gif" width="300" height="300" style="margin: 5px;">
143
+ </div>
144
+
145
+ ### Scene Generation
146
+
147
+ https://github.com/user-attachments/assets/987c459a-5708-41a5-9b92-89068a70a239
148
+
149
+ https://github.com/user-attachments/assets/ab501a86-b0cb-4c73-827e-988b2120d4c0
150
+
151
+ ## Citation
152
+ If you find this work helpful, please consider citing our technical report:
153
+
154
+ ```bibtex
155
+ @article{roblox2025cube,
156
+ title = {Cube: A Roblox View of 3D Intelligence},
157
+ author = {Roblox, Foundation AI Team},
158
+ journal = {arXiv preprint arXiv:2503.15475},
159
+ year = {2025}
160
+ }
161
+ ```
162
+
163
+ ## Acknowledgements
164
+
165
+ We would like to thank the contributors of [TRELLIS](https://github.com/microsoft/TRELLIS), [CraftsMan3D](https://github.com/wyysf-98/CraftsMan3D), [threestudio](https://github.com/threestudio-project/threestudio), [Hunyuan3D-2](https://github.com/Tencent/Hunyuan3D-2), [minGPT](https://github.com/karpathy/minGPT), [dinov2](https://github.com/facebookresearch/dinov2), [OptVQ](https://github.com/zbr17/OptVQ), [1d-tokenizer](https://github.com/bytedance/1d-tokenizer)
166
+ repositories, for their open source contributions.
cube/SECURITY.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Security Policy
2
+
3
+ ## Reporting a Vulnerability
4
+
5
+ If you discover a security vulnerability in this repository, we appreciate your help in ensuring that the issue is addressed quickly.
6
+
7
+ Report any vulnerabilities found to our bug bounty program on HackerOne: https://hackerone.com/roblox
8
+
9
+ Please **do not create a public issue in this repo**.
cube/cube3d/__init__.py ADDED
File without changes
cube/cube3d/colab_cube3d.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
cube/cube3d/configs/open_model.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpt_model:
2
+ n_layer: 23
3
+ n_single_layer: 1
4
+ rope_theta: 10000
5
+ n_head: 12
6
+ n_embd: 1536
7
+ bias: true
8
+ eps: 1.e-6
9
+ shape_model_vocab_size: 16384
10
+ text_model_embed_dim: 768
11
+ use_pooled_text_embed: False
12
+ shape_model_embed_dim: 32
13
+ encoder_with_cls_token: true
14
+
15
+ shape_model:
16
+ encoder_with_cls_token: true
17
+ num_encoder_latents: 512
18
+ num_decoder_latents: 0
19
+ embed_dim: 32
20
+ width: 768
21
+ num_heads: 12
22
+ out_dim: 1
23
+ eps: 1.e-6
24
+ num_freqs: 128
25
+ point_feats: 3
26
+ embed_point_feats: false
27
+ num_encoder_layers: 13
28
+ encoder_cross_attention_levels: [0, 2, 4, 8]
29
+ num_decoder_layers: 24
30
+ num_codes: 16384
31
+
32
+ text_model_pretrained_model_name_or_path: "openai/clip-vit-large-patch14"
cube/cube3d/generate.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import trimesh
6
+
7
+ from cube3d.inference.engine import Engine, EngineFast
8
+ from cube3d.mesh_utils.postprocessing import (
9
+ PYMESHLAB_AVAILABLE,
10
+ create_pymeshset,
11
+ postprocess_mesh,
12
+ save_mesh,
13
+ )
14
+ from cube3d.renderer import renderer
15
+
16
+ def generate_mesh(
17
+ engine,
18
+ prompt,
19
+ output_dir,
20
+ output_name,
21
+ resolution_base=8.0,
22
+ disable_postprocess=False,
23
+ top_k: int = 1,
24
+ ):
25
+ mesh_v_f = engine.t2s(
26
+ [prompt],
27
+ use_kv_cache=True,
28
+ resolution_base=resolution_base,
29
+ top_k=top_k,
30
+ )
31
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
32
+ obj_path = os.path.join(output_dir, f"{output_name}.obj")
33
+ if PYMESHLAB_AVAILABLE:
34
+ ms = create_pymeshset(vertices, faces)
35
+ if not disable_postprocess:
36
+ target_face_num = max(10000, int(faces.shape[0] * 0.1))
37
+ print(f"Postprocessing mesh to {target_face_num} faces")
38
+ postprocess_mesh(ms, target_face_num, obj_path)
39
+
40
+ save_mesh(ms, obj_path)
41
+ else:
42
+ print(
43
+ "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
44
+ )
45
+ mesh = trimesh.Trimesh(vertices, faces)
46
+ mesh.export(obj_path)
47
+
48
+ return obj_path
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="cube shape generation script")
53
+ parser.add_argument(
54
+ "--config-path",
55
+ type=str,
56
+ default="cube3d/configs/open_model.yaml",
57
+ help="Path to the configuration YAML file.",
58
+ )
59
+ parser.add_argument(
60
+ "--output-dir",
61
+ type=str,
62
+ default="outputs/",
63
+ help="Path to the output directory to store .obj and .gif files",
64
+ )
65
+ parser.add_argument(
66
+ "--gpt-ckpt-path",
67
+ type=str,
68
+ required=True,
69
+ help="Path to the main GPT checkpoint file.",
70
+ )
71
+ parser.add_argument(
72
+ "--shape-ckpt-path",
73
+ type=str,
74
+ required=True,
75
+ help="Path to the shape encoder/decoder checkpoint file.",
76
+ )
77
+ parser.add_argument(
78
+ "--fast-inference",
79
+ help="Use optimized inference",
80
+ default=False,
81
+ action="store_true",
82
+ )
83
+ parser.add_argument(
84
+ "--prompt",
85
+ type=str,
86
+ required=True,
87
+ help="Text prompt for generating a 3D mesh",
88
+ )
89
+ parser.add_argument(
90
+ "--top-k",
91
+ type=int,
92
+ default=1,
93
+ help="Top k filtering, 0 means no filtering, by default 1, which is determistic.",
94
+ )
95
+ parser.add_argument(
96
+ "--render-gif",
97
+ help="Render a turntable gif of the mesh",
98
+ default=False,
99
+ action="store_true",
100
+ )
101
+ parser.add_argument(
102
+ "--disable-postprocessing",
103
+ help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
104
+ default=False,
105
+ action="store_true",
106
+ )
107
+ parser.add_argument(
108
+ "--resolution-base",
109
+ type=float,
110
+ default=8.0,
111
+ help="Resolution base for the shape decoder.",
112
+ )
113
+ args = parser.parse_args()
114
+ os.makedirs(args.output_dir, exist_ok=True)
115
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
116
+ print(f"Using device: {device}")
117
+ # Initialize engine based on fast_inference flag
118
+ if args.fast_inference:
119
+ print(
120
+ "Using cuda graphs, this will take some time to warmup and capture the graph."
121
+ )
122
+ engine = EngineFast(
123
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
124
+ )
125
+ print("Compiled the graph.")
126
+ else:
127
+ engine = Engine(
128
+ args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
129
+ )
130
+
131
+ # Generate meshes based on input source
132
+ obj_path = generate_mesh(
133
+ engine,
134
+ args.prompt,
135
+ args.output_dir,
136
+ "output",
137
+ args.resolution_base,
138
+ args.disable_postprocessing,
139
+ args.top_k,
140
+ )
141
+ if args.render_gif:
142
+ gif_path = renderer.render_turntable(obj_path, args.output_dir)
143
+ print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`")
144
+ print(f"Generated mesh for {args.prompt} at `{obj_path}`")
cube/cube3d/inference/__init__.py ADDED
File without changes
cube/cube3d/inference/engine.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
4
+
5
+ from cube3d.inference.logits_postprocesses import process_logits
6
+ from cube3d.inference.utils import load_config, load_model_weights, parse_structured
7
+ from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
8
+ from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer
9
+ from cube3d.model.transformers.cache import Cache
10
+
11
+
12
+ class Engine:
13
+ def __init__(
14
+ self,
15
+ config_path: str,
16
+ gpt_ckpt_path: str,
17
+ shape_ckpt_path: str,
18
+ device: torch.device,
19
+ ):
20
+ """
21
+ Initializes the inference engine with the given configuration and checkpoint paths.
22
+ Args:
23
+ config_path (str): Path to the configuration file.
24
+ gpt_ckpt_path (str): Path to the GPT model checkpoint file.
25
+ shape_ckpt_path (str): Path to the shape model checkpoint file.
26
+ device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
27
+ Attributes:
28
+ cfg (dict): Loaded configuration from the config file.
29
+ device (torch.device): The device to run the models on.
30
+ gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights.
31
+ shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights.
32
+ text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model.
33
+ text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model.
34
+ max_new_tokens (int): Maximum number of new tokens for the shape model.
35
+ min_id (int): Minimum ID for the shape model codes.
36
+ max_id (int): Maximum ID for the shape model codes.
37
+ """
38
+
39
+ self.cfg = load_config(config_path)
40
+ self.device = device
41
+
42
+ self.gpt_model = DualStreamRoformer(
43
+ parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
44
+ )
45
+ load_model_weights(
46
+ self.gpt_model,
47
+ gpt_ckpt_path,
48
+ )
49
+ self.gpt_model = self.gpt_model.eval().to(self.device)
50
+
51
+ self.shape_model = OneDAutoEncoder(
52
+ parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
53
+ )
54
+ load_model_weights(
55
+ self.shape_model,
56
+ shape_ckpt_path,
57
+ )
58
+ self.shape_model = self.shape_model.eval().to(self.device)
59
+
60
+ # copy vq codebook to gpt
61
+ with torch.no_grad():
62
+ codebook = self.shape_model.bottleneck.block.get_codebook()
63
+ codebook = self.gpt_model.shape_proj(codebook).detach()
64
+ self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook
65
+
66
+ self.text_model = CLIPTextModelWithProjection.from_pretrained(
67
+ self.cfg.text_model_pretrained_model_name_or_path,
68
+ force_download=False,
69
+ device_map=self.device,
70
+ ).eval()
71
+ self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
72
+ self.cfg.text_model_pretrained_model_name_or_path
73
+ )
74
+
75
+ self.max_new_tokens = self.shape_model.cfg.num_encoder_latents
76
+ self.min_id = 0
77
+ self.max_id = self.shape_model.cfg.num_codes
78
+
79
+ @torch.inference_mode()
80
+ def prepare_inputs(self, prompts: list[str], guidance_scale: float):
81
+ """
82
+ Prepares the input embeddings for the model based on the provided prompts and guidance scale.
83
+ Args:
84
+ prompts (list[str]): A list of prompt strings to be encoded.
85
+ guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
86
+ Returns:
87
+ tuple: A tuple containing:
88
+ - embed (torch.Tensor): The encoded input embeddings.
89
+ - cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
90
+ """
91
+
92
+ prompt_embeds = self.run_clip(prompts)
93
+
94
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
95
+ embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id)
96
+
97
+ cond = prompt_embeds
98
+ if guidance_scale > 0.0:
99
+ embed = torch.cat([embed, embed], dim=0)
100
+ uncond_embeds = self.run_clip([""] * len(prompts))
101
+ cond = torch.cat([prompt_embeds, uncond_embeds], dim=0)
102
+
103
+ return embed, cond
104
+
105
+ @torch.inference_mode()
106
+ def run_clip(self, text_inputs):
107
+ """
108
+ Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings.
109
+ Args:
110
+ text_inputs (str or List[str]): The input text or list of texts to be processed.
111
+ Returns:
112
+ torch.Tensor: The encoded text embeddings.
113
+ """
114
+
115
+ text_inputs = self.text_tokenizer(
116
+ text_inputs,
117
+ max_length=self.text_tokenizer.model_max_length,
118
+ padding="max_length",
119
+ truncation=True,
120
+ return_tensors="pt",
121
+ )
122
+ with torch.no_grad():
123
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
124
+ # use full precision for text encoder
125
+ with torch.autocast(device_type=self.device.type, enabled=False):
126
+ encoded = self.text_model(**text_inputs)
127
+ if self.gpt_model.cfg.use_pooled_text_embed:
128
+ embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512]
129
+ else:
130
+ embed = encoded.last_hidden_state # [bs, 77, 512]
131
+ embed = self.gpt_model.encode_text(embed)
132
+
133
+ return embed
134
+
135
+ @torch.inference_mode()
136
+ def encode_input(self, inputs: torch.Tensor, bos: int):
137
+ """
138
+ Encodes the beginning of sequence (BOS) token for the given input tensor.
139
+ Args:
140
+ inputs (torch.Tensor): The input tensor containing sequences.
141
+ bos (int): The beginning of sequence token ID.
142
+ Returns:
143
+ torch.Tensor: The encoded BOS token embeddings.
144
+ """
145
+
146
+ b = inputs.shape[0]
147
+ bos_embed = self.gpt_model.encode_token(
148
+ torch.full(
149
+ (b, 1),
150
+ fill_value=bos,
151
+ dtype=torch.long,
152
+ device=self.device,
153
+ )
154
+ )
155
+ return bos_embed
156
+
157
+ @torch.inference_mode()
158
+ def run_gpt(
159
+ self,
160
+ prompts: list[str],
161
+ use_kv_cache: bool,
162
+ guidance_scale: float = 3.0,
163
+ top_k: int = 1,
164
+ ):
165
+ """
166
+ Generates text using a GPT model based on the provided prompts.
167
+ Args:
168
+ prompts (list[str]): A list of input prompts to generate text from.
169
+ use_kv_cache (bool): Whether to use key-value caching for faster generation.
170
+ guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
171
+ top_k : (int, optional): Top k filtering, 0 means no filtering, by default 1.
172
+ Returns:
173
+ torch.Tensor: A tensor containing the generated token IDs.
174
+ """
175
+ embed, cond = self.prepare_inputs(prompts, guidance_scale)
176
+
177
+ output_ids = []
178
+
179
+ batch_size, input_seq_len, dim = embed.shape
180
+ max_seq_len = input_seq_len + self.max_new_tokens
181
+ embed_buffer = torch.zeros(
182
+ (batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
183
+ )
184
+ embed_buffer[:, :input_seq_len, :].copy_(embed)
185
+ cond_len = cond.shape[1]
186
+ kv_cache = None
187
+ if use_kv_cache:
188
+ kv_cache = self.gpt_model.init_kv_cache(
189
+ batch_size,
190
+ cond_len,
191
+ self.max_new_tokens + 1, # +1 for the BOS token
192
+ torch.bfloat16,
193
+ embed.device,
194
+ )
195
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
196
+ for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
197
+ curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
198
+ logits = self.gpt_model(
199
+ embed_buffer,
200
+ cond,
201
+ kv_cache=kv_cache,
202
+ curr_pos_id=curr_pos_id if use_kv_cache else None,
203
+ decode=(i > 0) if use_kv_cache else False,
204
+ )
205
+ if use_kv_cache:
206
+ logits = logits[:, 0, ...]
207
+ else:
208
+ logits = logits[:, i, ...]
209
+
210
+ logits = logits[..., self.min_id : self.max_id]
211
+
212
+ if guidance_scale > 0.0:
213
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
214
+ gamma = (
215
+ guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
216
+ )
217
+ logits = (1 + gamma) * logits - gamma * uncond_logits
218
+ probs = process_logits(
219
+ logits,
220
+ top_k=top_k,
221
+ )
222
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
223
+ output_ids.append(next_id)
224
+ next_embed = self.gpt_model.encode_token(next_id)
225
+ if guidance_scale > 0.0:
226
+ next_embed = torch.cat([next_embed, next_embed], dim=0)
227
+ embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
228
+
229
+ return torch.cat(output_ids, dim=1)
230
+
231
+ @torch.inference_mode()
232
+ def run_shape_decode(
233
+ self,
234
+ output_ids: torch.Tensor,
235
+ resolution_base: float = 8.0,
236
+ chunk_size: int = 100_000,
237
+ ):
238
+ """
239
+ Decodes the shape from the given output IDs and extracts the geometry.
240
+ Args:
241
+ output_ids (torch.Tensor): The tensor containing the output IDs.
242
+ resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
243
+ chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
244
+ Returns:
245
+ tuple: A tuple containing the vertices and faces of the mesh.
246
+ """
247
+ shape_ids = (
248
+ output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...]
249
+ .clamp_(0, self.shape_model.cfg.num_codes - 1)
250
+ .view(-1, self.shape_model.cfg.num_encoder_latents)
251
+ )
252
+ latents = self.shape_model.decode_indices(shape_ids)
253
+ mesh_v_f, _ = self.shape_model.extract_geometry(
254
+ latents,
255
+ resolution_base=resolution_base,
256
+ chunk_size=chunk_size,
257
+ use_warp=True,
258
+ )
259
+ return mesh_v_f
260
+
261
+ @torch.inference_mode()
262
+ def t2s(
263
+ self,
264
+ prompts: list[str],
265
+ use_kv_cache: bool,
266
+ guidance_scale: float = 3.0,
267
+ resolution_base: float = 8.0,
268
+ chunk_size: int = 100_000,
269
+ top_k: int = 5,
270
+ ):
271
+ """
272
+ Generates a 3D mesh from text prompts using a GPT model and shape decoder.
273
+ Args:
274
+ prompts (list[str]): A list of text prompts to guide the generation.
275
+ use_kv_cache (bool): Whether to use key-value caching for the GPT model.
276
+ guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
277
+ resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
278
+ chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
279
+ Returns:
280
+ mesh_v_f: The generated 3D mesh vertices and faces.
281
+ """
282
+ output_ids = self.run_gpt(prompts, use_kv_cache, guidance_scale, top_k)
283
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
284
+ mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
285
+ return mesh_v_f
286
+
287
+
288
+ class EngineFast(Engine):
289
+ def __init__(
290
+ self,
291
+ config_path: str,
292
+ gpt_ckpt_path: str,
293
+ shape_ckpt_path: str,
294
+ device: torch.device,
295
+ ):
296
+ """
297
+ Initializes the inference engine with the given configuration and checkpoint paths.
298
+ Args:
299
+ config_path (str): Path to the configuration file.
300
+ gpt_ckpt_path (str): Path to the GPT checkpoint file.
301
+ shape_ckpt_path (str): Path to the shape checkpoint file.
302
+ device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
303
+ """
304
+
305
+ super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, device)
306
+
307
+ # CUDA Graph params
308
+ self.graph = torch.cuda.CUDAGraph()
309
+ self.embed_buffer = torch.Tensor()
310
+ self.cond_buffer = torch.Tensor()
311
+ self.logits_buffer = torch.Tensor()
312
+ self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device)
313
+ self.kv_cache: list[Cache] = []
314
+
315
+ self._warmup_and_capture_graph()
316
+
317
+ def _warmup_and_capture_graph(self):
318
+ """
319
+ Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution.
320
+ This method performs the following steps:
321
+ 1. Prepares the input embeddings and conditions using a warmup prompt.
322
+ 2. Initializes buffers for embeddings and conditions.
323
+ 3. Initializes the key-value cache for the GPT model.
324
+ 4. Runs a series of warmup passes to prefill the model and generate logits.
325
+ 5. Captures the CUDA graph for the model's forward pass to optimize future executions.
326
+ """
327
+
328
+ warmup_prompt = "A cube"
329
+ embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0)
330
+
331
+ batch_size, input_seq_len, dim = embed.shape
332
+ max_seq_len = input_seq_len + self.max_new_tokens
333
+ self.embed_buffer = torch.zeros(
334
+ (batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device
335
+ )
336
+ self.embed_buffer[:, :input_seq_len, :].copy_(embed)
337
+
338
+ self.cond_buffer = torch.empty_like(cond)
339
+ self.cond_buffer.copy_(cond)
340
+ cond_len = self.cond_buffer.shape[1]
341
+
342
+ # Initialize kv_cache for the first time
343
+ self.kv_cache = self.gpt_model.init_kv_cache(
344
+ batch_size,
345
+ cond_len,
346
+ self.max_new_tokens + 1, # +1 for the BOS token
347
+ torch.bfloat16,
348
+ self.device,
349
+ )
350
+
351
+ num_warmup_passes = 10
352
+
353
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
354
+ self._set_curr_pos_id(0)
355
+ _ = self._prefill_and_return_logits()
356
+
357
+ for x in range(1, num_warmup_passes):
358
+ self._set_curr_pos_id(x)
359
+ self.logits_buffer = self.gpt_model(
360
+ embed=self.embed_buffer,
361
+ cond=self.cond_buffer,
362
+ kv_cache=self.kv_cache,
363
+ curr_pos_id=self.curr_pos_id,
364
+ decode=True,
365
+ )
366
+
367
+ side_stream = torch.cuda.Stream(device=self.device)
368
+ with torch.cuda.graph(self.graph, stream=side_stream):
369
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
370
+ self.logits_buffer = self.gpt_model(
371
+ embed=self.embed_buffer,
372
+ cond=self.cond_buffer,
373
+ kv_cache=self.kv_cache,
374
+ curr_pos_id=self.curr_pos_id,
375
+ decode=True,
376
+ )
377
+
378
+ def _reset_kv_cache(self):
379
+ """
380
+ Resets the key-value cache by setting all key and value states to zero.
381
+ This method iterates through each cache in the `kv_cache` attribute and
382
+ calls the `zero_()` method on both `key_states` and `value_states` to
383
+ reset them to their initial state.
384
+ """
385
+
386
+ for cache in self.kv_cache:
387
+ cache.key_states.zero_()
388
+ cache.value_states.zero_()
389
+
390
+ def _prefill_and_return_logits(self) -> torch.Tensor:
391
+ """
392
+ Prefills the model's key-value cache and returns the logits.
393
+ This method resets the key-value cache and then performs a forward pass
394
+ through the GPT model in eager mode to prefill the logits.
395
+ Returns:
396
+ torch.Tensor: The prefilled logits tensor with the first dimension removed.
397
+ """
398
+
399
+ self._reset_kv_cache()
400
+
401
+ # Prefill is always eager
402
+ prefill_logits = self.gpt_model(
403
+ embed=self.embed_buffer,
404
+ cond=self.cond_buffer,
405
+ kv_cache=self.kv_cache,
406
+ curr_pos_id=self.curr_pos_id,
407
+ decode=False,
408
+ )
409
+
410
+ return prefill_logits[:, 0, ...]
411
+
412
+ def _set_curr_pos_id(self, pos: int):
413
+ """
414
+ Set the current position ID.
415
+ This method updates the `curr_pos_id` attribute with the given position.
416
+ Args:
417
+ pos (int): The position ID to set.
418
+ """
419
+
420
+ self.curr_pos_id.copy_(
421
+ torch.tensor([pos], dtype=torch.long, device=self.device)
422
+ )
423
+
424
+ def run_gpt(
425
+ self,
426
+ prompts: list[str],
427
+ use_kv_cache: bool,
428
+ guidance_scale: float = 3.0,
429
+ top_k: int = 1,
430
+ ):
431
+ """
432
+ Runs the GPT model to generate text based on the provided prompts.
433
+ Args:
434
+ prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
435
+ use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
436
+ guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
437
+ Returns:
438
+ torch.Tensor: A tensor containing the generated output token IDs.
439
+ Raises:
440
+ AssertionError: If the batch size is greater than 1.
441
+ """
442
+
443
+ embed, cond = self.prepare_inputs(prompts, guidance_scale)
444
+ assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
445
+
446
+ batch_size, input_seq_len, _ = embed.shape
447
+ self.embed_buffer.zero_()
448
+ self.embed_buffer[:, :input_seq_len, :].copy_(embed)
449
+
450
+ assert self.cond_buffer.shape == cond.shape
451
+ self.cond_buffer.copy_(cond)
452
+
453
+ output_ids = torch.zeros(
454
+ (batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
455
+ )
456
+
457
+ with torch.autocast(self.device.type, dtype=torch.bfloat16):
458
+ self._set_curr_pos_id(0)
459
+
460
+ logits = self._prefill_and_return_logits()
461
+
462
+ logits = logits[..., self.min_id : self.max_id]
463
+ if guidance_scale > 0.0:
464
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
465
+ gamma = guidance_scale
466
+ logits = (1 + gamma) * logits - gamma * uncond_logits
467
+
468
+ probs = process_logits(logits, top_k=top_k)
469
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
470
+
471
+ output_ids[:, 0] = next_id.squeeze()
472
+ next_embed = self.gpt_model.encode_token(next_id)
473
+ next_embed = next_embed.repeat(2, 1, 1)
474
+ self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
475
+
476
+ for i in tqdm(
477
+ range(1, self.max_new_tokens), desc=f"generating"
478
+ ):
479
+ self._set_curr_pos_id(i)
480
+ self.graph.replay()
481
+
482
+ logits = self.logits_buffer[:, 0, ...]
483
+
484
+ logits = logits[..., self.min_id : self.max_id]
485
+ if guidance_scale > 0.0:
486
+ logits, uncond_logits = logits.float().chunk(2, dim=0)
487
+ gamma = (
488
+ guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
489
+ )
490
+ logits = (1 + gamma) * logits - gamma * uncond_logits
491
+ probs = process_logits(logits, top_k=top_k)
492
+ next_id = torch.multinomial(probs, num_samples=1, replacement=True)
493
+
494
+ output_ids[:, i] = next_id.squeeze()
495
+ next_embed = self.gpt_model.encode_token(next_id)
496
+ next_embed = next_embed.repeat(2, 1, 1)
497
+ self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
498
+
499
+ return output_ids
cube/cube3d/inference/logits_postprocesses.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def top_k_filtering(logits, top_k: int = 1):
6
+ """
7
+ Filter a distribution of logits using top-k and/or top-p (nucleus) filtering.
8
+ The input logits tensor is modified in-place.
9
+
10
+ Args:
11
+ logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size].
12
+ top_k: If > 0, only keep the top k tokens with highest probability.
13
+ top_p: If < 1.0, only keep tokens whose cumulative probability is below this threshold.
14
+
15
+ Returns:
16
+ A tensor of logits where values outside the top-k/top-p threshold are set to -∞.
17
+ """
18
+ if top_k > 0:
19
+ idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[
20
+ 0
21
+ ].amin(dim=-1, keepdim=True)
22
+ logits.masked_fill_(idx_to_remove, -torch.inf)
23
+
24
+ return logits
25
+
26
+
27
+ def process_logits(
28
+ logits,
29
+ top_k: int = 1,
30
+ ):
31
+ """
32
+ Process logits by optionally applying top-k filtering.
33
+ The final probabilities are returned after applying softmax on the filtered logits.
34
+
35
+ Args:
36
+ logits: A tensor of logits to process. Expected shape is [..., vocab_size].
37
+ top_k: If > 0, only keep the top k tokens with highest probability.
38
+
39
+ Returns:
40
+ A tensor of probabilities after filtering, with the same shape as the input logits.
41
+ """
42
+ logits = top_k_filtering(logits, top_k=top_k)
43
+ probs = F.softmax(logits, dim=-1)
44
+ return probs
cube/cube3d/inference/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from omegaconf import DictConfig, OmegaConf
6
+ from safetensors.torch import load_model
7
+
8
+
9
+ def load_config(cfg_path: str) -> Any:
10
+ """
11
+ Load and resolve a configuration file.
12
+ Args:
13
+ cfg_path (str): The path to the configuration file.
14
+ Returns:
15
+ Any: The loaded and resolved configuration object.
16
+ Raises:
17
+ AssertionError: If the loaded configuration is not an instance of DictConfig.
18
+ """
19
+
20
+ cfg = OmegaConf.load(cfg_path)
21
+ OmegaConf.resolve(cfg)
22
+ assert isinstance(cfg, DictConfig)
23
+ return cfg
24
+
25
+
26
+ def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
27
+ """
28
+ Parses a configuration dictionary into a structured configuration object.
29
+ Args:
30
+ cfg_type (Any): The type of the structured configuration object.
31
+ cfg (DictConfig): The configuration dictionary to be parsed.
32
+ Returns:
33
+ Any: The structured configuration object created from the dictionary.
34
+ """
35
+
36
+ scfg = OmegaConf.structured(cfg_type(**cfg))
37
+ return scfg
38
+
39
+
40
+ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
41
+ """
42
+ Load a safetensors checkpoint into a PyTorch model.
43
+ The model is updated in place.
44
+
45
+ Args:
46
+ model: PyTorch model to load weights into
47
+ ckpt_path: Path to the safetensors checkpoint file
48
+
49
+ Returns:
50
+ None
51
+ """
52
+ assert ckpt_path.endswith(".safetensors"), (
53
+ f"Checkpoint path '{ckpt_path}' is not a safetensors file"
54
+ )
55
+
56
+ load_model(model, ckpt_path)
cube/cube3d/mesh_utils/postprocessing.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+
5
+ try:
6
+ import pymeshlab
7
+
8
+ PYMESHLAB_AVAILABLE = True
9
+ except ImportError:
10
+ logging.warning(
11
+ "pymeshlab is not installed or could not be loaded. Please install it with `pip install pymeshlab`."
12
+ )
13
+ PYMESHLAB_AVAILABLE = False
14
+ from typing import Any
15
+
16
+ # Create stub class for typing
17
+ class pymeshlab:
18
+ MeshSet = Any
19
+ Mesh = Any
20
+
21
+
22
+ def create_pymeshset(vertices: np.ndarray, faces: np.ndarray):
23
+ """
24
+ Creates a MeshLab MeshSet given a list of vertices and faces.
25
+ """
26
+ assert PYMESHLAB_AVAILABLE, "pymeshlab is not installed or could not be loaded."
27
+ # Initialize MeshSet and create pymeshlab.Mesh
28
+ mesh_set = pymeshlab.MeshSet()
29
+ input_mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces)
30
+ mesh_set.add_mesh(input_mesh, "input_mesh")
31
+ logging.info("Mesh successfully added to pymeshlab MeshSet.")
32
+ return mesh_set
33
+
34
+
35
+ def cleanup(ms: pymeshlab.MeshSet):
36
+ """
37
+ General cleanup for a given Mesh. Removes degenerate elements from the
38
+ geometry.
39
+ """
40
+ ms.meshing_remove_null_faces()
41
+ ms.meshing_remove_folded_faces()
42
+ ms.meshing_remove_duplicate_vertices()
43
+ ms.meshing_remove_duplicate_faces()
44
+ ms.meshing_remove_t_vertices()
45
+ ms.meshing_remove_unreferenced_vertices()
46
+
47
+
48
+ def remove_floaters(ms: pymeshlab.MeshSet, threshold: float = 0.005):
49
+ """
50
+ Remove any floating artifacts that exist from our mesh generation.
51
+ """
52
+ assert PYMESHLAB_AVAILABLE, "pymeshlab is not installed or could not be loaded."
53
+ ms.meshing_remove_connected_component_by_diameter(
54
+ mincomponentdiag=pymeshlab.PercentageValue(15), removeunref=True
55
+ )
56
+
57
+
58
+ def simplify_mesh(ms: pymeshlab.MeshSet, target_face_num: int):
59
+ """
60
+ Simplify the mesh to the target number of faces.
61
+ """
62
+ ms.meshing_decimation_quadric_edge_collapse(
63
+ targetfacenum=target_face_num,
64
+ qualitythr=0.4,
65
+ preservenormal=True,
66
+ autoclean=True,
67
+ )
68
+
69
+
70
+ def save_mesh(ms: pymeshlab.MeshSet, output_path: str):
71
+ """
72
+ Save the mesh to a file.
73
+ """
74
+ ms.save_current_mesh(output_path)
75
+ logging.info(f"Mesh saved to {output_path}.")
76
+
77
+
78
+ def postprocess_mesh(ms: pymeshlab.MeshSet, target_face_num: int, output_path: str):
79
+ """
80
+ Postprocess the mesh to the target number of faces.
81
+ """
82
+ cleanup(ms)
83
+ remove_floaters(ms)
84
+ simplify_mesh(ms, target_face_num)
cube/cube3d/model/__init__.py ADDED
File without changes
cube/cube3d/model/autoencoder/__init__.py ADDED
File without changes
cube/cube3d/model/autoencoder/embedder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class PhaseModulatedFourierEmbedder(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ num_freqs: int,
11
+ input_dim: int = 3,
12
+ ):
13
+ """
14
+ Initializes the PhaseModulatedFourierEmbedder class.
15
+ Args:
16
+ num_freqs (int): The number of frequencies to be used.
17
+ input_dim (int, optional): The dimension of the input. Defaults to 3.
18
+ Attributes:
19
+ weight (torch.nn.Parameter): The weight parameter initialized with random values.
20
+ carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem.
21
+ out_dim (int): The output dimension calculated based on the input dimension and number of frequencies.
22
+ """
23
+
24
+ super().__init__()
25
+
26
+ self.weight = nn.Parameter(
27
+ torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs)
28
+ )
29
+
30
+ # NOTE this is the highest frequency we can get (2 for peaks, 2 for zeros, and 4 for interpolation points), see also https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
31
+ carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
32
+ carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi
33
+ self.register_buffer("carrier", carrier, persistent=False)
34
+
35
+ self.out_dim = input_dim * (num_freqs * 2 + 1)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Perform the forward pass of the embedder model.
40
+ Args:
41
+ x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim).
42
+ Returns:
43
+ torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where
44
+ output_dim = input_dim + 2 * input_dim.
45
+ """
46
+
47
+ m = x.float().unsqueeze(-1)
48
+ fm = (m * self.weight).view(*x.shape[:-1], -1)
49
+ pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1)
50
+ embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1)
51
+
52
+ return embedding
cube/cube3d/model/autoencoder/grid.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import warp as wp
6
+
7
+
8
+ def generate_dense_grid_points(
9
+ bbox_min: np.ndarray,
10
+ bbox_max: np.ndarray,
11
+ resolution_base: float,
12
+ indexing: Literal["xy", "ij"] = "ij",
13
+ ) -> tuple[np.ndarray, list[int], np.ndarray]:
14
+ """
15
+ Generate a dense grid of points within a bounding box.
16
+
17
+ Parameters:
18
+ bbox_min (np.ndarray): The minimum coordinates of the bounding box (3D).
19
+ bbox_max (np.ndarray): The maximum coordinates of the bounding box (3D).
20
+ resolution_base (float): The base resolution for the grid. The number of cells along each axis will be 2^resolution_base.
21
+ indexing (Literal["xy", "ij"], optional): The indexing convention for the grid. "xy" for Cartesian indexing, "ij" for matrix indexing. Default is "ij".
22
+ Returns:
23
+ tuple: A tuple containing:
24
+ - xyz (np.ndarray): A 2D array of shape (N, 3) where N is the total number of grid points. Each row represents the (x, y, z) coordinates of a grid point.
25
+ - grid_size (list): A list of three integers representing the number of grid points along each axis.
26
+ - length (np.ndarray): The length of the bounding box along each axis.
27
+ """
28
+ length = bbox_max - bbox_min
29
+ num_cells = np.exp2(resolution_base)
30
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
31
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
32
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
33
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
34
+ xyz = np.stack((xs, ys, zs), axis=-1)
35
+ xyz = xyz.reshape(-1, 3)
36
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
37
+
38
+ return xyz, grid_size, length
39
+
40
+
41
+ def marching_cubes_with_warp(
42
+ grid_logits: torch.Tensor,
43
+ level: float,
44
+ device: Union[str, torch.device] = "cuda",
45
+ max_verts: int = 3_000_000,
46
+ max_tris: int = 3_000_000,
47
+ ) -> tuple[np.ndarray, np.ndarray]:
48
+ """
49
+ Perform the marching cubes algorithm on a 3D grid with warp support.
50
+ Args:
51
+ grid_logits (torch.Tensor): A 3D tensor containing the grid logits.
52
+ level (float): The threshold level for the isosurface.
53
+ device (Union[str, torch.device], optional): The device to perform the computation on. Defaults to "cuda".
54
+ max_verts (int, optional): The maximum number of vertices. Defaults to 3,000,000.
55
+ max_tris (int, optional): The maximum number of triangles. Defaults to 3,000,000.
56
+ Returns:
57
+ Tuple[np.ndarray, np.ndarray]: A tuple containing the vertices and faces of the isosurface.
58
+ """
59
+ if isinstance(device, torch.device):
60
+ device = str(device)
61
+
62
+ assert grid_logits.ndim == 3
63
+ if "cuda" in device:
64
+ assert wp.is_cuda_available()
65
+ else:
66
+ raise ValueError(
67
+ f"Device {device} is not supported for marching_cubes_with_warp"
68
+ )
69
+
70
+ dim = grid_logits.shape[0]
71
+ field = wp.from_torch(grid_logits)
72
+
73
+ iso = wp.MarchingCubes(
74
+ nx=dim,
75
+ ny=dim,
76
+ nz=dim,
77
+ max_verts=int(max_verts),
78
+ max_tris=int(max_tris),
79
+ device=device,
80
+ )
81
+ iso.surface(field=field, threshold=level)
82
+ vertices = iso.verts.numpy()
83
+ faces = iso.indices.numpy().reshape(-1, 3)
84
+ return vertices, faces
cube/cube3d/model/autoencoder/one_d_autoencoder.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from functools import partial
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from skimage import measure
11
+ from torch.nn import functional as F
12
+ from tqdm import tqdm
13
+
14
+ from cube3d.model.autoencoder.embedder import PhaseModulatedFourierEmbedder
15
+ from cube3d.model.autoencoder.grid import (
16
+ generate_dense_grid_points,
17
+ marching_cubes_with_warp,
18
+ )
19
+ from cube3d.model.autoencoder.spherical_vq import SphericalVectorQuantizer
20
+ from cube3d.model.transformers.attention import (
21
+ EncoderCrossAttentionLayer,
22
+ EncoderLayer,
23
+ init_linear,
24
+ init_tfixup,
25
+ )
26
+ from cube3d.model.transformers.norm import LayerNorm
27
+
28
+
29
+ def init_sort(x):
30
+ """
31
+ Sorts the input tensor `x` based on its pairwise distances to the first element.
32
+ This function computes the pairwise distances between all elements in `x` and the
33
+ first element of `x`. It then sorts the elements of `x` in ascending order of
34
+ their distances to the first element.
35
+ Args:
36
+ x (torch.Tensor): A 2D tensor where each row represents a data point.
37
+ Returns:
38
+ torch.Tensor: A tensor containing the rows of `x` sorted by their distances
39
+ to the first row of `x`.
40
+ """
41
+
42
+ distances = torch.cdist(x, x[:1])
43
+ _, indices = torch.sort(distances.squeeze(), dim=0)
44
+ x = x[indices]
45
+ return x
46
+
47
+
48
+ class MLPEmbedder(nn.Module):
49
+ def __init__(self, in_dim: int, embed_dim: int, bias: bool = True):
50
+ super().__init__()
51
+ self.in_layer = nn.Linear(in_dim, embed_dim, bias=bias)
52
+ self.silu = nn.SiLU()
53
+ self.out_layer = nn.Linear(embed_dim, embed_dim, bias=bias)
54
+
55
+ self.apply(partial(init_linear, embed_dim=embed_dim))
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return self.out_layer(self.silu(self.in_layer(x)))
59
+
60
+
61
+ class OneDEncoder(nn.Module):
62
+ def __init__(
63
+ self,
64
+ embedder,
65
+ num_latents: int,
66
+ point_feats: int,
67
+ embed_point_feats: bool,
68
+ width: int,
69
+ num_heads: int,
70
+ num_layers: int,
71
+ with_cls_token: bool = False,
72
+ cross_attention_levels: Optional[List[int]] = None,
73
+ eps: float = 1e-6,
74
+ ) -> None:
75
+ """
76
+ Initializes the OneDEncoder model.
77
+ Args:
78
+ embedder: An embedding module that provides the input embedding functionality.
79
+ num_latents (int): The number of latent variables.
80
+ point_feats (int): The number of point features.
81
+ embed_point_feats (bool): Whether to embed point features or not.
82
+ width (int): The width of the embedding dimension.
83
+ num_heads (int): The number of attention heads.
84
+ num_layers (int): The number of encoder layers.
85
+ with_cls_token (bool, optional): Whether to include a classification token like in Vision Transformers (ViT). Defaults to False.
86
+ cross_attention_levels (Optional[List[int]], optional): The indices of layers where cross-attention is applied. Defaults to None.
87
+ eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
88
+ Returns:
89
+ None
90
+ """
91
+ super().__init__()
92
+
93
+ self.embedder = embedder
94
+
95
+ # add cls token like ViT
96
+ self.with_cls_token = with_cls_token
97
+ if self.with_cls_token:
98
+ query = torch.empty((1 + num_latents, width))
99
+ else:
100
+ query = torch.empty((num_latents, width))
101
+
102
+ # initialize then sort query to potentially get better ordering
103
+ query.uniform_(-1.0, 1.0)
104
+ query = init_sort(query)
105
+
106
+ # set parameter
107
+ self.query = nn.Parameter(query)
108
+
109
+ self.embed_point_feats = embed_point_feats
110
+ in_dim = (
111
+ self.embedder.out_dim * 2
112
+ if self.embed_point_feats
113
+ else self.embedder.out_dim + point_feats
114
+ )
115
+ self.feat_in = MLPEmbedder(in_dim, embed_dim=width)
116
+
117
+ if cross_attention_levels is None:
118
+ cross_attention_levels = [0]
119
+
120
+ self.blocks = nn.ModuleList()
121
+ for i in range(num_layers):
122
+ if i in cross_attention_levels:
123
+ self.blocks.append(
124
+ EncoderCrossAttentionLayer(
125
+ embed_dim=width,
126
+ num_heads=num_heads,
127
+ eps=eps,
128
+ )
129
+ )
130
+ else:
131
+ self.blocks.append(
132
+ EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
133
+ )
134
+ self.ln_f = LayerNorm(width, eps=eps)
135
+
136
+ init_tfixup(self, num_layers)
137
+
138
+ def _forward(self, h, data, attn_mask=None):
139
+ """
140
+ Forward pass for the autoencoder model.
141
+
142
+ Args:
143
+ h (torch.Tensor): The input tensor to be processed, typically representing
144
+ the hidden state or intermediate representation.
145
+ data (torch.Tensor): The input data tensor to be transformed by the feature
146
+ extraction layer and used in cross-attention layers.
147
+ attn_mask (torch.Tensor, optional): An optional attention mask tensor to be
148
+ used in attention layers for masking specific positions. Defaults to None.
149
+ Returns:
150
+ torch.Tensor: The output tensor after processing through the layers and
151
+ applying final normalization.
152
+ """
153
+
154
+ data = self.feat_in(data)
155
+
156
+ for block in self.blocks:
157
+ if isinstance(block, EncoderCrossAttentionLayer):
158
+ h = block(h, data)
159
+ else:
160
+ h = block(h, attn_mask=attn_mask)
161
+
162
+ h = self.ln_f(h)
163
+ return h
164
+
165
+ def forward(
166
+ self, pts: torch.Tensor, feats: torch.Tensor
167
+ ) -> Tuple[torch.Tensor, list[torch.Tensor]]:
168
+ """
169
+ Forward pass of the 1D autoencoder model.
170
+ Args:
171
+ pts (torch.Tensor): Input tensor representing points with shape (batch_size, num_points, point_dim).
172
+ feats (torch.Tensor): Input tensor representing features with shape (batch_size, num_points, feature_dim).
173
+ Can be None if no features are provided.
174
+ Returns:
175
+ Tuple[torch.Tensor, list[torch.Tensor]]:
176
+ - The output tensor after processing the input data.
177
+ - A list of intermediate tensors (if applicable) generated during the forward pass.
178
+ """
179
+
180
+ b = pts.shape[0]
181
+ data = self.embedder(pts)
182
+
183
+ if feats is not None:
184
+ if self.embed_point_feats:
185
+ feats = self.embedder(feats)
186
+ data = torch.cat([data, feats], dim=-1)
187
+
188
+ # prepare query and data
189
+ h = self.query.unsqueeze(0).expand(b, -1, -1)
190
+ return self._forward(h, data, attn_mask=None)
191
+
192
+
193
+ class OneDBottleNeck(nn.Module):
194
+ def __init__(
195
+ self,
196
+ block,
197
+ ) -> None:
198
+ """
199
+ Initializes the OneDBottleNeck class.
200
+ Args:
201
+ block: The building block or module used within the autoencoder.
202
+ """
203
+ super().__init__()
204
+
205
+ self.block = block
206
+
207
+ def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, dict]:
208
+ """
209
+ Forward pass of the OneDBottleNeck function.
210
+ Args:
211
+ h (torch.Tensor): Input tensor to the model.
212
+ Returns:
213
+ Tuple[torch.Tensor, dict]: A tuple containing:
214
+ - The transformed tensor `z` after passing through the block (if applicable).
215
+ - A dictionary `ret_dict` containing additional information:
216
+ - "indices": Indices from the block output (if present).
217
+ - "z_q": Quantized tensor from the block output (if present).
218
+
219
+ """
220
+
221
+ z = h
222
+ ret_dict = {}
223
+ if self.block is not None:
224
+ z, d = self.block(z)
225
+
226
+ key_mappings = {
227
+ "q": "indices",
228
+ "z_q": "z_q",
229
+ }
230
+ for in_key, out_key in key_mappings.items():
231
+ if in_key in d:
232
+ ret_dict[out_key] = d[in_key]
233
+
234
+ return z, ret_dict
235
+
236
+
237
+ class OneDDecoder(nn.Module):
238
+ def __init__(
239
+ self,
240
+ num_latents: int,
241
+ width: int,
242
+ num_heads: int,
243
+ num_layers: int,
244
+ eps: float = 1e-6,
245
+ ) -> None:
246
+ """
247
+ Initializes the OneDDecoder class.
248
+ Args:
249
+ num_latents (int): The number of latent variables.
250
+ width (int): The width of the embedding dimension.
251
+ num_heads (int): The number of attention heads in each encoder layer.
252
+ num_layers (int): The number of encoder layers.
253
+ eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
254
+ """
255
+ super().__init__()
256
+
257
+ self.register_buffer("query", torch.empty([0, width]), persistent=False)
258
+ self.positional_encodings = nn.Parameter(
259
+ init_sort(F.normalize(torch.empty(num_latents, width).normal_()))
260
+ )
261
+ self.blocks = nn.ModuleList(
262
+ [
263
+ EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
264
+ for _ in range(num_layers)
265
+ ]
266
+ )
267
+
268
+ init_tfixup(self, num_layers)
269
+
270
+ def _forward(self, h):
271
+ """
272
+ Applies a sequence of operations to the input tensor `h` using the blocks
273
+ defined in the model.
274
+ Args:
275
+ h (torch.Tensor): The input tensor to be processed by the blocks.
276
+ Returns:
277
+ torch.Tensor: The output tensor after applying all blocks sequentially.
278
+ """
279
+
280
+ for block in self.blocks:
281
+ h = block(h)
282
+ return h
283
+
284
+ def forward(self, z):
285
+ """
286
+ This method processes the input tensor `z` by padding it to a fixed length,
287
+ adding positional encodings, and then passing it through the `_forward` method.
288
+
289
+ Args:
290
+ z (torch.Tensor): Input tensor.
291
+ Returns:
292
+ torch.Tensor: Output tensor after processing through the autoencoder.
293
+ Notes:
294
+ - If the `query` attribute has a non-zero shape, the input tensor `z` is padded
295
+ to match the required length using slices of `query`.
296
+ - Positional encodings are added to the padded input tensor before passing it
297
+ to the `_forward` method.
298
+ """
299
+
300
+ # pad input to fixed length
301
+ if self.query.shape[0] > 0:
302
+ pad_len = self.query.shape[0] + 1 - z.shape[1]
303
+ paddings = self.query[:pad_len, ...].unsqueeze(0).expand(z.shape[0], -1, -1)
304
+ z = torch.cat([paddings, z], dim=1)
305
+ h = z + self.positional_encodings[: z.shape[1], :].unsqueeze(0).expand(
306
+ z.shape[0], -1, -1
307
+ )
308
+
309
+ return self._forward(h)
310
+
311
+
312
+ class OneDOccupancyDecoder(nn.Module):
313
+ def __init__(
314
+ self, embedder, out_features: int, width: int, num_heads: int, eps=1e-6
315
+ ) -> None:
316
+ """
317
+ Initializes the OneDOccupancyDecoder module.
318
+ Args:
319
+ embedder: An embedding module that provides input embeddings.
320
+ out_features (int): The number of output features for the final linear layer.
321
+ width (int): The width of the intermediate layers.
322
+ num_heads (int): The number of attention heads for the cross-attention layer.
323
+ eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6.
324
+ """
325
+ super().__init__()
326
+
327
+ self.embedder = embedder
328
+ self.query_in = MLPEmbedder(self.embedder.out_dim, width)
329
+
330
+ self.attn_out = EncoderCrossAttentionLayer(embed_dim=width, num_heads=num_heads)
331
+ self.ln_f = LayerNorm(width, eps=eps)
332
+ self.c_head = nn.Linear(width, out_features)
333
+
334
+ def query(self, queries: torch.Tensor):
335
+ """
336
+ Processes the input tensor through the embedder and query_in layers.
337
+ Args:
338
+ queries (torch.Tensor): A tensor containing the input data to be processed.
339
+ Returns:
340
+ torch.Tensor: The output tensor after being processed by the embedder and query_in layers.
341
+ """
342
+
343
+ return self.query_in(self.embedder(queries))
344
+
345
+ def forward(self, queries: torch.Tensor, latents: torch.Tensor):
346
+ """
347
+ Defines the forward pass of the model.
348
+ Args:
349
+ queries (torch.Tensor): Input tensor representing the queries.
350
+ latents (torch.Tensor): Input tensor representing the latent representations.
351
+ Returns:
352
+ torch.Tensor: Output tensor after applying the query transformation,
353
+ attention mechanism, and final processing layers.
354
+ """
355
+ queries = self.query(queries)
356
+ x = self.attn_out(queries, latents)
357
+ x = self.c_head(self.ln_f(x))
358
+ return x
359
+
360
+
361
+ class OneDAutoEncoder(nn.Module):
362
+ @dataclass
363
+ class Config:
364
+ checkpoint_path: str = ""
365
+
366
+ # network params
367
+ num_encoder_latents: int = 256
368
+ num_decoder_latents: int = 256
369
+ embed_dim: int = 12
370
+ width: int = 768
371
+ num_heads: int = 12
372
+ out_dim: int = 1
373
+ eps: float = 1e-6
374
+
375
+ # grid features embedding
376
+ num_freqs: int = 128
377
+ point_feats: int = 0
378
+ embed_point_feats: bool = False
379
+
380
+ num_encoder_layers: int = 1
381
+ encoder_cross_attention_levels: list[int] = field(default_factory=list)
382
+ num_decoder_layers: int = 23
383
+
384
+ encoder_with_cls_token: bool = True
385
+ num_codes: int = 16384
386
+
387
+ def __init__(self, cfg: Config) -> None:
388
+ """
389
+ Initializes the OneDAutoencoder model.
390
+ Args:
391
+ cfg (Config): Configuration object containing the parameters for the model.
392
+ Attributes:
393
+ cfg (Config): Stores the configuration object.
394
+ embedder (PhaseModulatedFourierEmbedder): Embeds input data using phase-modulated Fourier features.
395
+ encoder (OneDEncoder): Encodes the input data into latent representations.
396
+ bottleneck (OneDBottleNeck): Bottleneck layer containing a spherical vector quantizer for dimensionality reduction.
397
+ decoder (OneDDecoder): Decodes latent representations back into the original data space.
398
+ occupancy_decoder (OneDOccupancyDecoder): Decodes occupancy information from latent representations.
399
+ """
400
+
401
+ super().__init__()
402
+
403
+ self.cfg = cfg
404
+
405
+ self.embedder = PhaseModulatedFourierEmbedder(
406
+ num_freqs=self.cfg.num_freqs, input_dim=3
407
+ )
408
+
409
+ self.encoder = OneDEncoder(
410
+ embedder=self.embedder,
411
+ num_latents=self.cfg.num_encoder_latents,
412
+ with_cls_token=self.cfg.encoder_with_cls_token,
413
+ point_feats=self.cfg.point_feats,
414
+ embed_point_feats=self.cfg.embed_point_feats,
415
+ width=self.cfg.width,
416
+ num_heads=self.cfg.num_heads,
417
+ num_layers=self.cfg.num_encoder_layers,
418
+ cross_attention_levels=self.cfg.encoder_cross_attention_levels,
419
+ eps=self.cfg.eps,
420
+ )
421
+
422
+ block = SphericalVectorQuantizer(
423
+ self.cfg.embed_dim,
424
+ self.cfg.num_codes,
425
+ self.cfg.width,
426
+ codebook_regularization="kl",
427
+ )
428
+ self.bottleneck = OneDBottleNeck(block=block)
429
+
430
+ self.decoder = OneDDecoder(
431
+ num_latents=self.cfg.num_encoder_latents,
432
+ width=self.cfg.width,
433
+ num_heads=self.cfg.num_heads,
434
+ num_layers=self.cfg.num_decoder_layers,
435
+ eps=self.cfg.eps,
436
+ )
437
+
438
+ self.occupancy_decoder = OneDOccupancyDecoder(
439
+ embedder=self.embedder,
440
+ out_features=self.cfg.out_dim,
441
+ width=self.cfg.width,
442
+ num_heads=self.cfg.num_heads,
443
+ eps=self.cfg.eps,
444
+ )
445
+
446
+ @torch.no_grad()
447
+ def decode_indices(self, shape_ids: torch.Tensor):
448
+ """
449
+ Decodes the given shape indices into latent representations.
450
+ Args:
451
+ shape_ids (torch.Tensor): A tensor containing the shape indices to be decoded.
452
+ Returns:
453
+ torch.Tensor: The decoded latent representations corresponding to the input shape indices.
454
+ """
455
+
456
+ z_q = self.bottleneck.block.lookup_codebook(shape_ids)
457
+ latents = self.decode(z_q)
458
+ return latents
459
+
460
+ @torch.no_grad()
461
+ def query_embeds(self, shape_ids: torch.Tensor):
462
+ """
463
+ Retrieves the latent embeddings corresponding to the given shape IDs.
464
+ Args:
465
+ shape_ids (torch.Tensor): A tensor containing the IDs of the shapes
466
+ for which the latent embeddings are to be queried.
467
+ Returns:
468
+ torch.Tensor: A tensor containing the latent embeddings retrieved
469
+ from the codebook for the provided shape IDs.
470
+ """
471
+
472
+ z_q = self.bottleneck.block.lookup_codebook_latents(shape_ids)
473
+ return z_q
474
+
475
+ @torch.no_grad()
476
+ def query_indices(self, shape_embs: torch.Tensor):
477
+ """
478
+ Queries the indices of the quantized embeddings from the bottleneck layer.
479
+ Args:
480
+ shape_embs (torch.Tensor): The input tensor containing shape embeddings
481
+ to be quantized.
482
+ Returns:
483
+ torch.Tensor: A tensor containing the quantized indices.
484
+ """
485
+
486
+ _, ret_dict = self.bottleneck.block.quantize(shape_embs)
487
+ return ret_dict["q"]
488
+
489
+ def encode(self, x: torch.Tensor, **kwargs):
490
+ """
491
+ Encodes the input tensor using the encoder and bottleneck layers.
492
+ Args:
493
+ x (torch.Tensor): Input tensor with shape (..., N), where the first 3
494
+ dimensions represent points (pts) and the remaining dimensions
495
+ represent features (feats).
496
+ **kwargs: Additional keyword arguments.
497
+ Returns:
498
+ Tuple[torch.Tensor, torch.Tensor, None, dict]: A tuple containing:
499
+ - z_e (torch.Tensor): Encoded tensor before bottleneck processing.
500
+ - z (torch.Tensor): Encoded tensor after bottleneck processing.
501
+ - None: Placeholder for compatibility with other methods.
502
+ - d (dict): Dictionary containing additional information, including:
503
+ - "z_cls" (torch.Tensor, optional): Class token if
504
+ `self.cfg.encoder_with_cls_token` is True.
505
+ """
506
+
507
+ pts, feats = x[..., :3], x[..., 3:]
508
+ z_e = self.encoder(pts, feats)
509
+
510
+ # split class token
511
+ if self.cfg.encoder_with_cls_token:
512
+ z_cls = z_e[:, 0, ...]
513
+ z_e = z_e[:, 1:, ...]
514
+
515
+ # quantize or kl
516
+ z, d = self.bottleneck(z_e)
517
+
518
+ if self.cfg.encoder_with_cls_token:
519
+ d["z_cls"] = z_cls
520
+ return z_e, z, None, d
521
+
522
+ def decode(self, z: torch.Tensor):
523
+ """
524
+ Decodes the latent representation `z` using the decoder network.
525
+ Args:
526
+ z (torch.Tensor): The latent representation tensor to be decoded.
527
+ Returns:
528
+ torch.Tensor: The decoded output tensor.
529
+ """
530
+
531
+ h = self.decoder(z)
532
+ return h
533
+
534
+ def query(self, queries: torch.Tensor, latents: torch.Tensor):
535
+ """
536
+ Computes the logits by decoding the given queries and latent representations.
537
+ Args:
538
+ queries (torch.Tensor): A tensor containing the query points to be decoded.
539
+ latents (torch.Tensor): A tensor containing the latent representations corresponding to the queries.
540
+ Returns:
541
+ torch.Tensor: A tensor containing the decoded logits for the given queries and latents.
542
+ """
543
+
544
+ logits = self.occupancy_decoder(queries, latents).squeeze(-1)
545
+ return logits
546
+
547
+ def forward(self, surface, queries, **kwargs):
548
+ """
549
+ Perform a forward pass through the autoencoder model.
550
+ Args:
551
+ surface (torch.Tensor): The input surface tensor to be encoded.
552
+ queries (torch.Tensor): The query tensor used for generating logits.
553
+ **kwargs: Additional keyword arguments.
554
+ Returns:
555
+ tuple: A tuple containing:
556
+ - z (torch.Tensor): The latent representation of the input surface.
557
+ - latents (torch.Tensor): The decoded output from the latent representation.
558
+ - None: Placeholder for a potential future return value.
559
+ - logits (torch.Tensor): The logits generated from the queries and latents.
560
+ - d (torch.Tensor): Additional output from the encoding process.
561
+ """
562
+
563
+ _, z, _, d = self.encode(surface)
564
+
565
+ latents = self.decode(z)
566
+ logits = self.query(queries, latents)
567
+
568
+ return z, latents, None, logits, d
569
+
570
+ @torch.no_grad()
571
+ def extract_geometry(
572
+ self,
573
+ latents: torch.FloatTensor,
574
+ bounds: list[float] = [
575
+ -1.05,
576
+ -1.05,
577
+ -1.05,
578
+ 1.05,
579
+ 1.05,
580
+ 1.05,
581
+ ],
582
+ resolution_base: float = 9.0,
583
+ chunk_size: int = 2_000_000,
584
+ use_warp: bool = False,
585
+ ):
586
+ """
587
+ Extracts 3D geometry from latent representations using a dense grid sampling
588
+ and marching cubes algorithm.
589
+ Args:
590
+ latents (torch.FloatTensor): A tensor of latent representations with shape
591
+ (batch_size, latent_dim).
592
+ bounds (list[float], optional): A list of six floats defining the bounding box
593
+ for the 3D grid in the format [xmin, ymin, zmin, xmax, ymax, zmax].
594
+ Defaults to [-1.05, -1.05, -1.05, 1.05, 1.05, 1.05].
595
+ resolution_base (float, optional): The base resolution for the grid. Higher
596
+ values result in finer grids. Defaults to 9.0.
597
+ chunk_size (int, optional): The number of grid points to process in a single
598
+ chunk. Defaults to 2,000,000.
599
+ use_warp (bool, optional): Whether to use a GPU-accelerated marching cubes
600
+ implementation. If False, falls back to a CPU implementation. Defaults to False.
601
+ Returns:
602
+ tuple:
603
+ - mesh_v_f (list[tuple]): A list of tuples containing vertices and faces
604
+ for each batch element. Each tuple is of the form
605
+ (vertices, faces), where:
606
+ - vertices (np.ndarray): Array of vertex coordinates with shape
607
+ (num_vertices, 3).
608
+ - faces (np.ndarray): Array of face indices with shape
609
+ (num_faces, 3).
610
+ If geometry extraction fails for a batch element, the tuple will be
611
+ (None, None).
612
+ - has_surface (np.ndarray): A boolean array indicating whether a surface
613
+ was successfully extracted for each batch element.
614
+ Raises:
615
+ Exception: Logs warnings or errors if geometry extraction fails for any
616
+ batch element or if the marching cubes algorithm encounters issues.
617
+ """
618
+ bbox_min = np.array(bounds[0:3])
619
+ bbox_max = np.array(bounds[3:6])
620
+ bbox_size = bbox_max - bbox_min
621
+
622
+ xyz_samples, grid_size, length = generate_dense_grid_points(
623
+ bbox_min=bbox_min,
624
+ bbox_max=bbox_max,
625
+ resolution_base=resolution_base,
626
+ indexing="ij",
627
+ )
628
+ xyz_samples = torch.FloatTensor(xyz_samples)
629
+ batch_size = latents.shape[0]
630
+
631
+ batch_logits = []
632
+
633
+ progress_bar = tqdm(
634
+ range(0, xyz_samples.shape[0], chunk_size),
635
+ desc=f"extracting geometry",
636
+ unit="chunk",
637
+ )
638
+ for start in progress_bar:
639
+ queries = xyz_samples[start : start + chunk_size, :]
640
+
641
+ num_queries = queries.shape[0]
642
+ if start > 0 and num_queries < chunk_size:
643
+ queries = F.pad(queries, [0, 0, 0, chunk_size - num_queries])
644
+ batch_queries = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
645
+
646
+ logits = self.query(batch_queries, latents)[:, :num_queries]
647
+ batch_logits.append(logits)
648
+
649
+ grid_logits = (
650
+ torch.cat(batch_logits, dim=1)
651
+ .detach()
652
+ .view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
653
+ .float()
654
+ )
655
+
656
+ mesh_v_f = []
657
+ has_surface = np.zeros((batch_size,), dtype=np.bool_)
658
+ for i in range(batch_size):
659
+ try:
660
+ warp_success = False
661
+ if use_warp:
662
+ try:
663
+ vertices, faces = marching_cubes_with_warp(
664
+ grid_logits[i],
665
+ level=0.0,
666
+ device=grid_logits.device,
667
+ )
668
+ warp_success = True
669
+ except Exception as e:
670
+ logging.warning(
671
+ f"Warning: error in marching cubes with warp: {e}"
672
+ )
673
+ warp_success = False # Fall back to CPU version
674
+
675
+ if not warp_success:
676
+ logging.warning(
677
+ "Warning: falling back to CPU version of marching cubes using skimage measure"
678
+ )
679
+ vertices, faces, _, _ = measure.marching_cubes(
680
+ grid_logits[i].cpu().numpy(), 0, method="lewiner"
681
+ )
682
+
683
+ vertices = vertices / grid_size * bbox_size + bbox_min
684
+ faces = faces[:, [2, 1, 0]]
685
+ mesh_v_f.append(
686
+ (vertices.astype(np.float32), np.ascontiguousarray(faces))
687
+ )
688
+ has_surface[i] = True
689
+ except Exception as e:
690
+ logging.error(f"Error: error in extract_geometry: {e}")
691
+ mesh_v_f.append((None, None))
692
+ has_surface[i] = False
693
+
694
+ return mesh_v_f, has_surface
cube/cube3d/model/autoencoder/spherical_vq.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Literal, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from cube3d.model.transformers.norm import RMSNorm
9
+
10
+
11
+ class SphericalVectorQuantizer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ embed_dim: int,
15
+ num_codes: int,
16
+ width: Optional[int] = None,
17
+ codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm",
18
+ ):
19
+ """
20
+ Initializes the SphericalVQ module.
21
+ Args:
22
+ embed_dim (int): The dimensionality of the embeddings.
23
+ num_codes (int): The number of codes in the codebook.
24
+ width (Optional[int], optional): The width of the input. Defaults to None.
25
+ Raises:
26
+ ValueError: If beta is not in the range [0, 1].
27
+ """
28
+ super().__init__()
29
+
30
+ self.num_codes = num_codes
31
+
32
+ self.codebook = nn.Embedding(num_codes, embed_dim)
33
+ self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
34
+
35
+ width = width or embed_dim
36
+ if width != embed_dim:
37
+ self.c_in = nn.Linear(width, embed_dim)
38
+ self.c_x = nn.Linear(width, embed_dim) # shortcut
39
+ self.c_out = nn.Linear(embed_dim, width)
40
+ else:
41
+ self.c_in = self.c_out = self.c_x = nn.Identity()
42
+
43
+ self.norm = RMSNorm(embed_dim, elementwise_affine=False)
44
+ self.cb_reg = codebook_regularization
45
+ if self.cb_reg == "batch_norm":
46
+ self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False)
47
+ else:
48
+ self.cb_weight = nn.Parameter(torch.ones([embed_dim]))
49
+ self.cb_bias = nn.Parameter(torch.zeros([embed_dim]))
50
+ self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias)
51
+
52
+ def get_codebook(self):
53
+ """
54
+ Retrieves the normalized codebook weights.
55
+ This method applies a series of normalization operations to the
56
+ codebook weights, ensuring they are properly scaled and normalized
57
+ before being returned.
58
+ Returns:
59
+ torch.Tensor: The normalized weights of the codebook.
60
+ """
61
+
62
+ return self.norm(self.cb_norm(self.codebook.weight))
63
+
64
+ @torch.no_grad()
65
+
66
+ def lookup_codebook(self, q: torch.Tensor):
67
+ """
68
+ Perform a lookup in the codebook and process the result.
69
+ This method takes an input tensor of indices, retrieves the corresponding
70
+ embeddings from the codebook, and applies a transformation to the retrieved
71
+ embeddings.
72
+ Args:
73
+ q (torch.Tensor): A tensor containing indices to look up in the codebook.
74
+ Returns:
75
+ torch.Tensor: The transformed embeddings retrieved from the codebook.
76
+ """
77
+
78
+ # normalize codebook
79
+ z_q = F.embedding(q, self.get_codebook())
80
+ z_q = self.c_out(z_q)
81
+ return z_q
82
+
83
+ @torch.no_grad()
84
+ def lookup_codebook_latents(self, q: torch.Tensor):
85
+ """
86
+ Retrieves the latent representations from the codebook corresponding to the given indices.
87
+ Args:
88
+ q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve.
89
+ The indices should be integers and correspond to the rows in the codebook.
90
+ Returns:
91
+ torch.Tensor: A tensor containing the latent representations retrieved from the codebook.
92
+ The shape of the returned tensor depends on the shape of the input indices
93
+ and the dimensionality of the codebook entries.
94
+ """
95
+
96
+ # normalize codebook
97
+ z_q = F.embedding(q, self.get_codebook())
98
+ return z_q
99
+
100
+ def quantize(self, z: torch.Tensor):
101
+ """
102
+ Quantizes the latent codes z with the codebook
103
+
104
+ Args:
105
+ z (Tensor): B x ... x F
106
+ """
107
+
108
+ # normalize codebook
109
+ codebook = self.get_codebook()
110
+ # the process of finding quantized codes is non differentiable
111
+ with torch.no_grad():
112
+ # flatten z
113
+ z_flat = z.view(-1, z.shape[-1])
114
+
115
+ # calculate distance and find the closest code
116
+ d = torch.cdist(z_flat, codebook)
117
+ q = torch.argmin(d, dim=1) # num_ele
118
+
119
+ z_q = codebook[q, :].reshape(*z.shape[:-1], -1)
120
+ q = q.view(*z.shape[:-1])
121
+
122
+ return z_q, {"z": z.detach(), "q": q}
123
+
124
+ def straight_through_approximation(self, z, z_q):
125
+ """passed gradient from z_q to z"""
126
+ z_q = z + (z_q - z).detach()
127
+ return z_q
128
+
129
+ def forward(self, z: torch.Tensor):
130
+ """
131
+ Forward pass of the spherical vector quantization autoencoder.
132
+ Args:
133
+ z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim).
134
+ Returns:
135
+ Tuple[torch.Tensor, Dict[str, Any]]:
136
+ - z_q (torch.Tensor): The quantized output tensor after applying the
137
+ straight-through approximation and output projection.
138
+ - ret_dict (Dict[str, Any]): A dictionary containing additional
139
+ information:
140
+ - "z_q" (torch.Tensor): Detached quantized tensor.
141
+ - "q" (torch.Tensor): Indices of the quantized vectors.
142
+ - "perplexity" (torch.Tensor): The perplexity of the quantization,
143
+ calculated as the exponential of the negative sum of the
144
+ probabilities' log values.
145
+ """
146
+
147
+ with torch.autocast(device_type=z.device.type, enabled=False):
148
+ # work in full precision
149
+ z = z.float()
150
+
151
+ # project and normalize
152
+ z_e = self.norm(self.c_in(z))
153
+ z_q, ret_dict = self.quantize(z_e)
154
+
155
+ ret_dict["z_q"] = z_q.detach()
156
+ z_q = self.straight_through_approximation(z_e, z_q)
157
+ z_q = self.c_out(z_q)
158
+
159
+ return z_q, ret_dict
cube/cube3d/model/gpt/__init__.py ADDED
File without changes
cube/cube3d/model/gpt/dual_stream_roformer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from cube3d.model.transformers.cache import Cache
8
+ from cube3d.model.transformers.dual_stream_attention import (
9
+ DualStreamDecoderLayerWithRotaryEmbedding,
10
+ )
11
+ from cube3d.model.transformers.norm import LayerNorm
12
+ from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding
13
+ from cube3d.model.transformers.rope import precompute_freqs_cis
14
+
15
+
16
+ class DualStreamRoformer(nn.Module):
17
+ @dataclass
18
+ class Config:
19
+ checkpoint_path: str = ""
20
+ n_layer: int = 12
21
+ n_single_layer: int = 0
22
+ rope_theta: float = 1000
23
+
24
+ n_head: int = 16
25
+ n_embd: int = 2048
26
+ bias: bool = False # bias in Linears and LayerNorms
27
+ eps: float = 1e-6 # Norm eps
28
+
29
+ shape_model_vocab_size: int = 4096
30
+ shape_model_embed_dim: int = 16
31
+
32
+ text_model_embed_dim: int = 512
33
+ use_pooled_text_embed: bool = False
34
+
35
+ encoder_with_cls_token: bool = True
36
+
37
+ def __init__(self, cfg: Config) -> None:
38
+ """
39
+ Initializes the DualStreamRoFormer model.
40
+ Args:
41
+ cfg (Config): Configuration object containing model parameters.
42
+ Attributes:
43
+ cfg (Config): Stores the configuration object.
44
+ text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension.
45
+ shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding
46
+ dimension
47
+ vocab_size (int): Vocabulary size for the shape model, including special tokens.
48
+ shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model.
49
+ shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model.
50
+ padding_id (int): Token ID for the padding token.
51
+ transformer (nn.ModuleDict): Dictionary containing the following components:
52
+ - wte (nn.Embedding): Embedding layer for the vocabulary.
53
+ - dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings.
54
+ - single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings.
55
+ - ln_f (LayerNorm): Layer normalization applied to the final output.
56
+ lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling.
57
+ """
58
+
59
+ super().__init__()
60
+
61
+ self.cfg = cfg
62
+
63
+ self.text_proj = nn.Linear(
64
+ in_features=self.cfg.text_model_embed_dim,
65
+ out_features=self.cfg.n_embd,
66
+ bias=self.cfg.bias,
67
+ )
68
+
69
+ self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd)
70
+
71
+ self.vocab_size = self.cfg.shape_model_vocab_size
72
+
73
+ def add_special_token():
74
+ token_id = self.vocab_size
75
+ self.vocab_size += 1
76
+ return token_id
77
+
78
+ self.shape_bos_id = add_special_token()
79
+ self.shape_eos_id = add_special_token()
80
+ self.padding_id = add_special_token()
81
+
82
+ self.transformer = nn.ModuleDict(
83
+ dict(
84
+ wte=nn.Embedding(
85
+ self.vocab_size,
86
+ self.cfg.n_embd,
87
+ padding_idx=self.padding_id,
88
+ ),
89
+ dual_blocks=nn.ModuleList(
90
+ [
91
+ DualStreamDecoderLayerWithRotaryEmbedding.from_config(
92
+ self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1)
93
+ )
94
+ for i in range(self.cfg.n_layer)
95
+ ]
96
+ ),
97
+ single_blocks=nn.ModuleList(
98
+ [
99
+ DecoderLayerWithRotaryEmbedding.from_config(self.cfg)
100
+ for _ in range(self.cfg.n_single_layer)
101
+ ]
102
+ ),
103
+ ln_f=LayerNorm(
104
+ self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps
105
+ ),
106
+ )
107
+ )
108
+
109
+ self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
110
+
111
+ def encode_text(self, text_embed):
112
+ """
113
+ Encodes the given text embeddings by projecting them through a linear transformation.
114
+ Args:
115
+ text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
116
+ Returns:
117
+ torch.Tensor: The projected text embeddings after applying the linear transformation.
118
+ """
119
+
120
+ return self.text_proj(text_embed)
121
+
122
+ def encode_token(self, tokens):
123
+ """
124
+ Encodes the input tokens using the word token embedding layer of the transformer model.
125
+ Args:
126
+ tokens (torch.Tensor): A tensor containing the input tokens to be encoded.
127
+ Returns:
128
+ torch.Tensor: A tensor containing the encoded token embeddings.
129
+ """
130
+
131
+ return self.transformer.wte(tokens)
132
+
133
+ def init_kv_cache(
134
+ self,
135
+ batch_size: int,
136
+ cond_len: int,
137
+ max_shape_tokens: int,
138
+ dtype: torch.dtype,
139
+ device: torch.device,
140
+ ) -> list[Cache]:
141
+ """
142
+ Initializes the key-value cache for the transformer model.
143
+ This method creates a list of `Cache` objects to store the key and value
144
+ states for both dual-stream and single-stream transformer blocks. The
145
+ cache is pre-allocated with zeros and is used to optimize the computation
146
+ of attention mechanisms during model inference.
147
+ Args:
148
+ batch_size (int): The batch size for the input data.
149
+ cond_len (int): The length of the conditioning sequence.
150
+ max_shape_tokens (int): The maximum number of tokens in the shape sequence.
151
+ dtype (torch.dtype): The data type for the tensors (e.g., torch.float32).
152
+ device (torch.device): The device on which the tensors will be allocated
153
+ (e.g., torch.device('cuda') or torch.device('cpu')).
154
+ Returns:
155
+ list[Cache]: A list of `Cache` objects containing pre-allocated key and
156
+ value states for each transformer block.
157
+ """
158
+ num_heads = self.cfg.n_head
159
+ max_all_tokens = cond_len + max_shape_tokens
160
+ per_head_dim = self.cfg.n_embd // num_heads
161
+
162
+ kv_cache = [
163
+ Cache(
164
+ key_states=torch.zeros(
165
+ (batch_size, num_heads, max_all_tokens, per_head_dim),
166
+ dtype=dtype,
167
+ device=device,
168
+ ),
169
+ value_states=torch.zeros(
170
+ (batch_size, num_heads, max_all_tokens, per_head_dim),
171
+ dtype=dtype,
172
+ device=device,
173
+ ),
174
+ )
175
+ for _ in range(len(self.transformer.dual_blocks))
176
+ ]
177
+ kv_cache += [
178
+ Cache(
179
+ key_states=torch.zeros(
180
+ (batch_size, num_heads, max_shape_tokens, per_head_dim),
181
+ dtype=dtype,
182
+ device=device,
183
+ ),
184
+ value_states=torch.zeros(
185
+ (batch_size, num_heads, max_shape_tokens, per_head_dim),
186
+ dtype=dtype,
187
+ device=device,
188
+ ),
189
+ )
190
+ for _ in range(len(self.transformer.single_blocks))
191
+ ]
192
+ return kv_cache
193
+
194
+ def forward(
195
+ self,
196
+ embed: torch.Tensor,
197
+ cond: torch.Tensor,
198
+ kv_cache: Optional[list[Cache]] = None,
199
+ curr_pos_id: Optional[torch.Tensor] = None,
200
+ decode: bool = False,
201
+ ):
202
+ """
203
+ Forward pass for the dual-stream RoFormer model.
204
+ Args:
205
+ embed (torch.Tensor): The input embedding tensor.
206
+ cond (torch.Tensor): The conditioning tensor.
207
+ kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None.
208
+ curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None.
209
+ decode (bool): Whether the model is in decoding mode. Default is False.
210
+ Returns:
211
+ torch.Tensor: The output logits tensor.
212
+ """
213
+ b, l = embed.shape[:2]
214
+ s = cond.shape[1]
215
+ device = embed.device
216
+
217
+ attn_mask = torch.tril(
218
+ torch.ones(s + l, s + l, dtype=torch.bool, device=device)
219
+ )
220
+
221
+ position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t)
222
+ position_ids = position_ids.unsqueeze_(0).expand(b, -1)
223
+
224
+ s_freqs_cis = precompute_freqs_cis(
225
+ dim=self.cfg.n_embd // self.cfg.n_head,
226
+ t=position_ids,
227
+ theta=self.cfg.rope_theta,
228
+ )
229
+
230
+ position_ids = torch.cat(
231
+ [
232
+ torch.zeros([b, s], dtype=torch.long, device=position_ids.device),
233
+ position_ids,
234
+ ],
235
+ dim=1,
236
+ )
237
+ d_freqs_cis = precompute_freqs_cis(
238
+ dim=self.cfg.n_embd // self.cfg.n_head,
239
+ t=position_ids,
240
+ theta=self.cfg.rope_theta,
241
+ )
242
+
243
+ if kv_cache is not None and decode:
244
+ assert curr_pos_id is not None
245
+ embed = embed[:, curr_pos_id, :]
246
+
247
+ h = embed
248
+ c = cond
249
+
250
+ layer_idx = 0
251
+ for block in self.transformer.dual_blocks:
252
+ h, c = block(
253
+ h,
254
+ c=c,
255
+ freqs_cis=d_freqs_cis,
256
+ attn_mask=attn_mask,
257
+ is_causal=True,
258
+ kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
259
+ curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
260
+ decode=decode,
261
+ )
262
+ layer_idx += 1
263
+ for block in self.transformer.single_blocks:
264
+ h = block(
265
+ h,
266
+ freqs_cis=s_freqs_cis,
267
+ attn_mask=None,
268
+ is_causal=True,
269
+ kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
270
+ curr_pos_id=curr_pos_id,
271
+ decode=decode,
272
+ )
273
+ layer_idx += 1
274
+
275
+ # Normalization
276
+ h = self.transformer.ln_f(h)
277
+ logits = self.lm_head(h)
278
+
279
+ return logits
cube/cube3d/model/transformers/__init__.py ADDED
File without changes
cube/cube3d/model/transformers/attention.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from cube3d.model.transformers.norm import LayerNorm, RMSNorm
7
+
8
+
9
+ def init_linear(module, embed_dim: int):
10
+ """
11
+ Initializes the weights and biases of a given linear module.
12
+ Args:
13
+ module (nn.Module): The module to initialize. Expected to be an instance of nn.Linear.
14
+ embed_dim (int): The embedding dimension used to calculate the standard deviation
15
+ for weight initialization.
16
+ Returns:
17
+ None
18
+ """
19
+
20
+ if isinstance(module, nn.Linear):
21
+ nn.init.normal_(module.weight, std=math.sqrt(1.0 / embed_dim))
22
+ if module.bias is not None:
23
+ torch.nn.init.zeros_(module.bias)
24
+
25
+
26
+ def init_tfixup(module: nn.Module, num_layers: int):
27
+ """Special initialization from https://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
28
+
29
+ Args:
30
+ module (nn.Module): decoder/encoder module
31
+ num_layers (int): number of layers in the module
32
+ """
33
+ with torch.no_grad():
34
+ for pn, p in module.named_parameters():
35
+ if (
36
+ pn.endswith("c_proj.weight")
37
+ or pn.endswith("up_proj.weight")
38
+ or pn.endswith("down_proj.weight")
39
+ ):
40
+ p *= (4 * num_layers) ** (-0.25)
41
+ elif pn.endswith("c_v.weight"):
42
+ p *= (4 * num_layers) ** (-0.25) * math.sqrt(2)
43
+
44
+
45
+ class MLP(nn.Module):
46
+ def __init__(self, embed_dim, hidden_dim, bias=True, approximate="none"):
47
+ """
48
+ MLP with GELU activation function."
49
+ """
50
+
51
+ super().__init__()
52
+ self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
53
+ self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
54
+ self.act_fn = nn.GELU(approximate=approximate)
55
+
56
+ def forward(self, x):
57
+ return self.down_proj(self.act_fn(self.up_proj(x)))
58
+
59
+
60
+ class SelfAttention(nn.Module):
61
+ def __init__(
62
+ self,
63
+ embed_dim: int,
64
+ num_heads: int,
65
+ bias: bool = True,
66
+ eps: float = 1e-6,
67
+ ):
68
+ """
69
+ Initializes the self attention mechanism.
70
+ Args:
71
+ embed_dim (int): The dimensionality of the embedding space.
72
+ num_heads (int): The number of attention heads.
73
+ bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
74
+ eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
75
+ Raises:
76
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
77
+ """
78
+
79
+ super().__init__()
80
+ assert embed_dim % num_heads == 0
81
+ self.num_heads = num_heads
82
+ self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=bias)
83
+ self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
84
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
85
+
86
+ head_dim = embed_dim // num_heads
87
+ self.q_norm = RMSNorm(head_dim)
88
+ self.k_norm = RMSNorm(head_dim)
89
+
90
+ def forward(self, x, attn_mask=None, is_causal: bool = False):
91
+ """
92
+ Performs the forward pass of the attention mechanism.
93
+ Args:
94
+ x (torch.Tensor): Input tensor.
95
+ attn_mask (Optional[torch.Tensor]): Attention mask to apply. Default is None.
96
+ is_causal (bool): If True, applies a causal mask to prevent attending to future positions.
97
+ Default is False.
98
+ Returns:
99
+ torch.Tensor: Output tensor after applying
100
+ the attention mechanism and projection.
101
+ """
102
+
103
+ b, l, d = x.shape
104
+
105
+ q, k = self.c_qk(x).chunk(2, dim=-1)
106
+ v = self.c_v(x)
107
+
108
+ q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
109
+ k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
110
+ v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
111
+
112
+ q = self.q_norm(q)
113
+ k = self.k_norm(k)
114
+
115
+ is_causal = is_causal and attn_mask is None
116
+ y = torch.nn.functional.scaled_dot_product_attention(
117
+ q,
118
+ k,
119
+ v,
120
+ attn_mask=attn_mask,
121
+ dropout_p=0.0,
122
+ is_causal=is_causal,
123
+ )
124
+
125
+ y = y.transpose(1, 2).contiguous().view(b, l, d)
126
+
127
+ y = self.c_proj(y)
128
+ return y
129
+
130
+
131
+ class CrossAttention(nn.Module):
132
+ def __init__(
133
+ self,
134
+ embed_dim: int,
135
+ num_heads: int,
136
+ q_dim=None,
137
+ kv_dim=None,
138
+ bias: bool = True,
139
+ ):
140
+ """
141
+ Initializes the cross attention mechanism.
142
+ Args:
143
+ embed_dim (int): The dimensionality of the embedding space.
144
+ num_heads (int): The number of attention heads.
145
+ q_dim (int, optional): The dimensionality of the query input. Defaults to `embed_dim`.
146
+ kv_dim (int, optional): The dimensionality of the key and value inputs. Defaults to `embed_dim`.
147
+ bias (bool, optional): Whether to include a bias term in the linear projections. Defaults to True.
148
+ Raises:
149
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
150
+ """
151
+ super().__init__()
152
+ assert embed_dim % num_heads == 0
153
+
154
+ q_dim = q_dim or embed_dim
155
+ kv_dim = kv_dim or embed_dim
156
+
157
+ self.c_q = nn.Linear(q_dim, embed_dim, bias=bias)
158
+ self.c_k = nn.Linear(kv_dim, embed_dim, bias=bias)
159
+ self.c_v = nn.Linear(kv_dim, embed_dim, bias=bias)
160
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
161
+ self.num_heads = num_heads
162
+
163
+ def forward(self, x, c, attn_mask=None, is_causal: bool = False):
164
+ """
165
+ Forward pass for the attention mechanism.
166
+ Args:
167
+ x (torch.Tensor): Input tensor of shape.
168
+ c (torch.Tensor): Context tensor.
169
+ attn_mask (torch.Tensor, optional): Attention mask.
170
+ Defaults to None.
171
+ is_causal (bool, optional): Whether to apply causal masking. Defaults to False.
172
+ Returns:
173
+ torch.Tensor: Output tensor.
174
+ """
175
+
176
+ q, k = self.c_q(x), self.c_k(c)
177
+ v = self.c_v(c)
178
+
179
+ b, l, d = q.shape
180
+ s = k.shape[1]
181
+
182
+ q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
183
+ k = k.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
184
+ v = v.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
185
+
186
+ y = torch.nn.functional.scaled_dot_product_attention(
187
+ q,
188
+ k,
189
+ v,
190
+ attn_mask=attn_mask,
191
+ dropout_p=0.0,
192
+ is_causal=(attn_mask is not None) and is_causal,
193
+ )
194
+
195
+ y = y.transpose(1, 2).contiguous().view(b, l, d)
196
+
197
+ y = self.c_proj(y)
198
+ return y
199
+
200
+
201
+ class EncoderLayer(nn.Module):
202
+ def __init__(
203
+ self,
204
+ embed_dim: int,
205
+ num_heads: int,
206
+ bias: bool = True,
207
+ eps: float = 1e-6,
208
+ ) -> None:
209
+ """
210
+ Initializes the EncoderLayer module.
211
+ Args:
212
+ embed_dim (int): The dimensionality of the embedding space.
213
+ num_heads (int): The number of attention heads.
214
+ bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
215
+ eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
216
+ """
217
+ super().__init__()
218
+ self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
219
+ self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps)
220
+ self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
221
+ self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
222
+
223
+ def forward(self, x, attn_mask=None, is_causal: bool = False):
224
+ """
225
+ Performs the forward pass of the transformer block.
226
+ Args:
227
+ x (torch.Tensor): The input tensor.
228
+ attn_mask (torch.Tensor, optional): An optional attention mask tensor to apply during the
229
+ attention computation. Default is None.
230
+ is_causal (bool, optional): If True, applies a causal mask to prevent attention to future
231
+ positions. Default is False.
232
+ Returns:
233
+ torch.Tensor: The output tensor of the same shape as the input.
234
+ """
235
+
236
+ x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
237
+ x = x + self.mlp(self.ln_2(x))
238
+ return x
239
+
240
+
241
+ class EncoderCrossAttentionLayer(nn.Module):
242
+ def __init__(
243
+ self,
244
+ embed_dim: int,
245
+ num_heads: int,
246
+ q_dim=None,
247
+ kv_dim=None,
248
+ bias: bool = True,
249
+ eps: float = 1e-6,
250
+ ) -> None:
251
+ """
252
+ Initializes the EncoderAttentionLayer module with cross-attention,
253
+ and a feed-forward MLP.
254
+ Args:
255
+ embed_dim (int): The dimensionality of the embedding space.
256
+ num_heads (int): The number of attention heads.
257
+ q_dim (int, optional): Dimensionality of the query input. Defaults to `embed_dim`.
258
+ kv_dim (int, optional): Dimensionality of the key and value inputs. Defaults to `embed_dim`.
259
+ bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
260
+ eps (float, optional): A small value added to the denominator for numerical stability
261
+ in layer normalization. Defaults to 1e-6.
262
+ """
263
+ super().__init__()
264
+
265
+ q_dim = q_dim or embed_dim
266
+ kv_dim = kv_dim or embed_dim
267
+
268
+ self.attn = CrossAttention(
269
+ embed_dim,
270
+ num_heads,
271
+ q_dim=q_dim,
272
+ kv_dim=kv_dim,
273
+ bias=bias,
274
+ )
275
+
276
+ self.ln_1 = LayerNorm(q_dim, elementwise_affine=False, eps=eps)
277
+ self.ln_2 = LayerNorm(kv_dim, elementwise_affine=False, eps=eps)
278
+
279
+ self.ln_f = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
280
+ self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
281
+
282
+ def forward(self, x, c, attn_mask=None, is_causal: bool = False):
283
+ """
284
+ Forward pass for the attention mechanism.
285
+ Args:
286
+ x (torch.Tensor): The input tensor to the attention mechanism.
287
+ c (torch.Tensor): The context tensor used for cross-attention.
288
+ attn_mask (torch.Tensor, optional): An optional attention mask to control
289
+ which positions can attend to others. Defaults to None.
290
+ is_causal (bool, optional): If True, applies a causal mask to prevent
291
+ attending to future positions. Defaults to False.
292
+ Returns:
293
+ torch.Tensor: The output tensor after applying attention and MLP layers.
294
+ """
295
+
296
+ x = x + self.attn(
297
+ self.ln_1(x), self.ln_2(c), attn_mask=attn_mask, is_causal=is_causal
298
+ )
299
+ x = x + self.mlp(self.ln_f(x))
300
+ return x
cube/cube3d/model/transformers/cache.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Cache:
8
+ key_states: torch.Tensor
9
+ value_states: torch.Tensor
cube/cube3d/model/transformers/dual_stream_attention.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from cube3d.model.transformers.cache import Cache
7
+ from cube3d.model.transformers.norm import LayerNorm, RMSNorm
8
+ from cube3d.model.transformers.roformer import SwiGLUMLP
9
+ from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb
10
+
11
+
12
+ class DismantledPreAttention(nn.Module):
13
+ def __init__(
14
+ self,
15
+ embed_dim: int,
16
+ num_heads: int,
17
+ query: bool = True,
18
+ bias: bool = True,
19
+ ) -> None:
20
+ """
21
+ Initializes the DismantledPreAttention module.
22
+ Args:
23
+ embed_dim (int): The dimensionality of the embedding space.
24
+ num_heads (int): The number of attention heads.
25
+ query (bool, optional): Whether to include query-key projection. Defaults to True.
26
+ bias (bool, optional): Whether to include bias in linear layers. Defaults to True.
27
+ Raises:
28
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
29
+ """
30
+ super().__init__()
31
+ assert embed_dim % num_heads == 0
32
+ self.query = query
33
+
34
+ head_dim = embed_dim // num_heads
35
+ # key, query, value projections for all heads, but in a batch
36
+ if query:
37
+ self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
38
+ self.q_norm = RMSNorm(head_dim)
39
+ else:
40
+ self.c_k = nn.Linear(embed_dim, embed_dim, bias=bias)
41
+ self.k_norm = RMSNorm(head_dim)
42
+ self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
43
+
44
+ # (B, T, C) -> (B, nh, T, hs)
45
+ self.to_mha = lambda x: x.view(*x.shape[:2], num_heads, -1).transpose(1, 2)
46
+
47
+ def forward(self, x):
48
+ """
49
+ Forward pass for the dismantled pre-attention mechanism.
50
+ Args:
51
+ x (torch.Tensor): Input tensor of shape (..., input_dim).
52
+ Returns:
53
+ tuple: A tuple containing:
54
+ - q (torch.Tensor or None): Query tensor after normalization and transformation,
55
+ or None if `self.query` is False.
56
+ - k (torch.Tensor): Key tensor after normalization and transformation.
57
+ - v (torch.Tensor): Value tensor after transformation.
58
+ """
59
+
60
+ if self.query:
61
+ q, k = self.c_qk(x).chunk(2, dim=-1)
62
+ q = self.q_norm(self.to_mha(q))
63
+ else:
64
+ q = None
65
+ k = self.c_k(x)
66
+
67
+ k = self.k_norm(self.to_mha(k))
68
+ v = self.to_mha(self.c_v(x))
69
+
70
+ return (q, k, v)
71
+
72
+
73
+ class DismantledPostAttention(nn.Module):
74
+ def __init__(
75
+ self,
76
+ embed_dim,
77
+ bias: bool = True,
78
+ eps: float = 1e-6,
79
+ ) -> None:
80
+ """
81
+ Initializes the DismantledPostAttention module.
82
+ Args:
83
+ embed_dim (int): The dimensionality of the embedding space.
84
+ bias (bool, optional): Whether to include a bias term in the linear projection. Defaults to True.
85
+ eps (float, optional): A small value added to the denominator for numerical stability in layer normalization. Defaults to 1e-6.
86
+ """
87
+ super().__init__()
88
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
89
+ self.ln_3 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
90
+ self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias)
91
+
92
+ def forward(self, x, a):
93
+ """
94
+ Forward pass of the dual stream attention mechanism.
95
+ Args:
96
+ x (torch.Tensor): The input tensor to the model.
97
+ a (torch.Tensor): The attention tensor to be combined with the input.
98
+ Returns:
99
+ torch.Tensor: The output tensor after applying the projection,
100
+ layer normalization, and MLP transformations.
101
+ """
102
+
103
+ x = x + self.c_proj(a)
104
+ x = x + self.mlp(self.ln_3(x))
105
+ return x
106
+
107
+
108
+ class DualStreamAttentionWithRotaryEmbedding(nn.Module):
109
+ def __init__(
110
+ self,
111
+ embed_dim: int,
112
+ num_heads: int,
113
+ cond_pre_only: bool = False,
114
+ bias: bool = True,
115
+ ):
116
+ """
117
+ Initializes the DualStreamAttention module.
118
+ Args:
119
+ embed_dim (int): The dimensionality of the embedding space.
120
+ num_heads (int): The number of attention heads.
121
+ cond_pre_only (bool, optional): If True, the conditional pre-attention
122
+ will only process the key and value, not the query. Defaults to False.
123
+ bias (bool, optional): Whether to include a bias term in the attention layers.
124
+ Defaults to True.
125
+ """
126
+ super().__init__()
127
+
128
+ self.cond_pre_only = cond_pre_only
129
+
130
+ self.pre_x = DismantledPreAttention(
131
+ embed_dim=embed_dim, num_heads=num_heads, query=True, bias=bias
132
+ )
133
+
134
+ self.pre_c = DismantledPreAttention(
135
+ embed_dim=embed_dim, num_heads=num_heads, query=not cond_pre_only, bias=bias
136
+ )
137
+
138
+ def forward(
139
+ self,
140
+ x,
141
+ c: Optional[torch.Tensor],
142
+ freqs_cis,
143
+ attn_mask: Optional[torch.Tensor] = None,
144
+ is_causal: bool = False,
145
+ kv_cache: Optional[Cache] = None,
146
+ curr_pos_id: Optional[torch.Tensor] = None,
147
+ decode: bool = False,
148
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
149
+ """
150
+ Forward pass for dual stream Multi-Head Attention.
151
+
152
+ Efficient single weight matrix multiplication with results split into query, key, value.
153
+
154
+ Parameters
155
+ ----------
156
+ x : torch.Tensor
157
+ Hidden states [B, L, D]
158
+ c : torch.Tensor
159
+ Condition [B, S, D]
160
+ freqs_cis: torch.Tensor
161
+ Precomputed RoPE matrix from precompute_freqs_cis [B, S+L, Hd]
162
+ attn_mask : torch.Tensor, optional
163
+ Attention mask [B, S+L, S+L], by default None
164
+ kv_cache: None | Tensor
165
+ key-value cache, but only if not None; if None - it means that it's disabled
166
+ contains cache for keys and value from all previous steps
167
+ kv_cache_cond: None | Tensor
168
+ key-value cache, but only if not None; if None - it means that it's disabled
169
+ contains cache for keys and value from all previous steps for the text conditioning.
170
+
171
+ Returns
172
+ -------
173
+ torch.Tensor
174
+ Hidden state output [B, L, D]
175
+ """
176
+ if kv_cache is None or not decode:
177
+ # Either training or prefill
178
+ qkv_c = self.pre_c(c)
179
+ qkv_x = self.pre_x(x)
180
+ # prepend condition stream
181
+ # (B, nh, Tc, hs) + (B, nh, Tx, hs) -> (B, nh, Tc+Tx, hs)
182
+ if self.cond_pre_only:
183
+ q = qkv_x[0]
184
+ else:
185
+ q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
186
+ k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
187
+ v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
188
+
189
+ else:
190
+ # if using kv cache, query would only be the last token in the sequence, hence is_causal is False
191
+ assert x.shape[1] == 1
192
+ is_causal = False
193
+ q, k, v = self.pre_x(x)
194
+
195
+ if kv_cache is not None:
196
+ if not decode:
197
+ kv_cache.key_states[:, :, : k.shape[2], :].copy_(k)
198
+ kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
199
+ else:
200
+ assert curr_pos_id is not None
201
+ kv_cache.key_states.index_copy_(2, curr_pos_id, k)
202
+ kv_cache.value_states.index_copy_(2, curr_pos_id, v)
203
+ k = kv_cache.key_states
204
+ v = kv_cache.value_states
205
+
206
+ if attn_mask is not None:
207
+ # trim attention mask to length
208
+ if decode:
209
+ assert curr_pos_id is not None
210
+ attn_mask = attn_mask[..., curr_pos_id, :]
211
+ else:
212
+ attn_mask = attn_mask[..., -q.shape[2] :, :]
213
+
214
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
215
+ # efficient attention using Flash Attention CUDA kernels
216
+ y = scaled_dot_product_attention_with_rotary_emb(
217
+ q,
218
+ k,
219
+ v,
220
+ freqs_cis=freqs_cis,
221
+ attn_mask=attn_mask,
222
+ curr_pos_id=curr_pos_id if decode else None,
223
+ is_causal=is_causal,
224
+ )
225
+
226
+ # re-assemble all head outputs side by side
227
+ y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
228
+
229
+ if y.shape[1] == x.shape[1]:
230
+ y_c = None
231
+ y_x = y
232
+ else:
233
+ assert c is not None, "Conditioning is required for dual stream attention"
234
+ y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
235
+ return y_x, y_c
236
+
237
+
238
+ class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
239
+ """Nicely wrapped decoder layer block for dual stream GPT model"""
240
+
241
+ def __init__(
242
+ self,
243
+ embed_dim,
244
+ num_heads: int,
245
+ cond_pre_only: bool = False,
246
+ bias: bool = True,
247
+ eps: float = 1.0e-6,
248
+ ) -> None:
249
+ """
250
+ Initializes the DualStreamDecoderLayerWithRotaryEmbedding module with optional conditional pre-only mode.
251
+ Args:
252
+ embed_dim (int): The dimensionality of the embedding space.
253
+ num_heads (int): The number of attention heads.
254
+ cond_pre_only (bool, optional): If True, applies conditional processing only before attention. Defaults to False.
255
+ bias (bool, optional): If True, includes bias terms in the attention and post-attention layers. Defaults to True.
256
+ eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1.0e-6.
257
+ """
258
+ super().__init__()
259
+
260
+ self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
261
+ self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
262
+
263
+ self.attn = DualStreamAttentionWithRotaryEmbedding(
264
+ embed_dim=embed_dim,
265
+ num_heads=num_heads,
266
+ cond_pre_only=cond_pre_only,
267
+ bias=bias,
268
+ )
269
+
270
+ self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
271
+ if not cond_pre_only:
272
+ self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
273
+
274
+ @classmethod
275
+ def from_config(cls, cfg, cond_pre_only: bool = False):
276
+ """
277
+ Create an instance of the class using the provided configuration.
278
+ Args:
279
+ cfg: A configuration object containing the necessary parameters:
280
+ - n_embd (int): The size of the embedding dimension.
281
+ - n_head (int): The number of attention heads.
282
+ - bias (bool): Whether to include a bias term.
283
+ - eps (float): A small value added for numerical stability.
284
+ cond_pre_only (bool, optional): If True, applies conditioning only in the pre-processing step.
285
+ Defaults to False.
286
+ Returns:
287
+ An instance of the class initialized with the specified configuration.
288
+ """
289
+
290
+ return cls(
291
+ cfg.n_embd,
292
+ num_heads=cfg.n_head,
293
+ cond_pre_only=cond_pre_only,
294
+ bias=cfg.bias,
295
+ eps=cfg.eps,
296
+ )
297
+
298
+ def forward(
299
+ self,
300
+ x,
301
+ c,
302
+ freqs_cis: torch.Tensor,
303
+ attn_mask: Optional[torch.Tensor] = None,
304
+ is_causal: bool = True,
305
+ kv_cache: Optional[Cache] = None,
306
+ curr_pos_id: Optional[torch.Tensor] = None,
307
+ decode: bool = False,
308
+ ):
309
+ """
310
+ Forward pass for DualStreamDecoderLayerWithRotaryEmbedding.
311
+
312
+ Parameters
313
+ ----------
314
+ x : torch.Tensor
315
+ Hidden states [B, L, D]
316
+ c : torch.Tensor
317
+ Condition [B, S, D]
318
+ freqs_cis: torch.Tensor
319
+ Postional embedding from RoPE [B, S+L, hd]
320
+ attn_mask : torch.Tensor, optional
321
+ Attention mask [B, S+L, S+L], by default None
322
+ kv_vache : torch.Tensor, optional
323
+ kv_cache by default None
324
+
325
+ Returns
326
+ -------
327
+ torch.Tensor
328
+ Hidden state output [B, L, D]
329
+ torch.Tensor
330
+ kv_cache output [1, L, D]
331
+ """
332
+ a_x, a_c = self.attn(
333
+ self.ln_1(x),
334
+ # NOTE condition could be none if using kv cache
335
+ self.ln_2(c) if c is not None else None,
336
+ freqs_cis=freqs_cis,
337
+ attn_mask=attn_mask,
338
+ is_causal=is_causal,
339
+ kv_cache=kv_cache,
340
+ curr_pos_id=curr_pos_id,
341
+ decode=decode,
342
+ )
343
+ x = self.post_1(x, a_x)
344
+ if a_c is not None:
345
+ c = self.post_2(c, a_c)
346
+ else:
347
+ c = None
348
+ return x, c
cube/cube3d/model/transformers/norm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
6
+ """
7
+ Applies a fused Root Mean Square (RMS) normalization to the input tensor.
8
+ Args:
9
+ x (torch.Tensor): The input tensor to be normalized. Expected to have
10
+ at least one dimension.
11
+ weight (nn.Parameter): A learnable parameter used to scale the normalized
12
+ tensor. Its shape must be broadcastable to the shape of `x`.
13
+ eps (float): A small constant added to the denominator for numerical
14
+ stability during normalization.
15
+ Returns:
16
+ torch.Tensor: The normalized and scaled tensor with the same shape as `x`.
17
+ """
18
+
19
+ x = x.float()
20
+ return (x * torch.rsqrt((x * x).mean(-1, keepdim=True).add_(eps))) * weight
21
+
22
+
23
+ class LayerNorm(nn.LayerNorm):
24
+ def forward(self, input: torch.Tensor):
25
+ """
26
+ Wrapper to ensure that the input tensor is cast to float before normalization.
27
+ """
28
+ y = super().forward(input.float())
29
+ return y.type_as(input)
30
+
31
+
32
+ class RMSNorm(nn.Module):
33
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine: bool = True):
34
+ """
35
+ Initializes the normalization layer.
36
+ Args:
37
+ dim (int): The number of features in the input tensor.
38
+ eps (float, optional): A small value added to the denominator for numerical stability. Defaults to 1e-5.
39
+ elementwise_affine (bool, optional): If True, this layer will have learnable per-element affine parameters. Defaults to True.
40
+ """
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim), requires_grad=elementwise_affine)
44
+
45
+ def forward(self, x):
46
+ return fused_rms_norm(x, weight=self.weight, eps=self.eps).type_as(x)
cube/cube3d/model/transformers/roformer.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from cube3d.model.transformers.cache import Cache
8
+ from cube3d.model.transformers.norm import LayerNorm, RMSNorm
9
+ from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb
10
+
11
+
12
+ class SwiGLUMLP(nn.Module):
13
+ def __init__(self, embed_dim, hidden_dim, bias=True, **kwargs):
14
+ """
15
+ A PyTorch implementation of the SwiGLU (Swish-Gated Linear Unit) MLP layer.
16
+ This module consists of three linear projections: `gate_proj`, `up_proj`, and `down_proj`.
17
+ It applies the SwiGLU activation function, which combines the Swish activation with a gating mechanism,
18
+ followed by a projection back to the original embedding dimension.
19
+ Args:
20
+ embed_dim (int): The dimensionality of the input embeddings.
21
+ hidden_dim (int): The dimensionality of the hidden layer.
22
+ bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
23
+ **kwargs: Additional keyword arguments (currently unused).
24
+ """
25
+ super().__init__()
26
+ self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
27
+ self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
28
+ self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
29
+
30
+ # Ignore copy
31
+ def forward(self, x):
32
+ """
33
+ Applies a forward pass.
34
+ Args:
35
+ x (torch.Tensor): The input tensor.
36
+ Returns:
37
+ torch.Tensor: The output tensor after applying the forward pass.
38
+ """
39
+
40
+ down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
41
+ return down_proj
42
+
43
+
44
+ class SelfAttentionWithRotaryEmbedding(nn.Module):
45
+ def __init__(
46
+ self,
47
+ embed_dim: int,
48
+ num_heads: int,
49
+ bias: bool = True,
50
+ eps: float = 1e-6,
51
+ ):
52
+ """
53
+ A PyTorch module implementing self-attention with rotary embeddings.
54
+
55
+ Args:
56
+ embed_dim (int): The dimensionality of the input embeddings.
57
+ num_heads (int): The number of attention heads.
58
+ bias (bool, optional): Whether to include bias terms in the linear projections. Defaults to True.
59
+ eps (float, optional): A small value added for numerical stability in normalization. Defaults to 1e-6.
60
+ """
61
+ super().__init__()
62
+ assert embed_dim % num_heads == 0
63
+ self.num_heads = num_heads
64
+ # key, query, value projections for all heads, but in a batch
65
+ self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
66
+ self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
67
+ # output projection
68
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
69
+
70
+ head_dim = embed_dim // num_heads
71
+ self.q_norm = RMSNorm(head_dim)
72
+ self.k_norm = RMSNorm(head_dim)
73
+
74
+ def forward(
75
+ self,
76
+ x,
77
+ freqs_cis: torch.Tensor,
78
+ attn_mask=None,
79
+ is_causal: bool = False,
80
+ kv_cache: Optional[Cache] = None,
81
+ curr_pos_id: Optional[torch.Tensor] = None,
82
+ decode: bool = False,
83
+ ):
84
+ """
85
+ Forward pass for the SelfAttentionWithRotaryEmbedding instance.
86
+ Args:
87
+ x (torch.Tensor): Input tensor.
88
+ freqs_cis (torch.Tensor): Precomputed rotary positional embeddings.
89
+ attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention. Defaults to None.
90
+ is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding. Defaults to False.
91
+ kv_cache (Optional[Cache], optional): Cache object for storing key and value states for decoding. Defaults to None.
92
+ curr_pos_id (Optional[torch.Tensor], optional): Current position indices for decoding. Required if `decode` is True. Defaults to None.
93
+ decode (bool, optional): Whether the model is in decoding mode. Defaults to False.
94
+ Returns:
95
+ torch.Tensor: Output tensor after applying self-attention and projection.
96
+ """
97
+ # batch size, sequence length, embedding dim
98
+ b, l, d = x.shape
99
+
100
+ # compute q, k, v and then split per q, k, v
101
+ q, k = self.c_qk(x).chunk(2, dim=-1)
102
+ v = self.c_v(x)
103
+
104
+ # split per head
105
+ q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
106
+ k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
107
+ v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
108
+
109
+ q = self.q_norm(q)
110
+ k = self.k_norm(k)
111
+
112
+ if kv_cache is not None:
113
+ if not decode:
114
+ kv_cache.key_states[:, :, : k.shape[2], :].copy_(k)
115
+ kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
116
+ else:
117
+ assert curr_pos_id is not None
118
+ kv_cache.key_states.index_copy_(2, curr_pos_id, k)
119
+ kv_cache.value_states.index_copy_(2, curr_pos_id, v)
120
+ k = kv_cache.key_states
121
+ v = kv_cache.value_states
122
+
123
+ # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
124
+ # efficient attention using Flash Attention CUDA kernels
125
+ y = scaled_dot_product_attention_with_rotary_emb(
126
+ q,
127
+ k,
128
+ v,
129
+ freqs_cis=freqs_cis,
130
+ attn_mask=attn_mask,
131
+ curr_pos_id=curr_pos_id if decode else None,
132
+ is_causal=is_causal,
133
+ )
134
+
135
+ y = (
136
+ y.transpose(1, 2).contiguous().view(b, l, d)
137
+ ) # re-assemble all head outputs side by side
138
+
139
+ # output projection
140
+ y = self.c_proj(y)
141
+ return y
142
+
143
+
144
+ class DecoderLayerWithRotaryEmbedding(nn.Module):
145
+ def __init__(
146
+ self,
147
+ embed_dim: int,
148
+ num_heads: int,
149
+ bias: bool = True,
150
+ eps: float = 1e-6,
151
+ ) -> None:
152
+ """
153
+ Initializes the transformer model with rotary embeddings.
154
+ Args:
155
+ embed_dim (int): The dimensionality of the embedding space.
156
+ num_heads (int): The number of attention heads.
157
+ bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
158
+ eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6.
159
+ """
160
+ super().__init__()
161
+
162
+ self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
163
+ self.attn = SelfAttentionWithRotaryEmbedding(
164
+ embed_dim, num_heads=num_heads, bias=bias, eps=eps
165
+ )
166
+ self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
167
+ self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias)
168
+
169
+ @classmethod
170
+ def from_config(cls, cfg):
171
+ """
172
+ Create an instance of the class using the provided configuration.
173
+ Args:
174
+ cfg: A configuration object containing the following attributes:
175
+ - n_embd (int): The size of the embedding dimension.
176
+ - n_head (int): The number of attention heads.
177
+ - bias (bool): Whether to include a bias term.
178
+ - eps (float): A small value added for numerical stability.
179
+ Returns:
180
+ An instance of the class initialized with the specified configuration.
181
+ """
182
+
183
+ return cls(
184
+ cfg.n_embd,
185
+ num_heads=cfg.n_head,
186
+ bias=cfg.bias,
187
+ eps=cfg.eps,
188
+ )
189
+
190
+ def forward(
191
+ self,
192
+ x,
193
+ freqs_cis: torch.Tensor,
194
+ attn_mask=None,
195
+ is_causal: bool = True,
196
+ kv_cache: Optional[Cache] = None,
197
+ curr_pos_id: Optional[torch.Tensor] = None,
198
+ decode: bool = False,
199
+ ):
200
+ """
201
+ Forward pass for the transformer model.
202
+ Args:
203
+ x (torch.Tensor): Input tensor.
204
+ freqs_cis (torch.Tensor): Precomputed sinusoidal positional encodings.
205
+ attn_mask (Optional[torch.Tensor], optional): Attention mask to apply during self-attention.
206
+ Defaults to None.
207
+ is_causal (bool, optional): Whether to apply causal masking for autoregressive decoding.
208
+ Defaults to True.
209
+ kv_cache (Optional[Cache], optional): Key-value cache for efficient decoding.
210
+ Defaults to None.
211
+ curr_pos_id (Optional[torch.Tensor], optional): Current position IDs for decoding.
212
+ Defaults to None.
213
+ decode (bool, optional): Whether the model is in decoding mode.
214
+ Defaults to False.
215
+ Returns:
216
+ torch.Tensor: Output tensor.
217
+ """
218
+ out = self.attn(
219
+ self.ln_1(x),
220
+ freqs_cis=freqs_cis,
221
+ attn_mask=attn_mask,
222
+ is_causal=is_causal,
223
+ kv_cache=kv_cache,
224
+ curr_pos_id=curr_pos_id,
225
+ decode=decode,
226
+ )
227
+ x = x + out
228
+ x = x + self.mlp(self.ln_2(x))
229
+ return x
cube/cube3d/model/transformers/rope.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def apply_rotary_emb(
8
+ x: torch.Tensor,
9
+ freqs_cis: torch.Tensor,
10
+ curr_pos_id: Optional[torch.Tensor] = None,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Applies rotary positional embeddings to the input tensor.
14
+ Args:
15
+ x (torch.Tensor): The input tensor.
16
+ freqs_cis (torch.Tensor): A tensor containing the precomputed rotary
17
+ frequency components.
18
+ curr_pos_id (Optional[torch.Tensor]): An optional tensor specifying the
19
+ current position IDs to use for selecting a subset of `freqs_cis`.
20
+ If None, the function uses the last `seq_len` positions.
21
+ Returns:
22
+ torch.Tensor: The input tensor `x` with rotary positional embeddings
23
+ applied.
24
+ """
25
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
26
+ if curr_pos_id is None:
27
+ freqs_cis = freqs_cis[:, -x.shape[2] :].unsqueeze(1)
28
+ else:
29
+ freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
30
+ y = torch.view_as_real(x_ * freqs_cis).flatten(3)
31
+ return y.type_as(x)
32
+
33
+
34
+ @torch.no_grad
35
+ def precompute_freqs_cis(dim: int, t: torch.Tensor, theta: float = 10000.0):
36
+ """Calculate rotary embedding cos & sin, this is useful when every blocks in the network use same positional embedding.
37
+
38
+ Args:
39
+ dim (int): dimension of the single head of the transformer block
40
+ t (torch.Tensor): position ids [..., L]
41
+ theta (int, optional): rope theta. Defaults to 10000.
42
+
43
+ Returns:
44
+ Tuple[torch.Tensor, torch.Tensor]: tuple of cos and sin of rope
45
+ """
46
+ assert dim % 2 == 0, (
47
+ "RoPE only supports embedding dimensions that are multiples of 2"
48
+ )
49
+ freqs = 1.0 / (
50
+ theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)
51
+ )
52
+ # [batch_size, seq_len, num_freqs]
53
+ freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
54
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
55
+
56
+ return freqs_cis
57
+
58
+
59
+ def scaled_dot_product_attention_with_rotary_emb(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ freqs_cis: torch.Tensor,
64
+ attn_mask: Optional[torch.Tensor] = None,
65
+ curr_pos_id: Optional[torch.Tensor] = None,
66
+ is_causal: bool = False,
67
+ ) -> torch.Tensor:
68
+ """
69
+ Computes scaled dot product attention on query, key and value tensors
70
+ with rotary position embeddings on query and key.
71
+
72
+ Without caching enabled,
73
+ q should be (bs, nh, seqlen, hd).
74
+ k and v should stay unchanged, (bs, nh, seqlen, hd).
75
+ With caching enabled,
76
+ q should be (bs, nh, 1, hd).
77
+ k and v should stay unchanged, (bs, nh, 1, hd).
78
+ causal_mask must be False.
79
+ """
80
+ q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id) # (bs, nh, l, hd)
81
+ k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None) # (bs, nh, s + l, hd)
82
+
83
+ x = F.scaled_dot_product_attention(
84
+ q,
85
+ k,
86
+ v,
87
+ attn_mask=attn_mask,
88
+ dropout_p=0.0,
89
+ is_causal=is_causal and attn_mask is None,
90
+ )
91
+ return x
cube/cube3d/renderer/blender_script.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Blender script to render images of 3D models.
3
+
4
+ This script is adopted from the Trellis rendering script:
5
+ https://github.com/microsoft/TRELLIS/blob/main/dataset_toolkits/render.py
6
+
7
+ """
8
+
9
+ import argparse
10
+ import math
11
+ import os
12
+ import platform
13
+ import random
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Any, Callable, Dict, Generator, Literal, Optional, Tuple
17
+
18
+ import bpy
19
+ import numpy as np
20
+ from mathutils import Vector
21
+
22
+ pathdir = Path(__file__).parent
23
+ sys.path.append(pathdir.as_posix())
24
+
25
+ print(dir(bpy), bpy.__path__)
26
+
27
+ IMPORT_FUNCTIONS: Dict[str, Callable] = {
28
+ ".obj": bpy.ops.wm.obj_import,
29
+ ".glb": bpy.ops.import_scene.gltf,
30
+ ".gltf": bpy.ops.import_scene.gltf,
31
+ }
32
+
33
+
34
+ def center_and_scale_mesh(scale_value: float = 1.0) -> None:
35
+ """Centers and scales the scene to fit in a unit cube.
36
+ For example,
37
+ scale_value = 1.0 ==> [-0.5, 0.5]
38
+ scale_value = 2.0 ==> [-1.0, 1.0]
39
+ """
40
+ # Get all mesh objects
41
+ mesh_objects = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
42
+ if not mesh_objects:
43
+ return
44
+
45
+ # Calculate bounds
46
+ min_coords = Vector((float("inf"),) * 3)
47
+ max_coords = Vector((float("-inf"),) * 3)
48
+
49
+ for obj in mesh_objects:
50
+ # Get all vertices in world space
51
+ for vertex in obj.data.vertices:
52
+ world_coord = obj.matrix_world @ vertex.co
53
+ min_coords.x = min(min_coords.x, world_coord.x)
54
+ min_coords.y = min(min_coords.y, world_coord.y)
55
+ min_coords.z = min(min_coords.z, world_coord.z)
56
+ max_coords.x = max(max_coords.x, world_coord.x)
57
+ max_coords.y = max(max_coords.y, world_coord.y)
58
+ max_coords.z = max(max_coords.z, world_coord.z)
59
+
60
+ # Calculate center and dimensions
61
+ center = (min_coords + max_coords) / 2
62
+ dimensions = max_coords - min_coords
63
+ scale = scale_value / max(
64
+ dimensions.x, dimensions.y, dimensions.z
65
+ ) # Scale to fit in [-scale_value/2, scale_value/2] cube
66
+
67
+ # Create an empty to serve as the parent
68
+ empty = bpy.data.objects.new("Parent_Empty", None)
69
+ bpy.context.scene.collection.objects.link(empty)
70
+
71
+ # Parent all mesh objects to the empty
72
+ for obj in mesh_objects:
73
+ obj.parent = empty
74
+
75
+ # Move empty to center everything
76
+ empty.location = -center
77
+
78
+ # Apply scale to empty
79
+ empty.scale = (scale, scale, scale)
80
+
81
+ bpy.context.view_layer.update()
82
+ bpy.ops.object.select_all(action="DESELECT")
83
+ empty.select_set(True)
84
+ bpy.context.view_layer.objects.active = empty
85
+ bpy.ops.object.transform_apply(location=True, rotation=True, scale=True)
86
+ print(f"Empty location: {empty.location}")
87
+ print(f"Empty scale: {empty.scale}")
88
+
89
+ return scale
90
+
91
+
92
+ def normalize_scene() -> None:
93
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
94
+ at the origin.
95
+
96
+ Mostly taken from the Point-E / Shap-E rendering script
97
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
98
+ but fix for multiple root objects: (see bug report here:
99
+ https://github.com/openai/shap-e/pull/60).
100
+
101
+ Returns:
102
+ The new parent object that all objects descend from.
103
+ """
104
+ if len(list(get_scene_root_objects())) > 1:
105
+ # create an empty object to be used as a parent for all root objects
106
+ parent_empty = bpy.data.objects.new("ParentEmpty", None)
107
+ bpy.context.scene.collection.objects.link(parent_empty)
108
+
109
+ # parent all root objects to the empty object
110
+ for obj in get_scene_root_objects():
111
+ if obj != parent_empty:
112
+ obj.parent = parent_empty
113
+
114
+ bbox_min, bbox_max = scene_bbox()
115
+ scale = 1 / max(bbox_max - bbox_min)
116
+ for obj in get_scene_root_objects():
117
+ obj.scale = obj.scale * scale
118
+
119
+ # Apply scale to matrix_world.
120
+ bpy.context.view_layer.update()
121
+ bbox_min, bbox_max = scene_bbox()
122
+ offset = -(bbox_min + bbox_max) / 2
123
+ for obj in get_scene_root_objects():
124
+ obj.matrix_world.translation += offset
125
+ bpy.ops.object.select_all(action="DESELECT")
126
+ bbox_min, bbox_max = scene_bbox()
127
+ print(f"After normalize_scene: bbox_min: {bbox_min}, bbox_max: {bbox_max}")
128
+
129
+ # unparent the camera
130
+ bpy.data.objects["Camera"].parent = None
131
+
132
+ return parent_empty
133
+
134
+
135
+ def reset_cameras() -> None:
136
+ """Resets the cameras in the scene to a single default camera."""
137
+ # Delete all existing cameras
138
+ bpy.ops.object.select_all(action="DESELECT")
139
+ bpy.ops.object.select_by_type(type="CAMERA")
140
+ bpy.ops.object.delete()
141
+
142
+ # Create a new camera with default properties
143
+ bpy.ops.object.camera_add()
144
+
145
+ # Rename the new camera to 'NewDefaultCamera'
146
+ new_camera = bpy.context.active_object
147
+ new_camera.name = "Camera"
148
+
149
+ # Set the new camera as the active camera for the scene
150
+ scene.camera = new_camera
151
+
152
+
153
+ def get_camera_with_position(x, y, z, fov_degrees=40):
154
+ camera = bpy.data.objects["Camera"]
155
+ camera.data.angle = math.radians(fov_degrees)
156
+ camera.location = np.array([x, y, z])
157
+ direction = -camera.location
158
+ rot_quat = direction.to_track_quat("-Z", "Y")
159
+ camera.rotation_euler = rot_quat.to_euler()
160
+ return camera
161
+
162
+
163
+ def reset_scene() -> None:
164
+ """Resets the scene to a clean state.
165
+
166
+ Returns:
167
+ None
168
+ """
169
+ # delete everything that isn't part of a camera or a light
170
+ for obj in bpy.data.objects:
171
+ if obj.type not in {"CAMERA", "LIGHT"}:
172
+ bpy.data.objects.remove(obj, do_unlink=True)
173
+
174
+ # delete all the materials
175
+ for material in bpy.data.materials:
176
+ bpy.data.materials.remove(material, do_unlink=True)
177
+
178
+ # delete all the textures
179
+ for texture in bpy.data.textures:
180
+ bpy.data.textures.remove(texture, do_unlink=True)
181
+
182
+ # delete all the images
183
+ for image in bpy.data.images:
184
+ bpy.data.images.remove(image, do_unlink=True)
185
+
186
+
187
+ def load_object(object_path: str) -> None:
188
+ """Loads a model with a supported file extension into the scene.
189
+
190
+ Args:
191
+ object_path (str): Path to the model file.
192
+
193
+ Raises:
194
+ ValueError: If the file extension is not supported.
195
+
196
+ Returns:
197
+ None
198
+ """
199
+ file_extension = Path(object_path).suffix
200
+ if file_extension is None or file_extension == "":
201
+ raise ValueError(f"Unsupported file type: {object_path}")
202
+
203
+ # load from existing import functions
204
+ import_function = IMPORT_FUNCTIONS[file_extension]
205
+
206
+ if file_extension in {".glb", ".gltf"}:
207
+ import_function(filepath=object_path, merge_vertices=True)
208
+ else:
209
+ import_function(filepath=object_path)
210
+
211
+
212
+ def clear_lights():
213
+ bpy.ops.object.select_all(action="DESELECT")
214
+ for obj in bpy.context.scene.objects.values():
215
+ if isinstance(obj.data, bpy.types.Light):
216
+ obj.select_set(True)
217
+ bpy.ops.object.delete()
218
+
219
+
220
+ def create_light(
221
+ location,
222
+ energy=1.0,
223
+ angle=0.5 * math.pi / 180,
224
+ light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
225
+ ):
226
+ # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92
227
+ light_data = bpy.data.lights.new(name="Light", type=light_type)
228
+ light_data.energy = energy
229
+ if light_type != "AREA" and light_type != "POINT":
230
+ light_data.angle = angle
231
+ light_object = bpy.data.objects.new(name="Light", object_data=light_data)
232
+
233
+ direction = -location
234
+ rot_quat = direction.to_track_quat("-Z", "Y")
235
+ light_object.rotation_euler = rot_quat.to_euler()
236
+ bpy.context.view_layer.update()
237
+
238
+ bpy.context.collection.objects.link(light_object)
239
+ light_object.location = location
240
+
241
+
242
+ def create_uniform_lights(
243
+ distance=2.0,
244
+ energy=3.0,
245
+ light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
246
+ ):
247
+ clear_lights()
248
+ create_light(Vector([1, 0, 0]) * distance, energy=energy, light_type=light_type)
249
+ create_light(-Vector([1, 0, 0]) * distance, energy=energy, light_type=light_type)
250
+ create_light(Vector([0, 1, 0]) * distance, energy=energy, light_type=light_type)
251
+ create_light(-Vector([0, 1, 0]) * distance, energy=energy, light_type=light_type)
252
+ create_light(Vector([0, 0, 1]) * distance, energy=energy, light_type=light_type)
253
+ create_light(-Vector([0, 0, 1]) * distance, energy=energy, light_type=light_type)
254
+
255
+
256
+ def create_light_at_camera_position(
257
+ camera_position: Vector,
258
+ energy=1.5,
259
+ use_shadow=False,
260
+ light_type: Literal["POINT", "SUN", "SPOT", "AREA"] = "SUN",
261
+ ):
262
+ clear_lights()
263
+ create_light(camera_position, energy=energy, light_type=light_type)
264
+ # disable shadows
265
+ if not use_shadow:
266
+ for light in bpy.data.lights:
267
+ light.use_shadow = False
268
+
269
+
270
+ def set_world_background_color(
271
+ color: Tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
272
+ ) -> None:
273
+ bpy.context.scene.world.use_nodes = True
274
+ bpy.context.scene.world.node_tree.nodes["Background"].inputs[
275
+ 0
276
+ ].default_value = color
277
+ bpy.context.scene.view_settings.view_transform = "Standard"
278
+
279
+
280
+ def scene_bbox(
281
+ single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
282
+ ) -> Tuple[Vector, Vector]:
283
+ """Returns the bounding box of the scene.
284
+
285
+ Taken from Shap-E rendering script
286
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
287
+
288
+ Args:
289
+ single_obj (Optional[bpy.types.Object], optional): If not None, only computes
290
+ the bounding box for the given object. Defaults to None.
291
+ ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
292
+ to False.
293
+
294
+ Raises:
295
+ RuntimeError: If there are no objects in the scene.
296
+
297
+ Returns:
298
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
299
+ """
300
+ bbox_min = (math.inf,) * 3
301
+ bbox_max = (-math.inf,) * 3
302
+ found = False
303
+ for obj in get_scene_meshes() if single_obj is None else [single_obj]:
304
+ found = True
305
+ for coord in obj.bound_box:
306
+ coord = Vector(coord)
307
+ if not ignore_matrix:
308
+ coord = obj.matrix_world @ coord
309
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
310
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
311
+
312
+ if not found:
313
+ raise RuntimeError("no objects in scene to compute bounding box for")
314
+
315
+ return Vector(bbox_min), Vector(bbox_max)
316
+
317
+
318
+ def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
319
+ """Returns all root objects in the scene.
320
+
321
+ Yields:
322
+ Generator[bpy.types.Object, None, None]: Generator of all root objects in the
323
+ scene.
324
+ """
325
+ for obj in bpy.context.scene.objects.values():
326
+ if not obj.parent and not isinstance(obj.data, bpy.types.Light):
327
+ yield obj
328
+
329
+
330
+ def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
331
+ """Returns all meshes in the scene.
332
+
333
+ Yields:
334
+ Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
335
+ """
336
+ for obj in bpy.context.scene.objects.values():
337
+ if isinstance(obj.data, (bpy.types.Mesh)):
338
+ yield obj
339
+
340
+
341
+ def delete_missing_textures() -> Dict[str, Any]:
342
+ """Deletes all missing textures in the scene.
343
+
344
+ Returns:
345
+ Dict[str, Any]: Dictionary with keys "count", "files", and "file_path_to_color".
346
+ "count" is the number of missing textures, "files" is a list of the missing
347
+ texture file paths, and "file_path_to_color" is a dictionary mapping the
348
+ missing texture file paths to a random color.
349
+ """
350
+ missing_file_count = 0
351
+ out_files = []
352
+ file_path_to_color = {}
353
+
354
+ # Check all materials in the scene
355
+ for material in bpy.data.materials:
356
+ if material.use_nodes:
357
+ for node in material.node_tree.nodes:
358
+ if node.type == "TEX_IMAGE":
359
+ image = node.image
360
+ if image is not None:
361
+ file_path = bpy.path.abspath(image.filepath)
362
+ if file_path == "":
363
+ # means it's embedded
364
+ continue
365
+
366
+ if not os.path.exists(file_path):
367
+ # Find the connected Principled BSDF node
368
+ connected_node = node.outputs[0].links[0].to_node
369
+
370
+ if connected_node.type == "BSDF_PRINCIPLED":
371
+ if file_path not in file_path_to_color:
372
+ # Set a random color for the unique missing file path
373
+ random_color = [random.random() for _ in range(3)]
374
+ file_path_to_color[file_path] = random_color + [1]
375
+
376
+ connected_node.inputs[
377
+ "Base Color"
378
+ ].default_value = file_path_to_color[file_path]
379
+
380
+ # Delete the TEX_IMAGE node
381
+ material.node_tree.nodes.remove(node)
382
+ missing_file_count += 1
383
+ out_files.append(image.filepath)
384
+ return {
385
+ "count": missing_file_count,
386
+ "files": out_files,
387
+ "file_path_to_color": file_path_to_color,
388
+ }
389
+
390
+
391
+ def setup_environment_lighting(envmap_path):
392
+ world = bpy.context.scene.world
393
+ world.use_nodes = True
394
+ nodes = world.node_tree.nodes
395
+ links = world.node_tree.links
396
+
397
+ # Clear existing nodes
398
+ for node in nodes:
399
+ nodes.remove(node)
400
+
401
+ # Create Background node
402
+ bg_node = nodes.new(type="ShaderNodeBackground")
403
+ bg_node.location = (0, 0)
404
+
405
+ # Create Environment Texture node
406
+ env_tex_node = nodes.new(type="ShaderNodeTexEnvironment")
407
+ env_tex_node.location = (-300, 0)
408
+
409
+ # Set the environment texture path (replace this with your file path)
410
+ env_tex_node.image = bpy.data.images.load(envmap_path)
411
+
412
+ # Create World Output node
413
+ world_output_node = nodes.new(type="ShaderNodeOutputWorld")
414
+ world_output_node.location = (300, 0)
415
+
416
+ # Link nodes
417
+ links.new(env_tex_node.outputs["Color"], bg_node.inputs["Color"])
418
+ links.new(bg_node.outputs["Background"], world_output_node.inputs["Surface"])
419
+
420
+
421
+ def create_solid_color_material(name, color):
422
+ mat = bpy.data.materials.new(name)
423
+ mat.use_nodes = True
424
+ node_tree = mat.node_tree
425
+ color_node = node_tree.nodes.new("ShaderNodeBsdfDiffuse")
426
+ color_node.inputs["Color"].default_value = color
427
+ mat_output = node_tree.nodes["Material Output"]
428
+ node_tree.links.new(color_node.outputs["BSDF"], mat_output.inputs["Surface"])
429
+ return mat
430
+
431
+
432
+ def create_phong_material(name, color):
433
+ mat = bpy.data.materials.new(name)
434
+ mat.use_nodes = True
435
+ node_tree = mat.node_tree
436
+ spec_node = node_tree.nodes.new("ShaderNodeBsdfPrincipled")
437
+ print(spec_node.inputs.keys())
438
+ spec_node.inputs["Base Color"].default_value = color
439
+ spec_node.inputs["Roughness"].default_value = 0.5
440
+ spec_node.inputs["Metallic"].default_value = 1.0
441
+ mat_output = node_tree.nodes["Material Output"]
442
+ node_tree.links.new(spec_node.outputs["BSDF"], mat_output.inputs["Surface"])
443
+ return mat
444
+
445
+
446
+ def render_object(
447
+ object_file: str,
448
+ num_renders: int,
449
+ output_dir: str,
450
+ transparent_background: bool = False,
451
+ environment_map: str = None,
452
+ ) -> None:
453
+ """Saves rendered images for given asset to specified output directory.
454
+
455
+ Args:
456
+ object_file (str): Path to the object file.
457
+ num_renders (int): Number of renders to save of the object.
458
+ output_dir (str): Path to the directory where the rendered images and metadata
459
+ will be saved. The rendered images will be saved in the subdirectory
460
+ `output_dir/stemname`.
461
+ transparent_background (bool): Whether to use transparent background,
462
+ otherwise the background is white.
463
+ Returns:
464
+ None
465
+ """
466
+ os.makedirs(output_dir, exist_ok=True)
467
+
468
+ # load the object
469
+ reset_scene()
470
+ load_object(object_file)
471
+
472
+ if transparent_background:
473
+ scene.render.film_transparent = True
474
+ else:
475
+ scene.render.film_transparent = False
476
+
477
+ set_world_background_color([0.2, 0.2, 0.2, 1.0])
478
+
479
+ # normalize the scene
480
+ _ = normalize_scene()
481
+
482
+ # Set up cameras
483
+ cam = scene.objects["Camera"]
484
+ fov_degrees = 40.0
485
+ cam.data.angle = np.radians(fov_degrees)
486
+
487
+ # Set up camera constraints
488
+ cam_constraint = cam.constraints.new(type="TRACK_TO")
489
+ cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
490
+ cam_constraint.up_axis = "UP_Y"
491
+ empty = bpy.data.objects.new("Empty", None)
492
+ empty.location = (0, 0, 0)
493
+ scene.collection.objects.link(empty)
494
+ cam_constraint.target = empty
495
+ cam.parent = empty
496
+
497
+ # delete all objects that are not meshes
498
+ delete_missing_textures()
499
+
500
+ if environment_map:
501
+ setup_environment_lighting(environment_map)
502
+ else:
503
+ create_uniform_lights(energy=1.0, light_type="SUN")
504
+
505
+ camera_position = [0, -2, 0]
506
+
507
+ # determine how much to orbit camera by.
508
+ stepsize = 360.0 / num_renders
509
+
510
+ def render_views(name):
511
+ for i in range(num_renders):
512
+ # set camera
513
+ _ = get_camera_with_position(
514
+ camera_position[0],
515
+ camera_position[1],
516
+ camera_position[2],
517
+ fov_degrees=fov_degrees,
518
+ )
519
+
520
+ # Set output paths with absolute paths
521
+ render_path = os.path.abspath(
522
+ os.path.join(output_dir, f"{i:03d}_{name}.png")
523
+ )
524
+
525
+ # Set file output paths
526
+ scene.render.filepath = render_path
527
+
528
+ # Make sure the output directory exists
529
+ os.makedirs(output_dir, exist_ok=True)
530
+
531
+ # Render
532
+ bpy.ops.render.render(write_still=True)
533
+
534
+ context.view_layer.objects.active = empty
535
+ empty.rotation_euler[2] += math.radians(stepsize)
536
+
537
+ # ensure that all objects have materials, if not then add a default
538
+ # one.
539
+ textured_mat = create_solid_color_material("default texture", [0.6, 0.6, 0.6, 1])
540
+
541
+ for obj in get_scene_meshes():
542
+ if obj.active_material is None:
543
+ obj.active_material = textured_mat
544
+
545
+ render_views("textured")
546
+
547
+
548
+ def enable_gpus(device_type, use_cpus=False):
549
+ preferences = bpy.context.preferences
550
+ cycles_preferences = preferences.addons["cycles"].preferences
551
+ cycles_preferences.refresh_devices()
552
+ try:
553
+ devices = cycles_preferences.devices
554
+ except:
555
+ print("No devices detected")
556
+ if device_type == "CPU":
557
+ return []
558
+ else:
559
+ raise RuntimeError(f"No devices detected, set use_cpus to True")
560
+
561
+ assert device_type in [
562
+ "CUDA",
563
+ "METAL",
564
+ "OPENCL",
565
+ "CPU",
566
+ "NONE",
567
+ ], f"Unsupported device type: {device_type}"
568
+
569
+ try:
570
+ # print(devices)
571
+ iter(devices)
572
+ except TypeError:
573
+ # print("Single GPU Detected")
574
+ devices = [devices]
575
+
576
+ activated_gpus = []
577
+ for device in devices:
578
+ if device.type == "CPU":
579
+ device.use = use_cpus
580
+ else:
581
+ device.use = True
582
+ activated_gpus.append(device.name)
583
+
584
+ if device_type == "CUDA":
585
+ cycles_preferences.compute_device_type = "CUDA"
586
+ bpy.context.scene.cycles.device = "GPU"
587
+ elif device_type == "METAL":
588
+ cycles_preferences.compute_device_type = "METAL"
589
+ bpy.context.scene.cycles.device = "GPU"
590
+ elif device_type == "OPENCL":
591
+ cycles_preferences.compute_device_type = "OPENCL"
592
+ bpy.context.scene.cycles.device = "GPU"
593
+ else:
594
+ raise RuntimeError(f"Unsupported device type: {device_type}")
595
+
596
+ return activated_gpus
597
+
598
+
599
+ def set_render_settings(engine, resolution):
600
+ # Set render settings
601
+ render.engine = engine #
602
+ render.image_settings.file_format = "PNG"
603
+ render.image_settings.color_mode = "RGBA"
604
+ render.resolution_x = resolution
605
+ render.resolution_y = resolution
606
+ render.resolution_percentage = 100
607
+
608
+ # Set cycles settings
609
+ scene.cycles.device = "GPU"
610
+ scene.cycles.use_adaptive_sampling = True
611
+ scene.cycles.adaptive_threshold = 0.1
612
+ scene.cycles.samples = 64
613
+ scene.cycles.adaptive_min_samples = 1
614
+ scene.cycles.filter_width = 2
615
+ scene.cycles.use_fast_gi = True
616
+ scene.cycles.fast_gi_method = "REPLACE"
617
+ world.light_settings.ao_factor = 1.0
618
+ world.light_settings.distance = 10
619
+ scene.cycles.use_denoising = True # ML denoising
620
+ scene.cycles.denoising_use_gpu = True
621
+
622
+ # bake existing frames for faster future renders
623
+ scene.render.use_persistent_data = True
624
+
625
+ # Set eevee settings
626
+ scene.eevee.use_shadows = True
627
+ scene.eevee.use_raytracing = True
628
+ scene.eevee.ray_tracing_options.use_denoise = True
629
+ scene.eevee.use_fast_gi = True
630
+ scene.eevee.fast_gi_method = "GLOBAL_ILLUMINATION"
631
+ scene.eevee.ray_tracing_options.trace_max_roughness = 0.5
632
+ scene.eevee.fast_gi_resolution = "2"
633
+ scene.eevee.fast_gi_ray_count = 2
634
+ scene.eevee.fast_gi_step_count = 8
635
+
636
+
637
+ def print_devices():
638
+ print("Devices:")
639
+ preferences = bpy.context.preferences
640
+ cycles_preferences = preferences.addons["cycles"].preferences
641
+ cycles_preferences.refresh_devices()
642
+
643
+ devices = cycles_preferences.devices
644
+ for device in devices:
645
+ print(f' [{device.id}]<{device.type}> "{device.name}" Using: {device.use}')
646
+
647
+ print(f"Compute device type: {cycles_preferences.compute_device_type}")
648
+ print(f"Cycles device: {bpy.context.scene.cycles.device}")
649
+
650
+
651
+ if __name__ == "__main__":
652
+ parser = argparse.ArgumentParser()
653
+ parser.add_argument(
654
+ "--object_path",
655
+ type=str,
656
+ required=False,
657
+ help="Path to the object file",
658
+ )
659
+ parser.add_argument(
660
+ "--output_dir",
661
+ type=str,
662
+ required=True,
663
+ help="Path to the directory where the rendered images and metadata will be saved.",
664
+ )
665
+ parser.add_argument(
666
+ "--engine",
667
+ type=str,
668
+ default="BLENDER_EEVEE_NEXT", # BLENDER_BLENDER_EEVEE_NEXT rasterization, better than nvdifrast, CYCLES
669
+ choices=["CYCLES", "BLENDER_EEVEE_NEXT"],
670
+ )
671
+ parser.add_argument(
672
+ "--num_renders",
673
+ type=int,
674
+ default=12,
675
+ help="Number of renders to save of the object.",
676
+ )
677
+ parser.add_argument(
678
+ "--render_resolution",
679
+ type=int,
680
+ default=512,
681
+ help="Resolution of the rendered images.",
682
+ )
683
+ parser.add_argument(
684
+ "--transparent_background",
685
+ action="store_true",
686
+ help="Whether to use transparent background",
687
+ )
688
+ parser.add_argument(
689
+ "--environment_map",
690
+ default=None,
691
+ type=str,
692
+ help="Use the given environment map for lighting",
693
+ )
694
+
695
+ argv = sys.argv[sys.argv.index("--") + 1 :]
696
+ args = parser.parse_args(argv)
697
+
698
+ context = bpy.context
699
+ scene = context.scene
700
+ render = scene.render
701
+ world = bpy.data.worlds["World"]
702
+
703
+ set_render_settings(args.engine, args.render_resolution)
704
+
705
+ # detect platform and activate GPUs
706
+ platform = platform.system()
707
+ if platform == "Darwin":
708
+ activated_gpus = enable_gpus("METAL", use_cpus=True)
709
+ elif platform == "Linux":
710
+ activated_gpus = enable_gpus("CUDA", use_cpus=False)
711
+ else:
712
+ raise RuntimeError("Unsupported platform")
713
+ print(f"Activated GPUs: {activated_gpus}")
714
+
715
+ print_devices()
716
+
717
+ render_object(
718
+ object_file=args.object_path,
719
+ num_renders=args.num_renders,
720
+ output_dir=args.output_dir,
721
+ transparent_background=args.transparent_background,
722
+ environment_map=args.environment_map,
723
+ )
cube/cube3d/renderer/renderer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ from PIL import Image
8
+
9
+
10
+ def render_asset(
11
+ asset_path,
12
+ output_dir,
13
+ nviews=24,
14
+ img_resolution=512,
15
+ ):
16
+ """
17
+ Render given asset into output_dir and return the saved image paths.
18
+ Assumes that blender is installed and is in your path.
19
+
20
+ nviews : number of views to render
21
+ img_resolution : resolution of each rendered view in pixels
22
+ """
23
+
24
+ curr_file_path = __file__
25
+ curr_dir = os.path.dirname(curr_file_path)
26
+
27
+ command = [
28
+ "blender",
29
+ "--background",
30
+ "-noaudio",
31
+ "--python",
32
+ f"{curr_dir}/blender_script.py",
33
+ "--",
34
+ "--object_path",
35
+ asset_path,
36
+ "--num_renders",
37
+ str(nviews),
38
+ "--output_dir",
39
+ output_dir,
40
+ "--render_resolution",
41
+ str(img_resolution),
42
+ "--transparent_background",
43
+ "--engine",
44
+ "CYCLES",
45
+ ]
46
+
47
+ subprocess.run(command, check=True)
48
+
49
+ # return the saved images paths
50
+ images = []
51
+
52
+ for i in range(nviews):
53
+ fp = os.path.abspath(os.path.join(output_dir, f"{i:03d}_textured.png"))
54
+ images.append(fp)
55
+
56
+ return images
57
+
58
+
59
+ def save_gif(image_paths, outfile):
60
+ images = [Image.open(img) for img in image_paths]
61
+ if len(images) > 1:
62
+ background = Image.new("RGBA", images[0].size, (255, 255, 255))
63
+ images = [
64
+ Image.alpha_composite(background, png).convert("RGB") for png in images
65
+ ]
66
+ images[0].save(
67
+ outfile, save_all=True, append_images=images[1:], duration=100, loop=0
68
+ )
69
+
70
+
71
+ def render_turntable(obj_path, output_dir, output_name="turntable"):
72
+ """
73
+ Render a turntable gif of the mesh. Assumes that blender is installed and is in your path.
74
+ obj_path : path to the obj file
75
+ output_dir : directory to save the gif. Final image will be saved as `turntable.gif`
76
+ """
77
+ image_paths = render_asset(obj_path, output_dir)
78
+ gif_turntable_outfile = Path(output_dir) / f"{output_name}.gif"
79
+ save_gif(image_paths, gif_turntable_outfile)
80
+ return gif_turntable_outfile
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("-i", "--input")
86
+ parser.add_argument("-o", "--output_dir")
87
+ args = parser.parse_args(sys.argv[1:])
88
+ render_turntable(args.input, args.output_dir)
cube/cube3d/vq_vae_encode_decode.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ import numpy as np
5
+ import torch
6
+ import trimesh
7
+
8
+ from cube3d.inference.utils import load_config, load_model_weights, parse_structured
9
+ from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
10
+
11
+ MESH_SCALE = 0.96
12
+
13
+
14
+ def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
15
+ """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
16
+ vertices = vertices
17
+ bbmin = vertices.min(0)
18
+ bbmax = vertices.max(0)
19
+ center = (bbmin + bbmax) * 0.5
20
+ scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
21
+ vertices = (vertices - center) * scale
22
+ return vertices
23
+
24
+
25
+ def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
26
+ """
27
+ Load a mesh and scale it to a unit cube, and clean the mesh.
28
+ Parameters:
29
+ file_obj: str | IO
30
+ file_type: str
31
+ Returns:
32
+ mesh: trimesh.Trimesh
33
+ """
34
+ mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
35
+ mesh.remove_infinite_values()
36
+ mesh.update_faces(mesh.nondegenerate_faces())
37
+ mesh.update_faces(mesh.unique_faces())
38
+ mesh.remove_unreferenced_vertices()
39
+ if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
40
+ raise ValueError("Mesh has no vertices or faces after cleaning")
41
+ mesh.vertices = rescale(mesh.vertices)
42
+ return mesh
43
+
44
+
45
+ def load_and_process_mesh(file_path: str, n_samples: int = 8192):
46
+ """
47
+ Loads a 3D mesh from the specified file path, samples points from its surface,
48
+ and processes the sampled points into a point cloud with normals.
49
+ Args:
50
+ file_path (str): The file path to the 3D mesh file.
51
+ n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
52
+ Returns:
53
+ torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
54
+ Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
55
+ """
56
+
57
+ mesh = load_scaled_mesh(file_path)
58
+ positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
59
+ normals = mesh.face_normals[face_indices]
60
+ point_cloud = np.concatenate(
61
+ [positions, normals], axis=1
62
+ ) # Shape: (num_samples, 6)
63
+ point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
64
+ return point_cloud
65
+
66
+
67
+ @torch.inference_mode()
68
+ def run_shape_decode(
69
+ shape_model: OneDAutoEncoder,
70
+ output_ids: torch.Tensor,
71
+ resolution_base: float = 8.0,
72
+ chunk_size: int = 100_000,
73
+ ):
74
+ """
75
+ Decodes the shape from the given output IDs and extracts the geometry.
76
+ Args:
77
+ shape_model (OneDAutoEncoder): The shape model.
78
+ output_ids (torch.Tensor): The tensor containing the output IDs.
79
+ resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
80
+ chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
81
+ Returns:
82
+ tuple: A tuple containing the vertices and faces of the mesh.
83
+ """
84
+ shape_ids = (
85
+ output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
86
+ .clamp_(0, shape_model.cfg.num_codes - 1)
87
+ .view(-1, shape_model.cfg.num_encoder_latents)
88
+ )
89
+ latents = shape_model.decode_indices(shape_ids)
90
+ mesh_v_f, _ = shape_model.extract_geometry(
91
+ latents,
92
+ resolution_base=resolution_base,
93
+ chunk_size=chunk_size,
94
+ use_warp=True,
95
+ )
96
+ return mesh_v_f
97
+
98
+
99
+ if __name__ == "__main__":
100
+ parser = argparse.ArgumentParser(
101
+ description="cube shape encode and decode example script"
102
+ )
103
+ parser.add_argument(
104
+ "--mesh-path",
105
+ type=str,
106
+ required=True,
107
+ help="Path to the input mesh file.",
108
+ )
109
+ parser.add_argument(
110
+ "--config-path",
111
+ type=str,
112
+ default="cube3d/configs/open_model.yaml",
113
+ help="Path to the configuration YAML file.",
114
+ )
115
+ parser.add_argument(
116
+ "--shape-ckpt-path",
117
+ type=str,
118
+ required=True,
119
+ help="Path to the shape encoder/decoder checkpoint file.",
120
+ )
121
+ parser.add_argument(
122
+ "--recovered-mesh-path",
123
+ type=str,
124
+ default="recovered_mesh.obj",
125
+ help="Path to save the recovered mesh file.",
126
+ )
127
+ args = parser.parse_args()
128
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
129
+ logging.info(f"Using device: {device}")
130
+
131
+ cfg = load_config(args.config_path)
132
+
133
+ shape_model = OneDAutoEncoder(
134
+ parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
135
+ )
136
+ load_model_weights(
137
+ shape_model,
138
+ args.shape_ckpt_path,
139
+ )
140
+ shape_model = shape_model.eval().to(device)
141
+ point_cloud = load_and_process_mesh(args.mesh_path)
142
+ output = shape_model.encode(point_cloud.to(device))
143
+ indices = output[3]["indices"]
144
+ print("Got the following shape indices:")
145
+ print(indices)
146
+ print("Indices shape: ", indices.shape)
147
+ mesh_v_f = run_shape_decode(shape_model, indices)
148
+ vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
149
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
150
+ mesh.export(args.recovered_mesh_path)
cube/pyproject.toml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "cube"
7
+ version = "0.1"
8
+ requires-python = ">=3.7"
9
+ description = "A generative 3D model to accelerate the creation of 3D assets, accessories, and experiences."
10
+ authors = [
11
+ { name = "Foundation AI", email = "[email protected]" }
12
+ ]
13
+ keywords = ["cube"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Programming Language :: Python :: 3.10",
18
+ ]
19
+ dependencies = [
20
+ "numpy",
21
+ "torch>=2.2.2",
22
+ "tqdm",
23
+ "transformers",
24
+ "omegaconf",
25
+ "warp-lang",
26
+ "accelerate>=0.26.0",
27
+ "scikit-image",
28
+ "huggingface_hub[cli]",
29
+ "trimesh"
30
+ ]
31
+ [project.optional-dependencies]
32
+ meshlab = ["pymeshlab"]
33
+ lint = ["ruff==0.9.10"]
34
+
35
+ [tool.setuptools.packages.find]
36
+ where = ["cube3d"]
37
+ include = ["cube/*"]
38
+ namespaces = false
cube/setup.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="cube",
5
+ version="0.0.1",
6
+ )