mboss commited on
Commit
d945eeb
·
1 Parent(s): 816b4c8

Initial commit

Browse files
Files changed (42) hide show
  1. .gitattributes +2 -0
  2. LICENSE.md +51 -0
  3. README.md +11 -5
  4. app.py +357 -0
  5. demo_files/comp.gif +3 -0
  6. demo_files/examples/animal_character.png +3 -0
  7. demo_files/examples/animal_character_2.png +3 -0
  8. demo_files/examples/axe.png +3 -0
  9. demo_files/examples/chair1.png +3 -0
  10. demo_files/examples/character1.png +3 -0
  11. demo_files/examples/otter_samurai.png +3 -0
  12. demo_files/examples/raccoon_wizard.png +3 -0
  13. demo_files/examples/stylized-rocks.png +3 -0
  14. demo_files/examples/tree.png +3 -0
  15. demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
  16. demo_files/hdri/metro_noord_1k.hdr +0 -0
  17. demo_files/hdri/neon_photostudio_1k.hdr +0 -0
  18. demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
  19. demo_files/hdri/rainforest_trail_1k.hdr +0 -0
  20. demo_files/hdri/studio_small_08_1k.hdr +0 -0
  21. demo_files/hdri/urban_alley_01_1k.hdr +0 -0
  22. demo_files/scatterplot.jpg +0 -0
  23. demo_files/teaser.gif +3 -0
  24. load/tets/160_tets.npz +3 -0
  25. requirements.txt +13 -0
  26. sf3d/box_uv_unwrap.py +610 -0
  27. sf3d/models/camera.py +32 -0
  28. sf3d/models/global_estimator/multi_head_estimator.py +118 -0
  29. sf3d/models/image_estimator/clip_based_estimator.py +168 -0
  30. sf3d/models/isosurface.py +229 -0
  31. sf3d/models/mesh.py +172 -0
  32. sf3d/models/network.py +195 -0
  33. sf3d/models/tokenizers/dinov2.py +1196 -0
  34. sf3d/models/tokenizers/image.py +99 -0
  35. sf3d/models/tokenizers/triplane.py +49 -0
  36. sf3d/models/transformers/attention.py +31 -0
  37. sf3d/models/transformers/backbone.py +515 -0
  38. sf3d/models/utils.py +292 -0
  39. sf3d/system.py +483 -0
  40. sf3d/texture_baker.py +87 -0
  41. sf3d/texture_baker.slang +93 -0
  42. sf3d/utils.py +91 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+ Last Updated: July 5, 2024
3
+
4
+
5
+ I. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+
10
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
11
+
12
+
13
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
14
+
15
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
18
+
19
+ III. COMMERCIAL USE LICENSE
20
+
21
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
22
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
23
+
24
+ IV. GENERAL TERMS
25
+
26
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
27
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
28
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
29
+ c. Intellectual Property.
30
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
31
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
32
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
33
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
34
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
35
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
38
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
39
+
40
+ V. DEFINITIONS
41
+
42
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
43
+ "Agreement" means this Stability AI Community License Agreement.
44
+ "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
45
+ "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
46
+ "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
47
+ "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
48
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
49
+ "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
50
+ "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
51
+ "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,18 @@
1
  ---
2
- title: Stable Fast 3d
3
- emoji: 🌖
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.38.1
 
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Stable Fast 3D
3
+ emoji: 🎮
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.31.4
8
+ python_version: 3.10.13
9
  app_file: app.py
10
  pinned: false
11
+ models:
12
+ - stabilityai/stable-fast-3d
13
+ license: other
14
+ license_name: stabilityai-ai-community
15
+ license_link: LICENSE.md
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ from functools import lru_cache
5
+ from typing import Any
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import rembg
10
+ import torch
11
+ from gradio_litmodel3d import LitModel3D
12
+ import spaces
13
+ from PIL import Image
14
+
15
+ import sf3d.utils as sf3d_utils
16
+ from sf3d.system import SF3D
17
+
18
+ rembg_session = rembg.new_session()
19
+
20
+ COND_WIDTH = 512
21
+ COND_HEIGHT = 512
22
+ COND_DISTANCE = 1.6
23
+ COND_FOVY_DEG = 40
24
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
25
+
26
+ # Cached. Doesn't change
27
+ c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
28
+ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
29
+ COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
30
+ )
31
+
32
+
33
+ model = SF3D.from_pretrained(
34
+ "stabilityai/stable-fast-3d",
35
+ config_name="config.yaml",
36
+ weight_name="model.safetensors",
37
+ )
38
+ model.eval().cuda()
39
+
40
+ example_files = [
41
+ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
42
+ ]
43
+
44
+
45
+ @spaces.GPU
46
+ def run_model(input_image):
47
+ start = time.time()
48
+ with torch.no_grad():
49
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
50
+ model_batch = create_batch(input_image)
51
+ model_batch = {k: v.cuda() for k, v in model_batch.items()}
52
+ trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
53
+ trimesh_mesh = trimesh_mesh[0]
54
+
55
+ # Create new tmp file
56
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
57
+
58
+ trimesh_mesh.export(tmp_file.name, file_type="glb")
59
+
60
+ print("Generation took:", time.time() - start, "s")
61
+
62
+ return tmp_file.name
63
+
64
+
65
+ def create_batch(input_image: Image) -> dict[str, Any]:
66
+ img_cond = (
67
+ torch.from_numpy(
68
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
69
+ / 255.0
70
+ )
71
+ .float()
72
+ .clip(0, 1)
73
+ )
74
+ mask_cond = img_cond[:, :, -1:]
75
+ rgb_cond = torch.lerp(
76
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
77
+ )
78
+
79
+ batch_elem = {
80
+ "rgb_cond": rgb_cond,
81
+ "mask_cond": mask_cond,
82
+ "c2w_cond": c2w_cond.unsqueeze(0),
83
+ "intrinsic_cond": intrinsic.unsqueeze(0),
84
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
85
+ }
86
+ # Add batch dim
87
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
88
+ return batched
89
+
90
+
91
+ @lru_cache
92
+ def checkerboard(squares: int, size: int, min_value: float = 0.5):
93
+ base = np.zeros((squares, squares)) + min_value
94
+ base[1::2, ::2] = 1
95
+ base[::2, 1::2] = 1
96
+
97
+ repeat_mult = size // squares
98
+ return (
99
+ base.repeat(repeat_mult, axis=0)
100
+ .repeat(repeat_mult, axis=1)[:, :, None]
101
+ .repeat(3, axis=-1)
102
+ )
103
+
104
+
105
+ def remove_background(input_image: Image) -> Image:
106
+ return rembg.remove(input_image, session=rembg_session)
107
+
108
+
109
+ def resize_foreground(
110
+ image: Image,
111
+ ratio: float,
112
+ ) -> Image:
113
+ image = np.array(image)
114
+ assert image.shape[-1] == 4
115
+ alpha = np.where(image[..., 3] > 0)
116
+ y1, y2, x1, x2 = (
117
+ alpha[0].min(),
118
+ alpha[0].max(),
119
+ alpha[1].min(),
120
+ alpha[1].max(),
121
+ )
122
+ # crop the foreground
123
+ fg = image[y1:y2, x1:x2]
124
+ # pad to square
125
+ size = max(fg.shape[0], fg.shape[1])
126
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
127
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
128
+ new_image = np.pad(
129
+ fg,
130
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
131
+ mode="constant",
132
+ constant_values=((0, 0), (0, 0), (0, 0)),
133
+ )
134
+
135
+ # compute padding according to the ratio
136
+ new_size = int(new_image.shape[0] / ratio)
137
+ # pad to size, double side
138
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
139
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
140
+ new_image = np.pad(
141
+ new_image,
142
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
143
+ mode="constant",
144
+ constant_values=((0, 0), (0, 0), (0, 0)),
145
+ )
146
+ new_image = Image.fromarray(new_image, mode="RGBA").resize(
147
+ (COND_WIDTH, COND_HEIGHT)
148
+ )
149
+ return new_image
150
+
151
+
152
+ def square_crop(input_image: Image) -> Image:
153
+ # Perform a center square crop
154
+ min_size = min(input_image.size)
155
+ left = (input_image.size[0] - min_size) // 2
156
+ top = (input_image.size[1] - min_size) // 2
157
+ right = (input_image.size[0] + min_size) // 2
158
+ bottom = (input_image.size[1] + min_size) // 2
159
+ return input_image.crop((left, top, right, bottom)).resize(
160
+ (COND_WIDTH, COND_HEIGHT)
161
+ )
162
+
163
+
164
+ def show_mask_img(input_image: Image) -> Image:
165
+ img_numpy = np.array(input_image)
166
+ alpha = img_numpy[:, :, 3] / 255.0
167
+ chkb = checkerboard(32, 512) * 255
168
+ new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
169
+ return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
170
+
171
+
172
+ def run_button(run_btn, input_image, background_state, foreground_ratio):
173
+ if run_btn == "Run":
174
+ glb_file: str = run_model(background_state)
175
+
176
+ return (
177
+ gr.update(),
178
+ gr.update(),
179
+ gr.update(),
180
+ gr.update(),
181
+ gr.update(value=glb_file, visible=True),
182
+ gr.update(visible=True),
183
+ )
184
+ elif run_btn == "Remove Background":
185
+ rem_removed = remove_background(input_image)
186
+
187
+ sqr_crop = square_crop(rem_removed)
188
+ fr_res = resize_foreground(sqr_crop, foreground_ratio)
189
+
190
+ return (
191
+ gr.update(value="Run", visible=True),
192
+ sqr_crop,
193
+ fr_res,
194
+ gr.update(value=show_mask_img(fr_res), visible=True),
195
+ gr.update(value=None, visible=False),
196
+ gr.update(visible=False),
197
+ )
198
+
199
+
200
+ def requires_bg_remove(image, fr):
201
+ if image is None:
202
+ return (
203
+ gr.update(visible=False, value="Run"),
204
+ None,
205
+ None,
206
+ gr.update(value=None, visible=False),
207
+ gr.update(visible=False),
208
+ gr.update(visible=False),
209
+ )
210
+ alpha_channel = np.array(image.getchannel("A"))
211
+ min_alpha = alpha_channel.min()
212
+
213
+ if min_alpha == 0:
214
+ print("Already has alpha")
215
+ sqr_crop = square_crop(image)
216
+ fr_res = resize_foreground(sqr_crop, fr)
217
+ return (
218
+ gr.update(value="Run", visible=True),
219
+ sqr_crop,
220
+ fr_res,
221
+ gr.update(value=show_mask_img(fr_res), visible=True),
222
+ gr.update(visible=False),
223
+ gr.update(visible=False),
224
+ )
225
+ return (
226
+ gr.update(value="Remove Background", visible=True),
227
+ None,
228
+ None,
229
+ gr.update(value=None, visible=False),
230
+ gr.update(visible=False),
231
+ gr.update(visible=False),
232
+ )
233
+
234
+
235
+ def update_foreground_ratio(img_proc, fr):
236
+ foreground_res = resize_foreground(img_proc, fr)
237
+ return (
238
+ foreground_res,
239
+ gr.update(value=show_mask_img(foreground_res)),
240
+ )
241
+
242
+
243
+ with gr.Blocks() as demo:
244
+ img_proc_state = gr.State()
245
+ background_remove_state = gr.State()
246
+ gr.Markdown("""
247
+ # SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
248
+
249
+ **SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
250
+ This demo allows you to upload an image and generate a 3D mesh model from it.
251
+
252
+ **Tips**
253
+ 1. If the image already has an alpha channel, you can skip the background removal step.
254
+ 2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
255
+ 3. You can upload your own HDR environment map to light the 3D model.
256
+ """)
257
+ with gr.Row(variant="panel"):
258
+ with gr.Column():
259
+ with gr.Row():
260
+ input_img = gr.Image(
261
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
262
+ )
263
+ preview_removal = gr.Image(
264
+ label="Preview Background Removal",
265
+ type="pil",
266
+ image_mode="RGB",
267
+ interactive=False,
268
+ visible=False,
269
+ )
270
+
271
+ foreground_ratio = gr.Slider(
272
+ label="Foreground Ratio",
273
+ minimum=0.5,
274
+ maximum=1.0,
275
+ value=0.85,
276
+ step=0.05,
277
+ )
278
+
279
+ foreground_ratio.change(
280
+ update_foreground_ratio,
281
+ inputs=[img_proc_state, foreground_ratio],
282
+ outputs=[background_remove_state, preview_removal],
283
+ )
284
+
285
+ run_btn = gr.Button("Run", variant="primary", visible=False)
286
+
287
+ with gr.Column():
288
+ output_3d = LitModel3D(
289
+ label="3D Model",
290
+ visible=False,
291
+ clear_color=[0.0, 0.0, 0.0, 0.0],
292
+ tonemapping="aces",
293
+ contrast=1.0,
294
+ scale=1.0,
295
+ )
296
+ with gr.Column(visible=False, scale=1.0) as hdr_row:
297
+ gr.Markdown("""## HDR Environment Map
298
+
299
+ Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
300
+ """)
301
+
302
+ with gr.Row():
303
+ hdr_illumination_file = gr.File(
304
+ label="HDR Env Map", file_types=[".hdr"], file_count="single"
305
+ )
306
+ example_hdris = [
307
+ os.path.join("demo_files/hdri", f)
308
+ for f in os.listdir("demo_files/hdri")
309
+ ]
310
+ hdr_illumination_example = gr.Examples(
311
+ examples=example_hdris,
312
+ inputs=hdr_illumination_file,
313
+ )
314
+
315
+ hdr_illumination_file.change(
316
+ lambda x: gr.update(env_map=x.name if x is not None else None),
317
+ inputs=hdr_illumination_file,
318
+ outputs=[output_3d],
319
+ )
320
+
321
+ examples = gr.Examples(
322
+ examples=example_files,
323
+ inputs=input_img,
324
+ )
325
+
326
+ input_img.change(
327
+ requires_bg_remove,
328
+ inputs=[input_img, foreground_ratio],
329
+ outputs=[
330
+ run_btn,
331
+ img_proc_state,
332
+ background_remove_state,
333
+ preview_removal,
334
+ output_3d,
335
+ hdr_row,
336
+ ],
337
+ )
338
+
339
+ run_btn.click(
340
+ run_button,
341
+ inputs=[
342
+ run_btn,
343
+ input_img,
344
+ background_remove_state,
345
+ foreground_ratio,
346
+ ],
347
+ outputs=[
348
+ run_btn,
349
+ img_proc_state,
350
+ background_remove_state,
351
+ preview_removal,
352
+ output_3d,
353
+ hdr_row,
354
+ ],
355
+ )
356
+
357
+ demo.launch()
demo_files/comp.gif ADDED

Git LFS Details

  • SHA256: 1d5e060d90f29889c55c1c5681dbeb4b4c2408709d18f7451bb0a6f02c6e9bc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
demo_files/examples/animal_character.png ADDED

Git LFS Details

  • SHA256: 5949f60c651e71a41b7291197f91bb8be2c8861472765fc884e604e18b7806a0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
demo_files/examples/animal_character_2.png ADDED

Git LFS Details

  • SHA256: ffc3f10c629afd64798d38dad2cc419eb343c7106149426f78634a91367bf031
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
demo_files/examples/axe.png ADDED

Git LFS Details

  • SHA256: 94be53862906806ac28367017cd9d794edf416df4d33c1bc223ef6f6eed3b39e
  • Pointer size: 131 Bytes
  • Size of remote file: 277 kB
demo_files/examples/chair1.png ADDED

Git LFS Details

  • SHA256: 2503c12a74419d91a4c6c9f1affc48fee6e2b8b9091956ca6211e91ada57b5bf
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
demo_files/examples/character1.png ADDED

Git LFS Details

  • SHA256: 39cccb99b31a614144a6d147f0e0a8d52b986d6e73587c6e697da07d0a7112f2
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
demo_files/examples/otter_samurai.png ADDED

Git LFS Details

  • SHA256: 3f3c68fa49d43908f18087cde98aba486da814e938bd59909fcb70d996e9f77b
  • Pointer size: 131 Bytes
  • Size of remote file: 980 kB
demo_files/examples/raccoon_wizard.png ADDED

Git LFS Details

  • SHA256: 32cc3850d9f48548882c7b148e508e8ab149bc4f363611e9739adcbd38e8b16d
  • Pointer size: 131 Bytes
  • Size of remote file: 774 kB
demo_files/examples/stylized-rocks.png ADDED

Git LFS Details

  • SHA256: 386c3be3a6f24ee52e13f130c1ebc02a1bc46eb2c0ebe90d79ce6f38751f0fc6
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
demo_files/examples/tree.png ADDED

Git LFS Details

  • SHA256: b258278b4d85a75f9ea3f795d3692fc58304a1a3a7daf8935a9549107bfee170
  • Pointer size: 131 Bytes
  • Size of remote file: 693 kB
demo_files/hdri/abandoned_tiled_room_1k.hdr ADDED
Binary file (478 kB). View file
 
demo_files/hdri/metro_noord_1k.hdr ADDED
Binary file (467 kB). View file
 
demo_files/hdri/neon_photostudio_1k.hdr ADDED
Binary file (438 kB). View file
 
demo_files/hdri/peppermint_powerplant_1k.hdr ADDED
Binary file (473 kB). View file
 
demo_files/hdri/rainforest_trail_1k.hdr ADDED
Binary file (512 kB). View file
 
demo_files/hdri/studio_small_08_1k.hdr ADDED
Binary file (412 kB). View file
 
demo_files/hdri/urban_alley_01_1k.hdr ADDED
Binary file (458 kB). View file
 
demo_files/scatterplot.jpg ADDED
demo_files/teaser.gif ADDED

Git LFS Details

  • SHA256: 1d5dcb4fbe710e94c0fa70cc2c783d66e327222cb5e74839cfd003e619bc2e1d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.81 MB
load/tets/160_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
+ size 15408790
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ einops==0.7.0
4
+ jaxtyping==0.2.31
5
+ omegaconf==2.3.0
6
+ transformers==4.42.3
7
+ slangtorch==1.2.2
8
+ open_clip_torch==2.24.0
9
+ trimesh==4.4.1
10
+ numpy==1.26.4
11
+ huggingface-hub==0.23.4
12
+ rembg[gpu]==2.0.57
13
+ gradio-litmodel3d==0.0.1
sf3d/box_uv_unwrap.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from sf3d.models.utils import dot, triangle_intersection_2d
10
+
11
+
12
+ def _box_assign_vertex_to_cube_face(
13
+ vertex_positions: Float[Tensor, "Nv 3"],
14
+ vertex_normals: Float[Tensor, "Nv 3"],
15
+ triangle_idxs: Integer[Tensor, "Nf 3"],
16
+ bbox: Float[Tensor, "2 3"],
17
+ ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
18
+ # Test to not have a scaled model to fit the space better
19
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
20
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
21
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
22
+
23
+ # Create a [0, 1] normalized vertex position
24
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
25
+ # And to [-1, 1]
26
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
27
+
28
+ # Get all vertex positions for each triangle
29
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
30
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
31
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
32
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
33
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
34
+
35
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
36
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
37
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
38
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
39
+
40
+ # Just average the normals per face
41
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
42
+
43
+ # Now decide based on the face normal in which box map we project
44
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
45
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
46
+
47
+ axis = torch.tensor(
48
+ [
49
+ [1, 0, 0], # 0
50
+ [-1, 0, 0], # 1
51
+ [0, 1, 0], # 2
52
+ [0, -1, 0], # 3
53
+ [0, 0, 1], # 4
54
+ [0, 0, -1], # 5
55
+ ],
56
+ device=face_normal.device,
57
+ dtype=face_normal.dtype,
58
+ )
59
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
60
+ index = face_normal_axis.argmax(-1)
61
+
62
+ max_axis, uc, vc = (
63
+ torch.ones_like(abs_x),
64
+ torch.zeros_like(tri_stack[..., :1]),
65
+ torch.zeros_like(tri_stack[..., :1]),
66
+ )
67
+ mask_pos_x = index == 0
68
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
69
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
70
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
71
+
72
+ mask_neg_x = index == 1
73
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
74
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
75
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
76
+
77
+ mask_pos_y = index == 2
78
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
79
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
80
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
81
+
82
+ mask_neg_y = index == 3
83
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
84
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
85
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
86
+
87
+ mask_pos_z = index == 4
88
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
89
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
90
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
91
+
92
+ mask_neg_z = index == 5
93
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
94
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
95
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
96
+
97
+ # UC from [-1, 1] to [0, 1]
98
+ max_dim_div = max_axis.max(dim=0, keepdims=True).values
99
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
100
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
101
+
102
+ uv = torch.stack([uc, vc], dim=-1)
103
+
104
+ return uv, index
105
+
106
+
107
+ def _assign_faces_uv_to_atlas_index(
108
+ vertex_positions: Float[Tensor, "Nv 3"],
109
+ triangle_idxs: Integer[Tensor, "Nf 3"],
110
+ face_uv: Float[Tensor, "Nf 3 2"],
111
+ face_index: Integer[Tensor, "Nf 3"],
112
+ ) -> Integer[Tensor, "Nf"]: # noqa: F821
113
+ triangle_pos = vertex_positions[triangle_idxs]
114
+ # We need to do perform 3 overlap checks.
115
+ # The first set is placed in the upper two thirds of the UV atlas.
116
+ # Conceptually, this is the direct visible surfaces from the each cube side
117
+ # The second set is placed in the lower thirds and the left half of the UV atlas.
118
+ # This is the first set of occluded surfaces. They will also be saved in the projected fashion
119
+ # The third pass finds all non assigned faces. They will be placed in the bottom right half of
120
+ # the UV atlas in scattered fashion.
121
+ assign_idx = face_index.clone()
122
+ for overlap_step in range(3):
123
+ overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
124
+ for i in range(overlap_step * 6, (overlap_step + 1) * 6):
125
+ mask = assign_idx == i
126
+ if not mask.any():
127
+ continue
128
+ # Get all elements belonging to the projection face
129
+ uv_triangle = face_uv[mask]
130
+ cur_triangle_pos = triangle_pos[mask]
131
+ # Find the center of the uv coordinates
132
+ center_uv = uv_triangle.mean(dim=1, keepdim=True)
133
+ # And also the radius of the triangle
134
+ uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
135
+
136
+ potentially_overlapping_mask = (
137
+ # Find all close triangles
138
+ (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
139
+ # Do not select the same element by offseting with an large valued identity matrix
140
+ + torch.eye(
141
+ uv_triangle.shape[0],
142
+ device=uv_triangle.device,
143
+ dtype=uv_triangle.dtype,
144
+ ).unsqueeze(-1)
145
+ * 1000
146
+ )
147
+ # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
148
+ potentially_overlapping_mask = (
149
+ potentially_overlapping_mask
150
+ <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
151
+ ).squeeze(-1)
152
+ overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
153
+
154
+ # Only unique triangles (A|B and B|A should be the same)
155
+ f = torch.min(overlap_coords, dim=-1).values
156
+ s = torch.max(overlap_coords, dim=-1).values
157
+ overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
158
+ first, second = overlap_coords.unbind(-1)
159
+
160
+ # Get the triangles
161
+ tri_1 = uv_triangle[first]
162
+ tri_2 = uv_triangle[second]
163
+
164
+ # Perform the actual set with the reduced number of potentially overlapping triangles
165
+ its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
166
+
167
+ # So we now need to detect which triangles are the occluded ones.
168
+ # We always assume the first to be the visible one (the others should move)
169
+ # In the previous step we use a lexigraphical sort to get the unique pairs
170
+ # In this we use a sort based on the orthographic projection
171
+ ax = 0 if i < 2 else 1 if i < 4 else 2
172
+ use_max = i % 2 == 1
173
+
174
+ tri1_c = cur_triangle_pos[first].mean(dim=1)
175
+ tri2_c = cur_triangle_pos[second].mean(dim=1)
176
+
177
+ mark_first = (
178
+ (tri1_c[..., ax] > tri2_c[..., ax])
179
+ if use_max
180
+ else (tri1_c[..., ax] < tri2_c[..., ax])
181
+ )
182
+ first[mark_first] = second[mark_first]
183
+
184
+ # Lastly the same index can be tested multiple times.
185
+ # If one marks it as overlapping we keep it marked as such.
186
+ # We do this by testing if it has been marked at least once.
187
+ unique_idx, rev_idx = torch.unique(first, return_inverse=True)
188
+
189
+ add = torch.zeros_like(unique_idx, dtype=torch.float32)
190
+ add.index_add_(0, rev_idx, its.float())
191
+ its_mask = add > 0
192
+
193
+ # And fill it in the overlapping indicator
194
+ idx = torch.where(mask)[0][unique_idx]
195
+ overlapping_indicator[idx] = its_mask
196
+
197
+ # Move the index to the overlap regions (shift by 6)
198
+ assign_idx[overlapping_indicator] += 6
199
+
200
+ # We do not care about the correct face placement after the first 2 slices
201
+ max_idx = 6 * 2
202
+ return assign_idx.clamp(0, max_idx)
203
+
204
+
205
+ def _find_slice_offset_and_scale(
206
+ index: Integer[Tensor, "Nf"], # noqa: F821
207
+ ) -> Tuple[
208
+ Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
209
+ ]: # noqa: F821
210
+ # 6 due to the 6 cube faces
211
+ off = 1 / 3
212
+ dupl_off = 1 / 6
213
+
214
+ # Here, we need to decide how to pack the textures in the case of overlap
215
+ def x_offset_calc(x, i):
216
+ offset_calc = i // 6
217
+ # Initial coordinates - just 3x2 grid
218
+ if offset_calc == 0:
219
+ return off * x
220
+ else:
221
+ # Smaller 3x2 grid plus eventual shift to right for
222
+ # second overlap
223
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
224
+
225
+ def y_offset_calc(x, i):
226
+ offset_calc = i // 6
227
+ # Initial coordinates - just a 3x2 grid
228
+ if offset_calc == 0:
229
+ return off * x
230
+ else:
231
+ # Smaller coordinates in the lowest row
232
+ return dupl_off * x + off * 2
233
+
234
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
235
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
236
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
237
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
238
+ for i in range(index.max().item() + 1):
239
+ mask = index == i
240
+ if not mask.any():
241
+ continue
242
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
243
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
244
+
245
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
246
+ # All overlap elements are saved in half scale
247
+ div_x[index >= 6] = 6
248
+ div_y = div_x.clone() # Same for y
249
+ # Except for the random overlaps
250
+ div_x[index >= 12] = 2
251
+ # But the random overlaps are saved in a large block in the lower thirds
252
+ div_y[index >= 12] = 3
253
+
254
+ return offset_x, offset_y, div_x, div_y
255
+
256
+
257
+ def rotation_flip_matrix_2d(
258
+ rad: float, flip_x: bool = False, flip_y: bool = False
259
+ ) -> Float[Tensor, "2 2"]:
260
+ cos = math.cos(rad)
261
+ sin = math.sin(rad)
262
+ rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
263
+ flip_mat = torch.tensor(
264
+ [
265
+ [-1 if flip_x else 1, 0],
266
+ [0, -1 if flip_y else 1],
267
+ ],
268
+ dtype=torch.float32,
269
+ )
270
+
271
+ return flip_mat @ rot_mat
272
+
273
+
274
+ def calculate_tangents(
275
+ vertex_positions: Float[Tensor, "Nv 3"],
276
+ vertex_normals: Float[Tensor, "Nv 3"],
277
+ triangle_idxs: Integer[Tensor, "Nf 3"],
278
+ face_uv: Float[Tensor, "Nf 3 2"],
279
+ ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
280
+ vn_idx = [None] * 3
281
+ pos = [None] * 3
282
+ tex = face_uv.unbind(1)
283
+ for i in range(0, 3):
284
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
285
+ # t_nrm_idx is always the same as t_pos_idx
286
+ vn_idx[i] = triangle_idxs[:, i]
287
+
288
+ tangents = torch.zeros_like(vertex_normals)
289
+ tansum = torch.zeros_like(vertex_normals)
290
+
291
+ # Compute tangent space for each triangle
292
+ duv1 = tex[1] - tex[0]
293
+ duv2 = tex[2] - tex[0]
294
+ dpos1 = pos[1] - pos[0]
295
+ dpos2 = pos[2] - pos[0]
296
+
297
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
298
+
299
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
300
+
301
+ # Avoid division by zero for degenerated texture coordinates
302
+ denom_safe = denom.clip(1e-6)
303
+ tang = tng_nom / denom_safe
304
+
305
+ # Update all 3 vertices
306
+ for i in range(0, 3):
307
+ idx = vn_idx[i][:, None].repeat(1, 3)
308
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
309
+ tansum.scatter_add_(
310
+ 0, idx, torch.ones_like(tang)
311
+ ) # tansum[n_i] = tansum[n_i] + 1
312
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
313
+ # triangles influence the tangent space more
314
+ tangents = tangents / tansum
315
+
316
+ # Normalize and make sure tangent is perpendicular to normal
317
+ tangents = F.normalize(tangents, dim=1)
318
+ tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
319
+
320
+ return tangents
321
+
322
+
323
+ def _rotate_uv_slices_consistent_space(
324
+ vertex_positions: Float[Tensor, "Nv 3"],
325
+ vertex_normals: Float[Tensor, "Nv 3"],
326
+ triangle_idxs: Integer[Tensor, "Nf 3"],
327
+ uv: Float[Tensor, "Nf 3 2"],
328
+ index: Integer[Tensor, "Nf"], # noqa: F821
329
+ ):
330
+ tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
331
+ pos_stack = torch.stack(
332
+ [
333
+ -vertex_positions[..., 1],
334
+ vertex_positions[..., 0],
335
+ torch.zeros_like(vertex_positions[..., 0]),
336
+ ],
337
+ dim=-1,
338
+ )
339
+ expected_tangents = F.normalize(
340
+ torch.linalg.cross(
341
+ vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
342
+ ),
343
+ -1,
344
+ )
345
+
346
+ actual_tangents = tangents[triangle_idxs]
347
+ expected_tangents = expected_tangents[triangle_idxs]
348
+
349
+ def rotation_matrix_2d(theta):
350
+ c, s = torch.cos(theta), torch.sin(theta)
351
+ return torch.tensor([[c, -s], [s, c]])
352
+
353
+ # Now find the rotation
354
+ index_mod = index % 6 # Shouldn't happen. Just for safety
355
+ for i in range(6):
356
+ mask = index_mod == i
357
+ if not mask.any():
358
+ continue
359
+
360
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
361
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
362
+
363
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
364
+ cross_product = (
365
+ actual_mean_tangent[0] * expected_mean_tangent[1]
366
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
367
+ )
368
+ angle = torch.atan2(cross_product, dot_product)
369
+
370
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
371
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
372
+ uv_cur = uv[mask] * 2 - 1 # Center it first
373
+ # Rotate it
374
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
375
+
376
+ # Rescale uv[mask] to be within the 0-1 range
377
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
378
+
379
+ return uv
380
+
381
+
382
+ def _handle_slice_uvs(
383
+ uv: Float[Tensor, "Nf 3 2"],
384
+ index: Integer[Tensor, "Nf"], # noqa: F821
385
+ island_padding: float,
386
+ max_index: int = 6 * 2,
387
+ ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
388
+ uc, vc = uv.unbind(-1)
389
+
390
+ # Get the second slice (The first overlap)
391
+ index_filter = [index == i for i in range(6, max_index)]
392
+
393
+ # Normalize them to always fully fill the atlas patch
394
+ for i, fi in enumerate(index_filter):
395
+ if fi.sum() > 0:
396
+ # Scale the slice but only up to a factor of 2
397
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
398
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
399
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
400
+
401
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
402
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
403
+
404
+ return torch.stack([uc_padded, vc_padded], dim=-1)
405
+
406
+
407
+ def _handle_remaining_uvs(
408
+ uv: Float[Tensor, "Nf 3 2"],
409
+ index: Integer[Tensor, "Nf"], # noqa: F821
410
+ island_padding: float,
411
+ ) -> Float[Tensor, "Nf 3 2"]:
412
+ uc, vc = uv.unbind(-1)
413
+ # Get all remaining elements
414
+ remaining_filter = index >= 6 * 2
415
+ squares_left = remaining_filter.sum()
416
+
417
+ if squares_left == 0:
418
+ return uv
419
+
420
+ uc = uc[remaining_filter]
421
+ vc = vc[remaining_filter]
422
+
423
+ # Or remaining triangles are distributed in a rectangle
424
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
425
+ ratio = 0.5 * (1 / 3) # 1.5
426
+ # sqrt(744/(0.5*(1/3)))
427
+
428
+ mult = math.sqrt(squares_left / ratio)
429
+ num_square_width = int(math.ceil(0.5 * mult))
430
+ num_square_height = int(math.ceil(squares_left / num_square_width))
431
+
432
+ width = 1 / num_square_width
433
+ height = 1 / num_square_height
434
+
435
+ # The idea is again to keep the texture resolution consistent with the first slice
436
+ # This only occupys half the region in the texture chart but the scaling on the squares
437
+ # assumes full coverage.
438
+ clip_val = min(width, height) * 1.5
439
+ # Now normalize the UVs with taking into account the maximum scaling
440
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
441
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
442
+ ).clip(clip_val)
443
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
444
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
445
+ ).clip(clip_val)
446
+ # Add a small padding
447
+ uc = (
448
+ uc * (1 - island_padding * num_square_width * 0.5)
449
+ + island_padding * num_square_width * 0.25
450
+ ).clip(0, 1)
451
+ vc = (
452
+ vc * (1 - island_padding * num_square_height * 0.5)
453
+ + island_padding * num_square_height * 0.25
454
+ ).clip(0, 1)
455
+
456
+ uc = uc * width
457
+ vc = vc * height
458
+
459
+ # And calculate offsets for each element
460
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
461
+ x_idx = idx % num_square_width
462
+ y_idx = idx // num_square_width
463
+ # And move each triangle to its own spot
464
+ uc = uc + x_idx[:, None] * width
465
+ vc = vc + y_idx[:, None] * height
466
+
467
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
468
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
469
+
470
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
471
+
472
+ return uv
473
+
474
+
475
+ def _distribute_individual_uvs_in_atlas(
476
+ face_uv: Float[Tensor, "Nf 3 2"],
477
+ assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
478
+ offset_x: Float[Tensor, "Nf"], # noqa: F821
479
+ offset_y: Float[Tensor, "Nf"], # noqa: F821
480
+ div_x: Float[Tensor, "Nf"], # noqa: F821
481
+ div_y: Float[Tensor, "Nf"], # noqa: F821
482
+ island_padding: float,
483
+ ):
484
+ # Place the slice first
485
+ placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
486
+ # Then handle the remaining overlap elements
487
+ placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
488
+
489
+ uc, vc = placed_uv.unbind(-1)
490
+ uc = uc / div_x[:, None] + offset_x[:, None]
491
+ vc = vc / div_y[:, None] + offset_y[:, None]
492
+
493
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
494
+
495
+ return uv
496
+
497
+
498
+ def _get_unique_face_uv(
499
+ uv: Float[Tensor, "Nf 3 2"],
500
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
501
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
502
+ # And add the face to uv index mapping
503
+ vtex_idx = unique_idx.view(-1, 3)
504
+
505
+ return unique_uv, vtex_idx
506
+
507
+
508
+ def _align_mesh_with_main_axis(
509
+ vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
510
+ ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
511
+ # Use pca to find the 2 main axis (third is derived by cross product)
512
+ # Set the random seed so it's repeatable
513
+ torch.manual_seed(0)
514
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
515
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
516
+
517
+ main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
518
+ # Orthogonalize the second axis
519
+ seconday_axis: Float[Tensor, "3"] = F.normalize(
520
+ seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
521
+ )
522
+ # Create perpendicular third axis
523
+ third_axis: Float[Tensor, "3"] = F.normalize(
524
+ torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
525
+ )
526
+
527
+ # Check to which canonical axis each aligns
528
+ main_axis_max_idx = main_axis.abs().argmax().item()
529
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
530
+ third_axis_max_idx = third_axis.abs().argmax().item()
531
+
532
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
533
+ # If two axes have the same argmax move one of them
534
+ all_possible_axis = {0, 1, 2}
535
+ cur_index = 1
536
+ while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
537
+ # Find missing axis
538
+ missing_axis = all_possible_axis - set(
539
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
540
+ )
541
+ missing_axis = missing_axis.pop()
542
+ # Just assign it to third axis as it had the smallest contribution to the
543
+ # overall shape
544
+ if cur_index == 1:
545
+ third_axis_max_idx = missing_axis
546
+ elif cur_index == 2:
547
+ seconday_axis_max_idx = missing_axis
548
+ else:
549
+ raise ValueError("Could not find 3 unique axis")
550
+ cur_index += 1
551
+
552
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
553
+ raise ValueError("Could not find 3 unique axis")
554
+
555
+ axes = [None] * 3
556
+ axes[main_axis_max_idx] = main_axis
557
+ axes[seconday_axis_max_idx] = seconday_axis
558
+ axes[third_axis_max_idx] = third_axis
559
+ # Create rotation matrix from the individual axes
560
+ rot_mat = torch.stack(axes, dim=1).T
561
+
562
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
563
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
564
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
565
+
566
+ return vertex_positions, vertex_normals
567
+
568
+
569
+ def box_projection_uv_unwrap(
570
+ vertex_positions: Float[Tensor, "Nv 3"],
571
+ vertex_normals: Float[Tensor, "Nv 3"],
572
+ triangle_idxs: Integer[Tensor, "Nf 3"],
573
+ island_padding: float,
574
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
575
+ # Align the mesh with main axis directions first
576
+ vertex_positions, vertex_normals = _align_mesh_with_main_axis(
577
+ vertex_positions, vertex_normals
578
+ )
579
+
580
+ bbox: Float[Tensor, "2 3"] = torch.stack(
581
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
582
+ )
583
+ # First decide in which cube face the triangle is placed
584
+ face_uv, face_index = _box_assign_vertex_to_cube_face(
585
+ vertex_positions, vertex_normals, triangle_idxs, bbox
586
+ )
587
+
588
+ # Rotate the UV islands in a way that they align with the radial z tangent space
589
+ face_uv = _rotate_uv_slices_consistent_space(
590
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
591
+ )
592
+
593
+ # Then find where where the face is placed in the atlas.
594
+ # This has to detect potential overlaps
595
+ assigned_atlas_index = _assign_faces_uv_to_atlas_index(
596
+ vertex_positions, triangle_idxs, face_uv, face_index
597
+ )
598
+
599
+ # Then figure out the final place in the atlas based on the assignment
600
+ offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
601
+ assigned_atlas_index
602
+ )
603
+
604
+ # Next distribute the faces in the uv atlas
605
+ placed_uv = _distribute_individual_uvs_in_atlas(
606
+ face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
607
+ )
608
+
609
+ # And get the unique per-triangle UV coordinates
610
+ return _get_unique_face_uv(placed_uv)
sf3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sf3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
sf3d/models/global_estimator/multi_head_estimator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import torch.nn as nn
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from sf3d.models.network import get_activation
9
+ from sf3d.models.utils import BaseModule
10
+
11
+
12
+ @dataclass
13
+ class HeadSpec:
14
+ name: str
15
+ out_channels: int
16
+ n_hidden_layers: int
17
+ output_activation: Optional[str] = None
18
+ output_bias: float = 0.0
19
+ add_to_decoder_features: bool = False
20
+ shape: Optional[list[int]] = None
21
+
22
+
23
+ class MultiHeadEstimator(BaseModule):
24
+ @dataclass
25
+ class Config(BaseModule.Config):
26
+ triplane_features: int = 1024
27
+
28
+ n_layers: int = 2
29
+ hidden_features: int = 512
30
+ activation: str = "relu"
31
+
32
+ pool: str = "max"
33
+ # Literal["mean", "max"] = "mean" # noqa: F821
34
+
35
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
36
+
37
+ cfg: Config
38
+
39
+ def configure(self):
40
+ layers = []
41
+ cur_features = self.cfg.triplane_features * 3
42
+ for _ in range(self.cfg.n_layers):
43
+ layers.append(
44
+ nn.Conv2d(
45
+ cur_features,
46
+ self.cfg.hidden_features,
47
+ kernel_size=3,
48
+ padding=0,
49
+ stride=2,
50
+ )
51
+ )
52
+ layers.append(self.make_activation(self.cfg.activation))
53
+
54
+ cur_features = self.cfg.hidden_features
55
+
56
+ self.layers = nn.Sequential(*layers)
57
+
58
+ assert len(self.cfg.heads) > 0
59
+ heads = {}
60
+ for head in self.cfg.heads:
61
+ head_layers = []
62
+ for i in range(head.n_hidden_layers):
63
+ head_layers += [
64
+ nn.Linear(
65
+ self.cfg.hidden_features,
66
+ self.cfg.hidden_features,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ head_layers += [
71
+ nn.Linear(
72
+ self.cfg.hidden_features,
73
+ head.out_channels,
74
+ ),
75
+ ]
76
+ heads[head.name] = nn.Sequential(*head_layers)
77
+ self.heads = nn.ModuleDict(heads)
78
+
79
+ def make_activation(self, activation):
80
+ if activation == "relu":
81
+ return nn.ReLU(inplace=True)
82
+ elif activation == "silu":
83
+ return nn.SiLU(inplace=True)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(
88
+ self,
89
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
90
+ ) -> dict[str, Any]:
91
+ x = self.layers(
92
+ triplane.reshape(
93
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
94
+ )
95
+ )
96
+
97
+ if self.cfg.pool == "max":
98
+ x = x.amax(dim=[-2, -1])
99
+ elif self.cfg.pool == "mean":
100
+ x = x.mean(dim=[-2, -1])
101
+ else:
102
+ raise NotImplementedError
103
+
104
+ out = {
105
+ ("decoder_" if head.add_to_decoder_features else "")
106
+ + head.name: get_activation(head.output_activation)(
107
+ self.heads[head.name](x) + head.output_bias
108
+ )
109
+ for head in self.cfg.heads
110
+ }
111
+ for head in self.cfg.heads:
112
+ if head.shape:
113
+ head_name = (
114
+ "decoder_" if head.add_to_decoder_features else ""
115
+ ) + head.name
116
+ out[head_name] = out[head_name].reshape(*head.shape)
117
+
118
+ return out
sf3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import open_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from sf3d.models.network import get_activation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ @dataclass
16
+ class HeadSpec:
17
+ name: str
18
+ out_channels: int
19
+ n_hidden_layers: int
20
+ output_activation: Optional[str] = None
21
+ output_bias: float = 0.0
22
+ add_to_decoder_features: bool = False
23
+ shape: Optional[list[int]] = None
24
+
25
+
26
+ class ClipBasedHeadEstimator(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ model: str = "ViT-B-32"
30
+ pretrain: str = "laion2b_s34b_b79k"
31
+
32
+ distribution: str = "beta"
33
+
34
+ # ["mean", "mode", "sample", "sample_mean"]
35
+ distribution_eval: str = "mode"
36
+
37
+ activation: str = "relu"
38
+ hidden_features: int = 512
39
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
40
+
41
+ cfg: Config
42
+
43
+ def configure(self):
44
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
+ self.cfg.model, pretrained=self.cfg.pretrain
46
+ )
47
+ self.model.eval()
48
+
49
+ # Do not add the weights in self.model to the optimizer
50
+ for param in self.model.parameters():
51
+ param.requires_grad = False
52
+
53
+ assert len(self.cfg.heads) > 0
54
+ heads = {}
55
+ for head in self.cfg.heads:
56
+ head_layers = []
57
+
58
+ for i in range(head.n_hidden_layers):
59
+ head_layers += [
60
+ nn.Linear(
61
+ self.cfg.hidden_features,
62
+ self.cfg.hidden_features,
63
+ ),
64
+ self.make_activation(self.cfg.activation),
65
+ ]
66
+
67
+ head_layers = [nn.Sequential(*head_layers)]
68
+ head_layers += [
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ self.cfg.hidden_features,
72
+ self.cfg.hidden_features,
73
+ ),
74
+ self.make_activation(self.cfg.activation),
75
+ nn.Linear(self.cfg.hidden_features, 1),
76
+ )
77
+ for _ in range(2)
78
+ ]
79
+ heads[head.name] = nn.ModuleList(head_layers)
80
+ self.heads = nn.ModuleDict(heads)
81
+
82
+ def make_activation(self, activation):
83
+ if activation == "relu":
84
+ return nn.ReLU(inplace=True)
85
+ elif activation == "silu":
86
+ return nn.SiLU(inplace=True)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ def forward(
91
+ self,
92
+ cond_image: Float[Tensor, "B 1 H W 3"],
93
+ sample: bool = True,
94
+ ) -> dict[str, Any]:
95
+ # Run the model
96
+ # Resize cond_image to 224
97
+ cond_image = nn.functional.interpolate(
98
+ cond_image.flatten(0, 1).permute(0, 3, 1, 2),
99
+ size=(224, 224),
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+ cond_image = Normalize(
104
+ mean=open_clip.constants.OPENAI_DATASET_MEAN,
105
+ std=open_clip.constants.OPENAI_DATASET_STD,
106
+ )(cond_image)
107
+ image_features = self.model.encode_image(cond_image)
108
+
109
+ # Run the heads
110
+ outputs = {}
111
+
112
+ for head_dict in self.cfg.heads:
113
+ head_name = head_dict.name
114
+ shared_head, d1_h, d2_h = self.heads[head_name]
115
+ shared_features = shared_head(image_features)
116
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
117
+ if self.cfg.distribution == "normal":
118
+ mean = d1
119
+ var = d2
120
+ if mean.shape[-1] == 1:
121
+ outputs[head_name] = torch.distributions.Normal(
122
+ mean + head_dict.output_bias,
123
+ torch.nn.functional.softplus(var),
124
+ )
125
+ else:
126
+ outputs[head_name] = torch.distributions.MultivariateNormal(
127
+ mean + head_dict.output_bias,
128
+ torch.nn.functional.softplus(var).diag_embed(),
129
+ )
130
+ elif self.cfg.distribution == "beta":
131
+ outputs[head_name] = torch.distributions.Beta(
132
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
133
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
134
+ )
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ if sample:
139
+ for head_dict in self.cfg.heads:
140
+ head_name = head_dict.name
141
+ dist = outputs[head_name]
142
+
143
+ if self.cfg.distribution_eval == "mean":
144
+ out = dist.mean
145
+ elif self.cfg.distribution_eval == "mode":
146
+ out = dist.mode
147
+ elif self.cfg.distribution_eval == "sample_mean":
148
+ out = dist.sample([10]).mean(-1)
149
+ else:
150
+ # use rsample if gradient is needed
151
+ out = dist.rsample() if self.training else dist.sample()
152
+
153
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
154
+ outputs[f"{head_name}_dist"] = dist
155
+
156
+ for head in self.cfg.heads:
157
+ if head.shape:
158
+ if not sample:
159
+ raise ValueError(
160
+ "Cannot reshape non-sampled probabilisitic outputs"
161
+ )
162
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
163
+
164
+ if head.add_to_decoder_features:
165
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
166
+ del outputs[head.name]
167
+
168
+ return outputs
sf3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
sf3d/models/mesh.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from jaxtyping import Float, Integer
8
+ from torch import Tensor
9
+
10
+ from sf3d.box_uv_unwrap import box_projection_uv_unwrap
11
+ from sf3d.models.utils import dot
12
+
13
+
14
+ class Mesh:
15
+ def __init__(
16
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
17
+ ) -> None:
18
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
19
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
20
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
21
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
22
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
23
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
24
+ self.extras: Dict[str, Any] = {}
25
+ for k, v in kwargs.items():
26
+ self.add_extra(k, v)
27
+
28
+ def add_extra(self, k, v) -> None:
29
+ self.extras[k] = v
30
+
31
+ @property
32
+ def requires_grad(self):
33
+ return self.v_pos.requires_grad
34
+
35
+ @property
36
+ def v_nrm(self):
37
+ if self._v_nrm is None:
38
+ self._v_nrm = self._compute_vertex_normal()
39
+ return self._v_nrm
40
+
41
+ @property
42
+ def v_tng(self):
43
+ if self._v_tng is None:
44
+ self._v_tng = self._compute_vertex_tangent()
45
+ return self._v_tng
46
+
47
+ @property
48
+ def v_tex(self):
49
+ if self._v_tex is None:
50
+ self.unwrap_uv()
51
+ return self._v_tex
52
+
53
+ @property
54
+ def edges(self):
55
+ if self._edges is None:
56
+ self._edges = self._compute_edges()
57
+ return self._edges
58
+
59
+ def _compute_vertex_normal(self):
60
+ i0 = self.t_pos_idx[:, 0]
61
+ i1 = self.t_pos_idx[:, 1]
62
+ i2 = self.t_pos_idx[:, 2]
63
+
64
+ v0 = self.v_pos[i0, :]
65
+ v1 = self.v_pos[i1, :]
66
+ v2 = self.v_pos[i2, :]
67
+
68
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
69
+
70
+ # Splat face normals to vertices
71
+ v_nrm = torch.zeros_like(self.v_pos)
72
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
73
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
74
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
75
+
76
+ # Normalize, replace zero (degenerated) normals with some default value
77
+ v_nrm = torch.where(
78
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
79
+ )
80
+ v_nrm = F.normalize(v_nrm, dim=1)
81
+
82
+ if torch.is_anomaly_enabled():
83
+ assert torch.all(torch.isfinite(v_nrm))
84
+
85
+ return v_nrm
86
+
87
+ def _compute_vertex_tangent(self):
88
+ vn_idx = [None] * 3
89
+ pos = [None] * 3
90
+ tex = [None] * 3
91
+ for i in range(0, 3):
92
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
93
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
94
+ # t_nrm_idx is always the same as t_pos_idx
95
+ vn_idx[i] = self.t_pos_idx[:, i]
96
+
97
+ tangents = torch.zeros_like(self.v_nrm)
98
+ tansum = torch.zeros_like(self.v_nrm)
99
+
100
+ # Compute tangent space for each triangle
101
+ duv1 = tex[1] - tex[0]
102
+ duv2 = tex[2] - tex[0]
103
+ dpos1 = pos[1] - pos[0]
104
+ dpos2 = pos[2] - pos[0]
105
+
106
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
107
+
108
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
109
+
110
+ # Avoid division by zero for degenerated texture coordinates
111
+ denom_safe = denom.clip(1e-6)
112
+ tang = tng_nom / denom_safe
113
+
114
+ # Update all 3 vertices
115
+ for i in range(0, 3):
116
+ idx = vn_idx[i][:, None].repeat(1, 3)
117
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
118
+ tansum.scatter_add_(
119
+ 0, idx, torch.ones_like(tang)
120
+ ) # tansum[n_i] = tansum[n_i] + 1
121
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
122
+ # triangles influence the tangent space more
123
+ tangents = tangents / tansum
124
+
125
+ # Normalize and make sure tangent is perpendicular to normal
126
+ tangents = F.normalize(tangents, dim=1)
127
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
128
+
129
+ if torch.is_anomaly_enabled():
130
+ assert torch.all(torch.isfinite(tangents))
131
+
132
+ return tangents
133
+
134
+ @torch.no_grad()
135
+ def unwrap_uv(
136
+ self,
137
+ island_padding: float = 0.02,
138
+ ) -> Mesh:
139
+ uv, indices = box_projection_uv_unwrap(
140
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
141
+ )
142
+
143
+ # Do store per vertex UVs.
144
+ # This means we need to duplicate some vertices at the seams
145
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
146
+ individual_faces = torch.arange(
147
+ individual_vertices.shape[0],
148
+ device=individual_vertices.device,
149
+ dtype=self.t_pos_idx.dtype,
150
+ ).reshape(-1, 3)
151
+ uv_flat = uv[indices].reshape((-1, 2))
152
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
153
+
154
+ self.v_pos = individual_vertices
155
+ self.t_pos_idx = individual_faces
156
+ self._v_tex = uv_flat
157
+ self._v_nrm = self._compute_vertex_normal()
158
+ self._v_tng = self._compute_vertex_tangent()
159
+
160
+ def _compute_edges(self):
161
+ # Compute edges
162
+ edges = torch.cat(
163
+ [
164
+ self.t_pos_idx[:, [0, 1]],
165
+ self.t_pos_idx[:, [1, 2]],
166
+ self.t_pos_idx[:, [2, 0]],
167
+ ],
168
+ dim=0,
169
+ )
170
+ edges = edges.sort()[0]
171
+ edges = torch.unique(edges, dim=0)
172
+ return edges
sf3d/models/network.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.autograd import Function
11
+ from torch.cuda.amp import custom_bwd, custom_fwd
12
+
13
+ from sf3d.models.utils import BaseModule, normalize
14
+
15
+
16
+ class PixelShuffleUpsampleNetwork(BaseModule):
17
+ @dataclass
18
+ class Config(BaseModule.Config):
19
+ in_channels: int = 1024
20
+ out_channels: int = 40
21
+ scale_factor: int = 4
22
+
23
+ conv_layers: int = 4
24
+ conv_kernel_size: int = 3
25
+
26
+ cfg: Config
27
+
28
+ def configure(self) -> None:
29
+ layers = []
30
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
31
+
32
+ in_channels = self.cfg.in_channels
33
+ for i in range(self.cfg.conv_layers):
34
+ cur_out_channels = (
35
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
36
+ )
37
+ layers.append(
38
+ nn.Conv2d(
39
+ in_channels,
40
+ cur_out_channels,
41
+ self.cfg.conv_kernel_size,
42
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
43
+ )
44
+ )
45
+ if i != self.cfg.conv_layers - 1:
46
+ layers.append(nn.ReLU(inplace=True))
47
+
48
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
49
+
50
+ self.upsample = nn.Sequential(*layers)
51
+
52
+ def forward(
53
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
54
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
55
+ return rearrange(
56
+ self.upsample(
57
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
58
+ ),
59
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
60
+ Np=3,
61
+ )
62
+
63
+
64
+ class _TruncExp(Function): # pylint: disable=abstract-method
65
+ # Implementation from torch-ngp:
66
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
67
+ @staticmethod
68
+ @custom_fwd(cast_inputs=torch.float32)
69
+ def forward(ctx, x): # pylint: disable=arguments-differ
70
+ ctx.save_for_backward(x)
71
+ return torch.exp(x)
72
+
73
+ @staticmethod
74
+ @custom_bwd
75
+ def backward(ctx, g): # pylint: disable=arguments-differ
76
+ x = ctx.saved_tensors[0]
77
+ return g * torch.exp(torch.clamp(x, max=15))
78
+
79
+
80
+ trunc_exp = _TruncExp.apply
81
+
82
+
83
+ def get_activation(name) -> Callable:
84
+ if name is None:
85
+ return lambda x: x
86
+ name = name.lower()
87
+ if name == "none" or name == "linear" or name == "identity":
88
+ return lambda x: x
89
+ elif name == "lin2srgb":
90
+ return lambda x: torch.where(
91
+ x > 0.0031308,
92
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
93
+ 12.92 * x,
94
+ ).clamp(0.0, 1.0)
95
+ elif name == "exp":
96
+ return lambda x: torch.exp(x)
97
+ elif name == "shifted_exp":
98
+ return lambda x: torch.exp(x - 1.0)
99
+ elif name == "trunc_exp":
100
+ return trunc_exp
101
+ elif name == "shifted_trunc_exp":
102
+ return lambda x: trunc_exp(x - 1.0)
103
+ elif name == "sigmoid":
104
+ return lambda x: torch.sigmoid(x)
105
+ elif name == "tanh":
106
+ return lambda x: torch.tanh(x)
107
+ elif name == "shifted_softplus":
108
+ return lambda x: F.softplus(x - 1.0)
109
+ elif name == "scale_-11_01":
110
+ return lambda x: x * 0.5 + 0.5
111
+ elif name == "negative":
112
+ return lambda x: -x
113
+ elif name == "normalize_channel_last":
114
+ return lambda x: normalize(x)
115
+ elif name == "normalize_channel_first":
116
+ return lambda x: normalize(x, dim=1)
117
+ else:
118
+ try:
119
+ return getattr(F, name)
120
+ except AttributeError:
121
+ raise ValueError(f"Unknown activation function: {name}")
122
+
123
+
124
+ @dataclass
125
+ class HeadSpec:
126
+ name: str
127
+ out_channels: int
128
+ n_hidden_layers: int
129
+ output_activation: Optional[str] = None
130
+ out_bias: float = 0.0
131
+
132
+
133
+ class MaterialMLP(BaseModule):
134
+ @dataclass
135
+ class Config(BaseModule.Config):
136
+ in_channels: int = 120
137
+ n_neurons: int = 64
138
+ activation: str = "silu"
139
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
140
+
141
+ cfg: Config
142
+
143
+ def configure(self) -> None:
144
+ assert len(self.cfg.heads) > 0
145
+ heads = {}
146
+ for head in self.cfg.heads:
147
+ head_layers = []
148
+ for i in range(head.n_hidden_layers):
149
+ head_layers += [
150
+ nn.Linear(
151
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
152
+ self.cfg.n_neurons,
153
+ ),
154
+ self.make_activation(self.cfg.activation),
155
+ ]
156
+ head_layers += [
157
+ nn.Linear(
158
+ self.cfg.n_neurons,
159
+ head.out_channels,
160
+ ),
161
+ ]
162
+ heads[head.name] = nn.Sequential(*head_layers)
163
+ self.heads = nn.ModuleDict(heads)
164
+
165
+ def make_activation(self, activation):
166
+ if activation == "relu":
167
+ return nn.ReLU(inplace=True)
168
+ elif activation == "silu":
169
+ return nn.SiLU(inplace=True)
170
+ else:
171
+ raise NotImplementedError
172
+
173
+ def keys(self):
174
+ return self.heads.keys()
175
+
176
+ def forward(
177
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
178
+ ):
179
+ if include is not None and exclude is not None:
180
+ raise ValueError("Cannot specify both include and exclude.")
181
+ if include is not None:
182
+ heads = [h for h in self.cfg.heads if h.name in include]
183
+ elif exclude is not None:
184
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
185
+ else:
186
+ heads = self.cfg.heads
187
+
188
+ out = {
189
+ head.name: get_activation(head.output_activation)(
190
+ self.heads[head.name](x) + head.out_bias
191
+ )
192
+ for head in heads
193
+ }
194
+
195
+ return out
sf3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
sf3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from sf3d.models.transformers.attention import Modulation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
+ if modulation_cond is not None
86
+ else None,
87
+ )
88
+ local_features = out.last_hidden_state
89
+ local_features = local_features.permute(0, 2, 1)
90
+ local_features = rearrange(
91
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
+ )
93
+ if packed:
94
+ local_features = local_features.squeeze(1)
95
+
96
+ return local_features
97
+
98
+ def detokenize(self, *args, **kwargs):
99
+ raise NotImplementedError
sf3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )
sf3d/models/transformers/attention.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Modulation(nn.Module):
6
+ def __init__(
7
+ self,
8
+ embedding_dim: int,
9
+ condition_dim: int,
10
+ zero_init: bool = False,
11
+ single_layer: bool = False,
12
+ ):
13
+ super().__init__()
14
+ self.silu = nn.SiLU()
15
+ if single_layer:
16
+ self.linear1 = nn.Identity()
17
+ else:
18
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
19
+
20
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
21
+
22
+ # Only zero init the last linear layer
23
+ if zero_init:
24
+ nn.init.zeros_(self.linear2.weight)
25
+ nn.init.zeros_(self.linear2.bias)
26
+
27
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
28
+ emb = self.linear2(self.silu(self.linear1(condition)))
29
+ scale, shift = torch.chunk(emb, 2, dim=1)
30
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+ return x
sf3d/models/transformers/backbone.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sf3d.models.utils import BaseModule
9
+
10
+
11
+ class GEGLU(nn.Module):
12
+ r"""
13
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
14
+
15
+ Parameters:
16
+ dim_in (`int`): The number of channels in the input.
17
+ dim_out (`int`): The number of channels in the output.
18
+ """
19
+
20
+ def __init__(self, dim_in: int, dim_out: int):
21
+ super().__init__()
22
+ self.proj = nn.Linear(dim_in, dim_out * 2)
23
+
24
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
25
+ if gate.device.type != "mps":
26
+ return F.gelu(gate)
27
+ # mps: gelu is not implemented for float16
28
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
29
+
30
+ def forward(self, hidden_states, scale: float = 1.0):
31
+ args = ()
32
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
33
+ return hidden_states * self.gelu(gate)
34
+
35
+
36
+ class CrossAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ kv_dim=None,
41
+ num_heads=16,
42
+ qkv_bias=False,
43
+ attn_drop=0.0,
44
+ proj_drop=0.0,
45
+ ):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+ kv_dim = dim if not kv_dim else kv_dim
51
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
52
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
53
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
54
+ self.attn_drop = attn_drop
55
+ self.proj = nn.Linear(dim, dim)
56
+ self.proj_drop = nn.Dropout(proj_drop)
57
+
58
+ def forward(self, x_q, x_kv):
59
+ B, N_q, C = x_q.shape
60
+ B, N_kv, _ = x_kv.shape
61
+ # [B, N_q, C] -> [B, N_q, H, C/H]
62
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
63
+ # [B, N_kv, C] -> [B, N_kv, H, C/H]
64
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
65
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
66
+
67
+ # attention
68
+ x = torch.nn.functional.scaled_dot_product_attention(
69
+ q.permute(0, 2, 1, 3),
70
+ k.permute(0, 2, 1, 3),
71
+ v.permute(0, 2, 1, 3),
72
+ attn_mask=None,
73
+ dropout_p=self.attn_drop,
74
+ scale=self.scale,
75
+ ).permute(0, 2, 1, 3)
76
+
77
+ # [B, N_q, H, C/H] -> [B, N_q, C]
78
+ x = x.reshape(B, N_q, C)
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
84
+ class FeedForward(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: Optional[int] = None,
89
+ mult: int = 4,
90
+ dropout: float = 0.0,
91
+ ):
92
+ super().__init__()
93
+ inner_dim = int(dim * mult)
94
+ dim_out = dim_out if dim_out is not None else dim
95
+ act_fn = GEGLU(dim, inner_dim)
96
+ self.net = nn.ModuleList([])
97
+ self.net.append(act_fn)
98
+ self.net.append(nn.Dropout(dropout))
99
+ self.net.append(nn.Linear(inner_dim, dim_out))
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ for module in self.net:
103
+ x = module(x)
104
+ return x
105
+
106
+
107
+ class BasicBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ kv_dim: Optional[int] = None,
112
+ num_heads: int = 16,
113
+ qkv_bias: bool = False,
114
+ attn_drop: float = 0.0,
115
+ proj_drop: float = 0.0,
116
+ ff_drop: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.norm1 = nn.LayerNorm(dim)
120
+ self.attn1 = CrossAttention(
121
+ dim,
122
+ kv_dim=dim,
123
+ num_heads=num_heads,
124
+ qkv_bias=qkv_bias,
125
+ attn_drop=attn_drop,
126
+ proj_drop=proj_drop,
127
+ )
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ self.attn2 = CrossAttention(
130
+ dim,
131
+ kv_dim=kv_dim,
132
+ num_heads=num_heads,
133
+ qkv_bias=qkv_bias,
134
+ attn_drop=attn_drop,
135
+ proj_drop=proj_drop,
136
+ )
137
+ self.norm3 = nn.LayerNorm(dim)
138
+ self.ff = FeedForward(dim, dropout=ff_drop)
139
+
140
+ def forward(self, z, x):
141
+ z_norm = self.norm1(z)
142
+ z = z + self.attn1(z_norm, z_norm)
143
+ # TODO: do we need to have the second attention when x is None?
144
+ z_norm = self.norm2(z)
145
+ z = z + self.attn2(z_norm, x if x is not None else z_norm)
146
+ z_norm = self.norm3(z)
147
+ z = z + self.ff(z_norm)
148
+ return z
149
+
150
+
151
+ class SingleStreamTransformer(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ num_attention_heads: int = 16
155
+ attention_head_dim: int = 88
156
+ in_channels: Optional[int] = None
157
+ out_channels: Optional[int] = None
158
+ num_layers: int = 16
159
+ dropout: float = 0.0
160
+ norm_num_groups: int = 32
161
+ cross_attention_dim: Optional[int] = None
162
+ attention_bias: bool = False
163
+
164
+ cfg: Config
165
+
166
+ def configure(self) -> None:
167
+ self.num_attention_heads = self.cfg.num_attention_heads
168
+ self.attention_head_dim = self.cfg.attention_head_dim
169
+ inner_dim = self.num_attention_heads * self.attention_head_dim
170
+
171
+ # Define input layers
172
+ self.norm = torch.nn.GroupNorm(
173
+ num_groups=self.cfg.norm_num_groups,
174
+ num_channels=self.cfg.in_channels,
175
+ eps=1e-6,
176
+ affine=True,
177
+ )
178
+ self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
179
+
180
+ # Define transformers blocks
181
+ self.transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicBlock(
184
+ inner_dim,
185
+ kv_dim=self.cfg.cross_attention_dim,
186
+ num_heads=self.num_attention_heads,
187
+ qkv_bias=self.cfg.attention_bias,
188
+ proj_drop=self.cfg.dropout,
189
+ ff_drop=self.cfg.dropout,
190
+ )
191
+ for d in range(self.cfg.num_layers)
192
+ ]
193
+ )
194
+
195
+ # 4. Define output layers
196
+ self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
197
+
198
+ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
199
+ residual = hidden_states
200
+ hidden_states = self.norm(hidden_states)
201
+ hidden_states = hidden_states.permute(0, 2, 1)
202
+ hidden_states = self.proj_in(hidden_states)
203
+ for block in self.transformer_blocks:
204
+ hidden_states = block(hidden_states, encoder_hidden_states)
205
+ hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
206
+ # TODO: do we really need to add the residual?
207
+ hidden_states = hidden_states + residual
208
+ return hidden_states
209
+
210
+
211
+ class FuseBlock(nn.Module):
212
+ """
213
+ Fuse X in to Z with cross attention
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ dim_z: int,
219
+ dim_x: int,
220
+ num_heads: int = 16,
221
+ qkv_bias: bool = False,
222
+ attn_drop: float = 0.0,
223
+ proj_drop: float = 0.0,
224
+ ff_drop: float = 0.0,
225
+ norm_x_input: bool = True,
226
+ ):
227
+ super().__init__()
228
+ self.norm_x_input = norm_x_input
229
+ if self.norm_x_input:
230
+ self.norm_x = nn.LayerNorm(dim_x)
231
+ self.attn = CrossAttention(
232
+ dim_z,
233
+ kv_dim=dim_x,
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ attn_drop=attn_drop,
237
+ proj_drop=proj_drop,
238
+ )
239
+ self.norm_z1 = nn.LayerNorm(dim_z)
240
+ self.norm_z2 = nn.LayerNorm(dim_z)
241
+ self.ff = FeedForward(dim_z, dropout=ff_drop)
242
+
243
+ def forward(self, z, x):
244
+ # TODO: do we need to normalize x?
245
+ z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
246
+ z = z + self.ff(self.norm_z2(z))
247
+ return z
248
+
249
+
250
+ @torch.no_grad()
251
+ def get_triplane_attention_mask(res):
252
+ N = 3 * res * res
253
+ attn_mask = torch.zeros(3, res, res, 3, res, res)
254
+
255
+ i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
256
+
257
+ attn_mask[0, i, j, 1, i, :] = 1.0
258
+ attn_mask[0, i, j, 2, j, :] = 1.0
259
+ attn_mask[1, i, j, 0, i, :] = 1.0
260
+ attn_mask[1, i, j, 2, :, j] = 1.0
261
+ attn_mask[2, i, j, 0, :, i] = 1.0
262
+ attn_mask[2, i, j, 1, :, j] = 1.0
263
+ attn_mask = attn_mask.bool()
264
+
265
+ attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
266
+ attn_bias.masked_fill_(attn_mask, 0.0)
267
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
268
+
269
+ return attn_bias.reshape(N, N)
270
+
271
+
272
+ class TriplaneAttention(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ resolution: int,
277
+ num_heads: int = 16,
278
+ qkv_bias: bool = False,
279
+ attn_drop: float = 0.0,
280
+ proj_drop: float = 0.0,
281
+ full_attention: bool = False,
282
+ ):
283
+ super().__init__()
284
+ self.num_heads = num_heads
285
+ head_dim = dim // num_heads
286
+ self.scale = head_dim**-0.5
287
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
288
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
289
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
290
+ self.attn_drop = attn_drop
291
+ self.proj = nn.Linear(dim, dim)
292
+ self.proj_drop = nn.Dropout(proj_drop)
293
+
294
+ self.resolution = resolution
295
+ self.full_attention = full_attention
296
+ self.attn_mask = (
297
+ get_triplane_attention_mask(resolution) if not full_attention else None
298
+ )
299
+
300
+ def forward(self, x):
301
+ B, N, C = x.shape
302
+ # [B, N, C] -> [B, N, H, C/H]
303
+ q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
304
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
305
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
306
+
307
+ # detokenize the planes
308
+ assert N == self.resolution**2 * 3
309
+ attn_bias = (
310
+ self.attn_mask.to(q)
311
+ .unsqueeze(0)
312
+ .unsqueeze(0)
313
+ .expand(B, self.num_heads, -1, -1)
314
+ if not self.full_attention
315
+ else None
316
+ )
317
+
318
+ # full attention
319
+ x = torch.nn.functional.scaled_dot_product_attention(
320
+ q.permute(0, 2, 1, 3),
321
+ k.permute(0, 2, 1, 3),
322
+ v.permute(0, 2, 1, 3),
323
+ attn_mask=attn_bias,
324
+ dropout_p=self.attn_drop,
325
+ scale=self.scale,
326
+ ).permute(0, 2, 1, 3)
327
+
328
+ # [B, N_q, H, C/H] -> [B, N_q, C]
329
+ x = x.reshape(B, N, C)
330
+ x = self.proj(x)
331
+ x = self.proj_drop(x)
332
+ return x
333
+
334
+
335
+ class TwoStreamBlock(nn.Module):
336
+ def __init__(
337
+ self,
338
+ dim_latent: int,
339
+ dim_input: int,
340
+ num_basic_blocks: int = 4,
341
+ num_heads: int = 16,
342
+ qkv_bias: bool = False,
343
+ attn_drop: float = 0.0,
344
+ proj_drop: float = 0.0,
345
+ ff_drop: float = 0.0,
346
+ norm_x_input: bool = True,
347
+ dim_cross: Optional[int] = None,
348
+ ):
349
+ super().__init__()
350
+
351
+ # Define the fuse block that fuse the input into the latent
352
+ self.fuse_block_in = FuseBlock(
353
+ dim_latent,
354
+ dim_input,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ attn_drop=attn_drop,
358
+ proj_drop=proj_drop,
359
+ ff_drop=ff_drop,
360
+ norm_x_input=norm_x_input,
361
+ )
362
+
363
+ # Define the transformer block that process the latent
364
+ self.transformer_block = nn.ModuleList(
365
+ [
366
+ BasicBlock(
367
+ dim_latent,
368
+ kv_dim=dim_cross,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ proj_drop=proj_drop,
372
+ ff_drop=ff_drop,
373
+ )
374
+ for _ in range(num_basic_blocks)
375
+ ]
376
+ )
377
+
378
+ # Define the fuse block that fuse the latent into the input
379
+ self.fuse_block_out = FuseBlock(
380
+ dim_input,
381
+ dim_latent,
382
+ num_heads=num_heads,
383
+ qkv_bias=qkv_bias,
384
+ attn_drop=attn_drop,
385
+ proj_drop=proj_drop,
386
+ ff_drop=ff_drop,
387
+ norm_x_input=norm_x_input,
388
+ )
389
+
390
+ def forward(self, latent, input, cross_input):
391
+ latent = self.fuse_block_in(latent, input)
392
+ for block in self.transformer_block:
393
+ latent = block(latent, cross_input)
394
+ input = self.fuse_block_out(input, latent)
395
+ return latent, input
396
+
397
+
398
+ class TwoStreamInterleaveTransformer(BaseModule):
399
+ @dataclass
400
+ class Config(BaseModule.Config):
401
+ num_attention_heads: int = 16
402
+ attention_head_dim: int = 64
403
+ raw_triplane_channels: int = 1024
404
+ triplane_channels: int = 1024
405
+ raw_image_channels: int = 1024
406
+ num_latents: int = 1792
407
+ num_blocks: int = 4
408
+ num_basic_blocks: int = 3
409
+ dropout: float = 0.0
410
+ latent_init_std: float = 0.02
411
+ norm_num_groups: int = 32
412
+ attention_bias: bool = False
413
+ norm_x_input: bool = False
414
+ cross_attention_dim: int = 1024
415
+ mix_latent: bool = True
416
+
417
+ cfg: Config
418
+
419
+ def configure(self) -> None:
420
+ self.mix_latent = self.cfg.mix_latent
421
+
422
+ # Define the dimensions
423
+ self.num_attention_heads = self.cfg.num_attention_heads
424
+ self.attention_head_dim = self.cfg.attention_head_dim
425
+ self.num_latents = self.cfg.num_latents
426
+ self.latent_dim = self.num_attention_heads * self.attention_head_dim
427
+
428
+ # Define input layers
429
+ if self.cfg.norm_num_groups > 0:
430
+ self.norm_triplane = torch.nn.GroupNorm(
431
+ num_groups=self.cfg.norm_num_groups,
432
+ num_channels=self.cfg.raw_triplane_channels,
433
+ eps=1e-6,
434
+ affine=True,
435
+ )
436
+ else:
437
+ self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
438
+ self.proj_triplane = nn.Linear(
439
+ self.cfg.raw_triplane_channels, self.cfg.triplane_channels
440
+ )
441
+ if self.mix_latent:
442
+ self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
443
+ self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
444
+ self.norm_latent = nn.LayerNorm(self.latent_dim)
445
+ self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
446
+
447
+ # Define the latents
448
+ self.latent_init = nn.Parameter(
449
+ torch.zeros(1, self.num_latents, self.latent_dim)
450
+ )
451
+ nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
452
+
453
+ # Define the transformer blocks
454
+ self.main_blocks = nn.ModuleList(
455
+ [
456
+ TwoStreamBlock(
457
+ self.latent_dim,
458
+ self.cfg.triplane_channels,
459
+ num_basic_blocks=self.cfg.num_basic_blocks,
460
+ num_heads=self.num_attention_heads,
461
+ qkv_bias=self.cfg.attention_bias,
462
+ proj_drop=self.cfg.dropout,
463
+ ff_drop=self.cfg.dropout,
464
+ norm_x_input=self.cfg.norm_x_input,
465
+ dim_cross=self.cfg.cross_attention_dim,
466
+ )
467
+ for _ in range(self.cfg.num_blocks)
468
+ ]
469
+ )
470
+
471
+ # 4. Define output layers
472
+ self.proj_out = nn.Linear(
473
+ self.cfg.triplane_channels, self.cfg.raw_triplane_channels
474
+ )
475
+
476
+ def forward(self, hidden_states, encoder_hidden_states, **kwargs):
477
+ # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
478
+ # encoder_hidden_states: [B, N_image, image_dim] is the image tokens
479
+ if isinstance(self.norm_triplane, nn.GroupNorm):
480
+ triplane_tokens = self.norm_triplane(hidden_states)
481
+ triplane_tokens = triplane_tokens.permute(
482
+ 0, 2, 1
483
+ ) # [B, N_triplane, triplane_dim]
484
+ elif isinstance(self.norm_triplane, nn.LayerNorm):
485
+ triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
486
+ else:
487
+ raise ValueError("Unknown normalization layer")
488
+ triplane_tokens = self.proj_triplane(triplane_tokens)
489
+ if self.mix_latent:
490
+ image_tokens = self.norm_image(
491
+ encoder_hidden_states
492
+ ) # [B, N_image, image_dim]
493
+ image_tokens = self.proj_image(image_tokens)
494
+ init_latents = self.latent_init.expand(
495
+ hidden_states.shape[0], -1, -1
496
+ ) # [B, N_latent_init, latent_dim]
497
+ init_latents = self.norm_latent(init_latents)
498
+ init_latents = self.proj_latent(init_latents)
499
+ if self.mix_latent:
500
+ latent_tokens = torch.cat(
501
+ [image_tokens, init_latents], dim=1
502
+ ) # [B, N_latent, latent_dim]
503
+ else:
504
+ latent_tokens = init_latents
505
+
506
+ # forward the main blocks
507
+ for block in self.main_blocks:
508
+ latent_tokens, triplane_tokens = block(
509
+ latent_tokens, triplane_tokens, encoder_hidden_states
510
+ )
511
+
512
+ # project the triplane tokens back to the original dimension
513
+ triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
514
+ triplane_tokens = triplane_tokens + hidden_states
515
+ return triplane_tokens
sf3d/models/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import importlib
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from jaxtyping import Bool, Float, Int, Num
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch import Tensor
15
+
16
+
17
+ class BaseModule(nn.Module):
18
+ @dataclass
19
+ class Config:
20
+ pass
21
+
22
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
23
+
24
+ def __init__(
25
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
26
+ ) -> None:
27
+ super().__init__()
28
+ self.cfg = parse_structured(self.Config, cfg)
29
+ self.configure(*args, **kwargs)
30
+
31
+ def configure(self, *args, **kwargs) -> None:
32
+ raise NotImplementedError
33
+
34
+
35
+ def find_class(cls_string):
36
+ module_string = ".".join(cls_string.split(".")[:-1])
37
+ cls_name = cls_string.split(".")[-1]
38
+ module = importlib.import_module(module_string, package=None)
39
+ cls = getattr(module, cls_name)
40
+ return cls
41
+
42
+
43
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
44
+ # Check if cfg.keys are in fields
45
+ cfg_ = cfg.copy()
46
+ keys = list(cfg_.keys())
47
+
48
+ field_names = {f.name for f in dataclasses.fields(fields)}
49
+ for key in keys:
50
+ # This is helpful when swapping out modules from CLI
51
+ if key not in field_names:
52
+ print(f"Ignoring {key} as it's not supported by {fields}")
53
+ cfg_.pop(key)
54
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
55
+ return scfg
56
+
57
+
58
+ EPS_DTYPE = {
59
+ torch.float16: 1e-4,
60
+ torch.bfloat16: 1e-4,
61
+ torch.float32: 1e-7,
62
+ torch.float64: 1e-8,
63
+ }
64
+
65
+
66
+ def dot(x, y, dim=-1):
67
+ return torch.sum(x * y, dim, keepdim=True)
68
+
69
+
70
+ def reflect(x, n):
71
+ return x - 2 * dot(x, n) * n
72
+
73
+
74
+ def normalize(x, dim=-1, eps=None):
75
+ if eps is None:
76
+ eps = EPS_DTYPE[x.dtype]
77
+ return F.normalize(x, dim=dim, p=2, eps=eps)
78
+
79
+
80
+ def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
81
+ # One pad for determinant
82
+ tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
83
+ det_tri = torch.det(tri_sq)
84
+ tri_rev = torch.cat(
85
+ (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
86
+ )
87
+ tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
88
+ return tri_sq
89
+
90
+
91
+ def triangle_intersection_2d(
92
+ t1: Float[Tensor, "*B 3 2"],
93
+ t2: Float[Tensor, "*B 3 2"],
94
+ eps=1e-12,
95
+ ) -> Float[Tensor, "*B"]: # noqa: F821
96
+ """Returns True if triangles collide, False otherwise"""
97
+
98
+ def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
99
+ logdetx = torch.logdet(x.double())
100
+ if eps is None:
101
+ return ~torch.isfinite(logdetx)
102
+ return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
103
+
104
+ t1s = tri_winding(t1)
105
+ t2s = tri_winding(t2)
106
+
107
+ # Assume the triangles do not collide in the begging
108
+ ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
109
+ for i in range(3):
110
+ edge = torch.roll(t1s, i, dims=1)[:, :2, :]
111
+ # Check if all points of triangle 2 lay on the external side of edge E.
112
+ # If this is the case the triangle do not collide
113
+ upd = (
114
+ chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
115
+ & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
116
+ & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
117
+ )
118
+ # Here no collision is still True due to inversion
119
+ ret = ret | upd
120
+
121
+ for i in range(3):
122
+ edge = torch.roll(t2s, i, dims=1)[:, :2, :]
123
+
124
+ upd = (
125
+ chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
126
+ & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
127
+ & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
128
+ )
129
+ # Here no collision is still True due to inversion
130
+ ret = ret | upd
131
+
132
+ return ~ret # Do the inversion
133
+
134
+
135
+ ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
136
+
137
+
138
+ def scale_tensor(
139
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
140
+ ):
141
+ if inp_scale is None:
142
+ inp_scale = (0, 1)
143
+ if tgt_scale is None:
144
+ tgt_scale = (0, 1)
145
+ if isinstance(tgt_scale, Tensor):
146
+ assert dat.shape[-1] == tgt_scale.shape[-1]
147
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
148
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
149
+ return dat
150
+
151
+
152
+ def dilate_fill(img, mask, iterations=10):
153
+ oldMask = mask.float()
154
+ oldImg = img
155
+
156
+ mask_kernel = torch.ones(
157
+ (1, 1, 3, 3),
158
+ dtype=oldMask.dtype,
159
+ device=oldMask.device,
160
+ )
161
+
162
+ for i in range(iterations):
163
+ newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
164
+
165
+ # Fill the extension with mean color of old valid regions
166
+ img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
167
+ mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
168
+ new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
169
+
170
+ # Average color of the valid region
171
+ mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
172
+ 2
173
+ )
174
+ # Extend it to the new region
175
+ fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
176
+
177
+ mask_conv = F.conv2d(
178
+ newMask, mask_kernel, padding=1
179
+ ) # Get the sum for each kernel patch
180
+ newImg = F.fold(
181
+ fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
182
+ ) / mask_conv.clamp(1)
183
+
184
+ diffMask = newMask - oldMask
185
+
186
+ oldMask = newMask
187
+ oldImg = torch.lerp(oldImg, newImg, diffMask)
188
+
189
+ return oldImg
190
+
191
+
192
+ def float32_to_uint8_np(
193
+ x: Float[np.ndarray, "*B H W C"],
194
+ dither: bool = True,
195
+ dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
196
+ dither_strength: float = 1.0,
197
+ ) -> Int[np.ndarray, "*B H W C"]:
198
+ if dither:
199
+ dither = (
200
+ dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
201
+ )
202
+ if dither_mask is not None:
203
+ dither = dither * dither_mask
204
+ return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
205
+ return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
206
+
207
+
208
+ def convert_data(data):
209
+ if data is None:
210
+ return None
211
+ elif isinstance(data, np.ndarray):
212
+ return data
213
+ elif isinstance(data, torch.Tensor):
214
+ if data.dtype in [torch.float16, torch.bfloat16]:
215
+ data = data.float()
216
+ return data.detach().cpu().numpy()
217
+ elif isinstance(data, list):
218
+ return [convert_data(d) for d in data]
219
+ elif isinstance(data, dict):
220
+ return {k: convert_data(v) for k, v in data.items()}
221
+ else:
222
+ raise TypeError(
223
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
224
+ type(data),
225
+ )
226
+
227
+
228
+ class ImageProcessor:
229
+ def convert_and_resize(
230
+ self,
231
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
232
+ size: int,
233
+ ):
234
+ if isinstance(image, PIL.Image.Image):
235
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
236
+ elif isinstance(image, np.ndarray):
237
+ if image.dtype == np.uint8:
238
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
239
+ else:
240
+ image = torch.from_numpy(image)
241
+ elif isinstance(image, torch.Tensor):
242
+ pass
243
+
244
+ batched = image.ndim == 4
245
+
246
+ if not batched:
247
+ image = image[None, ...]
248
+ image = F.interpolate(
249
+ image.permute(0, 3, 1, 2),
250
+ (size, size),
251
+ mode="bilinear",
252
+ align_corners=False,
253
+ antialias=True,
254
+ ).permute(0, 2, 3, 1)
255
+ if not batched:
256
+ image = image[0]
257
+ return image
258
+
259
+ def __call__(
260
+ self,
261
+ image: Union[
262
+ PIL.Image.Image,
263
+ np.ndarray,
264
+ torch.FloatTensor,
265
+ List[PIL.Image.Image],
266
+ List[np.ndarray],
267
+ List[torch.FloatTensor],
268
+ ],
269
+ size: int,
270
+ ) -> Any:
271
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
272
+ image = self.convert_and_resize(image, size)
273
+ else:
274
+ if not isinstance(image, list):
275
+ image = [image]
276
+ image = [self.convert_and_resize(im, size) for im in image]
277
+ image = torch.stack(image, dim=0)
278
+ return image
279
+
280
+
281
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
282
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
283
+ intrinsic = np.identity(3, dtype=np.float32)
284
+ intrinsic[0, 0] = focal_length
285
+ intrinsic[1, 1] = focal_length
286
+ intrinsic[0, 2] = W / 2.0
287
+ intrinsic[1, 2] = H / 2.0
288
+
289
+ if bs > 0:
290
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
291
+
292
+ return torch.from_numpy(intrinsic)
sf3d/system.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import trimesh
9
+ from einops import rearrange
10
+ from huggingface_hub import hf_hub_download
11
+ from jaxtyping import Float
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from safetensors.torch import load_model
15
+ from torch import Tensor
16
+
17
+ from sf3d.models.isosurface import MarchingTetrahedraHelper
18
+ from sf3d.models.mesh import Mesh
19
+ from sf3d.models.utils import (
20
+ BaseModule,
21
+ ImageProcessor,
22
+ convert_data,
23
+ dilate_fill,
24
+ dot,
25
+ find_class,
26
+ float32_to_uint8_np,
27
+ normalize,
28
+ scale_tensor,
29
+ )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
31
+
32
+ from .texture_baker import TextureBaker
33
+
34
+
35
+ class SF3D(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ cond_image_size: int
39
+ isosurface_resolution: int
40
+ isosurface_threshold: float = 10.0
41
+ radius: float = 1.0
42
+ background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
43
+ default_fovy_deg: float = 40.0
44
+ default_distance: float = 1.6
45
+
46
+ camera_embedder_cls: str = ""
47
+ camera_embedder: dict = field(default_factory=dict)
48
+
49
+ image_tokenizer_cls: str = ""
50
+ image_tokenizer: dict = field(default_factory=dict)
51
+
52
+ tokenizer_cls: str = ""
53
+ tokenizer: dict = field(default_factory=dict)
54
+
55
+ backbone_cls: str = ""
56
+ backbone: dict = field(default_factory=dict)
57
+
58
+ post_processor_cls: str = ""
59
+ post_processor: dict = field(default_factory=dict)
60
+
61
+ decoder_cls: str = ""
62
+ decoder: dict = field(default_factory=dict)
63
+
64
+ image_estimator_cls: str = ""
65
+ image_estimator: dict = field(default_factory=dict)
66
+
67
+ global_estimator_cls: str = ""
68
+ global_estimator: dict = field(default_factory=dict)
69
+
70
+ cfg: Config
71
+
72
+ @classmethod
73
+ def from_pretrained(
74
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
75
+ ):
76
+ if os.path.isdir(pretrained_model_name_or_path):
77
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
78
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
79
+ else:
80
+ config_path = hf_hub_download(
81
+ repo_id=pretrained_model_name_or_path, filename=config_name
82
+ )
83
+ weight_path = hf_hub_download(
84
+ repo_id=pretrained_model_name_or_path, filename=weight_name
85
+ )
86
+
87
+ cfg = OmegaConf.load(config_path)
88
+ OmegaConf.resolve(cfg)
89
+ model = cls(cfg)
90
+ load_model(model, weight_path)
91
+ return model
92
+
93
+ @property
94
+ def device(self):
95
+ return next(self.parameters()).device
96
+
97
+ def configure(self):
98
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
99
+ self.cfg.image_tokenizer
100
+ )
101
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
102
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
103
+ self.cfg.camera_embedder
104
+ )
105
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
106
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
107
+ self.cfg.post_processor
108
+ )
109
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
110
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
111
+ self.cfg.image_estimator
112
+ )
113
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
114
+ self.cfg.global_estimator
115
+ )
116
+
117
+ self.bbox: Float[Tensor, "2 3"]
118
+ self.register_buffer(
119
+ "bbox",
120
+ torch.as_tensor(
121
+ [
122
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
123
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
124
+ ],
125
+ dtype=torch.float32,
126
+ ),
127
+ )
128
+ self.isosurface_helper = MarchingTetrahedraHelper(
129
+ self.cfg.isosurface_resolution,
130
+ os.path.join(
131
+ os.path.dirname(__file__),
132
+ "..",
133
+ "load",
134
+ "tets",
135
+ f"{self.cfg.isosurface_resolution}_tets.npz",
136
+ ),
137
+ )
138
+
139
+ self.baker = TextureBaker()
140
+ self.image_processor = ImageProcessor()
141
+
142
+ def triplane_to_meshes(
143
+ self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
144
+ ) -> list[Mesh]:
145
+ meshes = []
146
+ for i in range(triplanes.shape[0]):
147
+ triplane = triplanes[i]
148
+ grid_vertices = scale_tensor(
149
+ self.isosurface_helper.grid_vertices.to(triplanes.device),
150
+ self.isosurface_helper.points_range,
151
+ self.bbox,
152
+ )
153
+
154
+ values = self.query_triplane(grid_vertices, triplane)
155
+ decoded = self.decoder(values, include=["vertex_offset", "density"])
156
+ sdf = decoded["density"] - self.cfg.isosurface_threshold
157
+
158
+ deform = decoded["vertex_offset"].squeeze(0)
159
+
160
+ mesh: Mesh = self.isosurface_helper(
161
+ sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
162
+ )
163
+ mesh.v_pos = scale_tensor(
164
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
165
+ )
166
+
167
+ meshes.append(mesh)
168
+
169
+ return meshes
170
+
171
+ def query_triplane(
172
+ self,
173
+ positions: Float[Tensor, "*B N 3"],
174
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
175
+ ) -> Float[Tensor, "*B N F"]:
176
+ batched = positions.ndim == 3
177
+ if not batched:
178
+ # no batch dimension
179
+ triplanes = triplanes[None, ...]
180
+ positions = positions[None, ...]
181
+ assert triplanes.ndim == 5 and positions.ndim == 3
182
+
183
+ positions = scale_tensor(
184
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
185
+ )
186
+
187
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
188
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
189
+ dim=-3,
190
+ ).to(triplanes.dtype)
191
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
192
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
193
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
194
+ align_corners=True,
195
+ mode="bilinear",
196
+ )
197
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
198
+
199
+ return out
200
+
201
+ def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
202
+ # if batch[rgb_cond] is only one view, add a view dimension
203
+ if len(batch["rgb_cond"].shape) == 4:
204
+ batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
205
+ batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
206
+ batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
207
+ batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
208
+ batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
209
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
210
+
211
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
212
+ camera_embeds = self.camera_embedder(**batch)
213
+
214
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
215
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
216
+ modulation_cond=camera_embeds,
217
+ )
218
+
219
+ input_image_tokens = rearrange(
220
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
221
+ )
222
+
223
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
224
+
225
+ tokens = self.backbone(
226
+ tokens,
227
+ encoder_hidden_states=input_image_tokens,
228
+ modulation_cond=None,
229
+ )
230
+
231
+ direct_codes = self.tokenizer.detokenize(tokens)
232
+ scene_codes = self.post_processor(direct_codes)
233
+ return scene_codes, direct_codes
234
+
235
+ def run_image(
236
+ self,
237
+ image: Image,
238
+ bake_resolution: int,
239
+ estimate_illumination: bool = False,
240
+ ) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
241
+ if image.mode != "RGBA":
242
+ raise ValueError("Image must be in RGBA mode")
243
+ img_cond = (
244
+ torch.from_numpy(
245
+ np.asarray(
246
+ image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
247
+ ).astype(np.float32)
248
+ / 255.0
249
+ )
250
+ .float()
251
+ .clip(0, 1)
252
+ .to(self.device)
253
+ )
254
+ mask_cond = img_cond[:, :, -1:]
255
+ rgb_cond = torch.lerp(
256
+ torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
257
+ img_cond[:, :, :3],
258
+ mask_cond,
259
+ )
260
+
261
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
262
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
263
+ self.cfg.default_fovy_deg,
264
+ self.cfg.cond_image_size,
265
+ self.cfg.cond_image_size,
266
+ )
267
+
268
+ batch = {
269
+ "rgb_cond": rgb_cond,
270
+ "mask_cond": mask_cond,
271
+ "c2w_cond": c2w_cond.unsqueeze(0),
272
+ "intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
273
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
274
+ }
275
+
276
+ meshes, global_dict = self.generate_mesh(
277
+ batch, bake_resolution, estimate_illumination
278
+ )
279
+ return meshes[0], global_dict
280
+
281
+ def generate_mesh(
282
+ self,
283
+ batch,
284
+ bake_resolution: int,
285
+ estimate_illumination: bool = False,
286
+ ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
287
+ batch["rgb_cond"] = self.image_processor(
288
+ batch["rgb_cond"], self.cfg.cond_image_size
289
+ )
290
+ batch["mask_cond"] = self.image_processor(
291
+ batch["mask_cond"], self.cfg.cond_image_size
292
+ )
293
+ scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
294
+
295
+ global_dict = {}
296
+ if self.image_estimator is not None:
297
+ global_dict.update(
298
+ self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
299
+ )
300
+ if self.global_estimator is not None and estimate_illumination:
301
+ global_dict.update(self.global_estimator(non_postprocessed_codes))
302
+
303
+ with torch.no_grad():
304
+ with torch.autocast(device_type="cuda", enabled=False):
305
+ meshes = self.triplane_to_meshes(scene_codes)
306
+
307
+ rets = []
308
+ for i, mesh in enumerate(meshes):
309
+ # Check for empty mesh
310
+ if mesh.v_pos.shape[0] == 0:
311
+ rets.append(trimesh.Trimesh())
312
+ continue
313
+
314
+ mesh.unwrap_uv()
315
+
316
+ # Build textures
317
+ rast = self.baker.rasterize(
318
+ mesh.v_tex, mesh.t_pos_idx, bake_resolution
319
+ )
320
+ bake_mask = self.baker.get_mask(rast)
321
+
322
+ pos_bake = self.baker.interpolate(
323
+ mesh.v_pos,
324
+ rast,
325
+ mesh.t_pos_idx,
326
+ mesh.v_tex,
327
+ )
328
+ gb_pos = pos_bake[bake_mask]
329
+
330
+ tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
331
+ decoded = self.decoder(
332
+ tri_query, exclude=["density", "vertex_offset"]
333
+ )
334
+
335
+ nrm = self.baker.interpolate(
336
+ mesh.v_nrm,
337
+ rast,
338
+ mesh.t_pos_idx,
339
+ mesh.v_tex,
340
+ )
341
+ gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
342
+ decoded["normal"] = gb_nrm
343
+
344
+ # Check if any keys in global_dict start with decoded_
345
+ for k, v in global_dict.items():
346
+ if k.startswith("decoder_"):
347
+ decoded[k.replace("decoder_", "")] = v[i]
348
+
349
+ mat_out = {
350
+ "albedo": decoded["features"],
351
+ "roughness": decoded["roughness"],
352
+ "metallic": decoded["metallic"],
353
+ "normal": normalize(decoded["perturb_normal"]),
354
+ "bump": None,
355
+ }
356
+
357
+ for k, v in mat_out.items():
358
+ if v is None:
359
+ continue
360
+ if v.shape[0] == 1:
361
+ # Skip and directly add a single value
362
+ mat_out[k] = v[0]
363
+ else:
364
+ f = torch.zeros(
365
+ bake_resolution,
366
+ bake_resolution,
367
+ v.shape[-1],
368
+ dtype=v.dtype,
369
+ device=v.device,
370
+ )
371
+ if v.shape == f.shape:
372
+ continue
373
+ if k == "normal":
374
+ # Use un-normalized tangents here so that larger smaller tris
375
+ # Don't effect the tangents that much
376
+ tng = self.baker.interpolate(
377
+ mesh.v_tng,
378
+ rast,
379
+ mesh.t_pos_idx,
380
+ mesh.v_tex,
381
+ )
382
+ gb_tng = tng[bake_mask]
383
+ gb_tng = F.normalize(gb_tng, dim=-1)
384
+ gb_btng = F.normalize(
385
+ torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
386
+ )
387
+ normal = F.normalize(mat_out["normal"], dim=-1)
388
+
389
+ bump = torch.cat(
390
+ # Check if we have to flip some things
391
+ (
392
+ dot(normal, gb_tng),
393
+ dot(normal, gb_btng),
394
+ dot(normal, gb_nrm).clip(
395
+ 0.3, 1
396
+ ), # Never go below 0.3. This would indicate a flipped (or close to one) normal
397
+ ),
398
+ -1,
399
+ )
400
+ bump[..., :2] *= 0.5
401
+ bump = (bump * 0.5 + 0.5).clamp(0, 1)
402
+
403
+ f[bake_mask] = bump.view(-1, 3)
404
+ mat_out["bump"] = f
405
+ else:
406
+ f[bake_mask] = v.view(-1, v.shape[-1])
407
+ mat_out[k] = f
408
+
409
+ def uv_padding(arr):
410
+ if arr.ndim == 1:
411
+ return arr
412
+ return (
413
+ dilate_fill(
414
+ arr.permute(2, 0, 1)[None, ...],
415
+ bake_mask.unsqueeze(0).unsqueeze(0),
416
+ iterations=bake_resolution // 150,
417
+ )
418
+ .squeeze(0)
419
+ .permute(1, 2, 0)
420
+ )
421
+
422
+ verts_np = convert_data(mesh.v_pos)
423
+ faces = convert_data(mesh.t_pos_idx)
424
+ uvs = convert_data(mesh.v_tex)
425
+
426
+ basecolor_tex = Image.fromarray(
427
+ float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
428
+ ).convert("RGB")
429
+ basecolor_tex.format = "JPEG"
430
+
431
+ metallic = mat_out["metallic"].squeeze().cpu().item()
432
+ roughness = mat_out["roughness"].squeeze().cpu().item()
433
+
434
+ if "bump" in mat_out and mat_out["bump"] is not None:
435
+ bump_np = convert_data(uv_padding(mat_out["bump"]))
436
+ bump_up = np.ones_like(bump_np)
437
+ bump_up[..., :2] = 0.5
438
+ bump_up[..., 2:] = 1
439
+ bump_tex = Image.fromarray(
440
+ float32_to_uint8_np(
441
+ bump_np,
442
+ dither=True,
443
+ # Do not dither if something is perfectly flat
444
+ dither_mask=np.all(
445
+ bump_np == bump_up, axis=-1, keepdims=True
446
+ ).astype(np.float32),
447
+ )
448
+ ).convert("RGB")
449
+ bump_tex.format = (
450
+ "JPEG" # PNG would be better but the assets are larger
451
+ )
452
+ else:
453
+ bump_tex = None
454
+
455
+ material = trimesh.visual.material.PBRMaterial(
456
+ baseColorTexture=basecolor_tex,
457
+ roughnessFactor=roughness,
458
+ metallicFactor=metallic,
459
+ normalTexture=bump_tex,
460
+ )
461
+
462
+ tmesh = trimesh.Trimesh(
463
+ vertices=verts_np,
464
+ faces=faces,
465
+ visual=trimesh.visual.texture.TextureVisuals(
466
+ uv=uvs, material=material
467
+ ),
468
+ )
469
+ rot = trimesh.transformations.rotation_matrix(
470
+ np.radians(-90), [1, 0, 0]
471
+ )
472
+ tmesh.apply_transform(rot)
473
+ tmesh.apply_transform(
474
+ trimesh.transformations.rotation_matrix(
475
+ np.radians(90), [0, 1, 0]
476
+ )
477
+ )
478
+
479
+ tmesh.invert()
480
+
481
+ rets.append(tmesh)
482
+
483
+ return rets, global_dict
sf3d/texture_baker.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import slangtorch
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Bool, Float
7
+ from torch import Tensor
8
+
9
+
10
+ class TextureBaker(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.baker = slangtorch.loadModule(
14
+ os.path.join(os.path.dirname(__file__), "texture_baker.slang")
15
+ )
16
+
17
+ def rasterize(
18
+ self,
19
+ uv: Float[Tensor, "Nv 2"],
20
+ face_indices: Float[Tensor, "Nf 3"],
21
+ bake_resolution: int,
22
+ ) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
23
+ if not face_indices.is_cuda or not uv.is_cuda:
24
+ raise ValueError("All input tensors must be on cuda")
25
+
26
+ face_indices = face_indices.to(torch.int32)
27
+ uv = uv.to(torch.float32)
28
+
29
+ rast_result = torch.empty(
30
+ bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
31
+ )
32
+
33
+ block_size = 16
34
+ grid_size = bake_resolution // block_size
35
+ self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
36
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
37
+ )
38
+
39
+ return rast_result
40
+
41
+ def get_mask(
42
+ self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
43
+ ) -> Bool[Tensor, "bake_resolution bake_resolution"]:
44
+ return rast[..., -1] >= 0
45
+
46
+ def interpolate(
47
+ self,
48
+ attr: Float[Tensor, "Nv 3"],
49
+ rast: Float[Tensor, "bake_resolution bake_resolution 4"],
50
+ face_indices: Float[Tensor, "Nf 3"],
51
+ uv: Float[Tensor, "Nv 2"],
52
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
53
+ # Make sure all input tensors are on torch
54
+ if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
55
+ raise ValueError("All input tensors must be on cuda")
56
+
57
+ attr = attr.to(torch.float32)
58
+ face_indices = face_indices.to(torch.int32)
59
+ uv = uv.to(torch.float32)
60
+
61
+ pos_bake = torch.zeros(
62
+ rast.shape[0],
63
+ rast.shape[1],
64
+ 3,
65
+ device=attr.device,
66
+ dtype=attr.dtype,
67
+ )
68
+
69
+ block_size = 16
70
+ grid_size = rast.shape[0] // block_size
71
+ self.baker.interpolate(
72
+ attr=attr, indices=face_indices, rast=rast, output=pos_bake
73
+ ).launchRaw(
74
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
75
+ )
76
+
77
+ return pos_bake
78
+
79
+ def forward(
80
+ self,
81
+ attr: Float[Tensor, "Nv 3"],
82
+ uv: Float[Tensor, "Nv 2"],
83
+ face_indices: Float[Tensor, "Nf 3"],
84
+ bake_resolution: int,
85
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
86
+ rast = self.rasterize(uv, face_indices, bake_resolution)
87
+ return self.interpolate(attr, rast, face_indices, uv)
sf3d/texture_baker.slang ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // xy: 2D test position
2
+ // v1: vertex position 1
3
+ // v2: vertex position 2
4
+ // v3: vertex position 3
5
+ //
6
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
7
+ {
8
+ // Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
9
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
10
+ float2 v1v2 = v2 - v1;
11
+ float2 v1v3 = v3 - v1;
12
+ float2 xyv1 = xy - v1;
13
+
14
+ float d00 = dot(v1v2, v1v2);
15
+ float d01 = dot(v1v2, v1v3);
16
+ float d11 = dot(v1v3, v1v3);
17
+ float d20 = dot(xyv1, v1v2);
18
+ float d21 = dot(xyv1, v1v3);
19
+
20
+ float denom = d00 * d11 - d01 * d01;
21
+ v = (d11 * d20 - d01 * d21) / denom;
22
+ w = (d00 * d21 - d01 * d20) / denom;
23
+ u = 1.0 - v - w;
24
+
25
+ return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
26
+ }
27
+
28
+ [AutoPyBindCUDA]
29
+ [CUDAKernel]
30
+ void interpolate(
31
+ TensorView<float3> attr,
32
+ TensorView<int3> indices,
33
+ TensorView<float4> rast,
34
+ TensorView<float3> output)
35
+ {
36
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
37
+
38
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
39
+
40
+ if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
41
+ return;
42
+
43
+ float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
44
+ int triangle_idx = int(barycentric.w);
45
+
46
+ if (triangle_idx < 0) {
47
+ output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
48
+ return;
49
+ }
50
+
51
+ float3 v1 = attr[indices[triangle_idx].x];
52
+ float3 v2 = attr[indices[triangle_idx].y];
53
+ float3 v3 = attr[indices[triangle_idx].z];
54
+
55
+ output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
56
+ }
57
+
58
+ [AutoPyBindCUDA]
59
+ [CUDAKernel]
60
+ void bake_uv(
61
+ TensorView<float2> uv,
62
+ TensorView<int3> indices,
63
+ TensorView<float4> output)
64
+ {
65
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
66
+
67
+ if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
68
+ return;
69
+
70
+ // We index x,y but the orginal coords are HW. So swap them
71
+ float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
72
+ // Normalize to [0, 1]
73
+ pixel_coord /= float2(output.size(1), output.size(0));
74
+ pixel_coord = clamp(pixel_coord, 0.0, 1.0);
75
+ // Flip x-axis
76
+ pixel_coord.y = 1 - pixel_coord.y;
77
+
78
+ for (int i = 0; i < indices.size(0); i++) {
79
+ float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
80
+ float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
81
+ float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
82
+
83
+ float u, v, w;
84
+ bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
85
+
86
+ if (hit){
87
+ output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
88
+ return;
89
+ }
90
+ }
91
+
92
+ output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
93
+ }
sf3d/utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import rembg
5
+ import torch
6
+ from PIL import Image
7
+
8
+ import sf3d.models.utils as sf3d_utils
9
+
10
+
11
+ def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
12
+ intrinsic = sf3d_utils.get_intrinsic_from_fov(
13
+ np.deg2rad(fov_deg),
14
+ H=cond_height,
15
+ W=cond_width,
16
+ )
17
+ intrinsic_normed_cond = intrinsic.clone()
18
+ intrinsic_normed_cond[..., 0, 2] /= cond_width
19
+ intrinsic_normed_cond[..., 1, 2] /= cond_height
20
+ intrinsic_normed_cond[..., 0, 0] /= cond_width
21
+ intrinsic_normed_cond[..., 1, 1] /= cond_height
22
+
23
+ return intrinsic, intrinsic_normed_cond
24
+
25
+
26
+ def default_cond_c2w(distance: float):
27
+ c2w_cond = torch.as_tensor(
28
+ [
29
+ [0, 0, 1, distance],
30
+ [1, 0, 0, 0],
31
+ [0, 1, 0, 0],
32
+ [0, 0, 0, 1],
33
+ ]
34
+ ).float()
35
+ return c2w_cond
36
+
37
+
38
+ def remove_background(
39
+ image: Image,
40
+ rembg_session: Any = None,
41
+ force: bool = False,
42
+ **rembg_kwargs,
43
+ ) -> Image:
44
+ do_remove = True
45
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
46
+ do_remove = False
47
+ do_remove = do_remove or force
48
+ if do_remove:
49
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
50
+ return image
51
+
52
+
53
+ def resize_foreground(
54
+ image: Image,
55
+ ratio: float,
56
+ ) -> Image:
57
+ image = np.array(image)
58
+ assert image.shape[-1] == 4
59
+ alpha = np.where(image[..., 3] > 0)
60
+ y1, y2, x1, x2 = (
61
+ alpha[0].min(),
62
+ alpha[0].max(),
63
+ alpha[1].min(),
64
+ alpha[1].max(),
65
+ )
66
+ # crop the foreground
67
+ fg = image[y1:y2, x1:x2]
68
+ # pad to square
69
+ size = max(fg.shape[0], fg.shape[1])
70
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
71
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
72
+ new_image = np.pad(
73
+ fg,
74
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
75
+ mode="constant",
76
+ constant_values=((0, 0), (0, 0), (0, 0)),
77
+ )
78
+
79
+ # compute padding according to the ratio
80
+ new_size = int(new_image.shape[0] / ratio)
81
+ # pad to size, double side
82
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
83
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
84
+ new_image = np.pad(
85
+ new_image,
86
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
87
+ mode="constant",
88
+ constant_values=((0, 0), (0, 0), (0, 0)),
89
+ )
90
+ new_image = Image.fromarray(new_image, mode="RGBA")
91
+ return new_image