jammmmm mboss commited on
Commit
77d8010
·
1 Parent(s): e96dd77

Update to latest inference code

Browse files

Co-authored-by: Mark Boss <[email protected]>

.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.gif filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.gif filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv*/
127
+ env/
128
+ venv*/
129
+ ENV/
130
+ env.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+ .vs/
163
+ .idea/
164
+ .vscode/
165
+
166
+ stabilityai/
167
+ output/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: check-ast
10
+ - id: check-merge-conflict
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ args: [--markdown-linebreak-ext=md]
15
+
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ # Ruff version.
18
+ rev: v0.3.5
19
+ hooks:
20
+ # Run the linter.
21
+ - id: ruff
22
+ args: [ --fix ]
23
+ # Run the formatter.
24
+ - id: ruff-format
README.md CHANGED
@@ -4,9 +4,9 @@ emoji: 🎮
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.31.4
8
  python_version: 3.10.13
9
- app_file: app.py
10
  pinned: false
11
  models:
12
  - stabilityai/stable-fast-3d
@@ -14,5 +14,3 @@ license: other
14
  license_name: stabilityai-ai-community
15
  license_link: LICENSE.md
16
  ---
17
-
18
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.41.0
8
  python_version: 3.10.13
9
+ app_file: gradio_app.py
10
  pinned: false
11
  models:
12
  - stabilityai/stable-fast-3d
 
14
  license_name: stabilityai-ai-community
15
  license_link: LICENSE.md
16
  ---
 
 
__init__.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import sys
5
+ from contextlib import nullcontext
6
+
7
+ import comfy.model_management
8
+ import folder_paths
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ from PIL import Image
13
+ from trimesh.exchange import gltf
14
+
15
+ sys.path.append(os.path.dirname(__file__))
16
+ from sf3d.system import SF3D
17
+ from sf3d.utils import resize_foreground
18
+
19
+ SF3D_CATEGORY = "StableFast3D"
20
+ SF3D_MODEL_NAME = "stabilityai/stable-fast-3d"
21
+
22
+
23
+ class StableFast3DLoader:
24
+ CATEGORY = SF3D_CATEGORY
25
+ FUNCTION = "load"
26
+ RETURN_NAMES = ("sf3d_model",)
27
+ RETURN_TYPES = ("SF3D_MODEL",)
28
+
29
+ @classmethod
30
+ def INPUT_TYPES(cls):
31
+ return {"required": {}}
32
+
33
+ def load(self):
34
+ device = comfy.model_management.get_torch_device()
35
+ model = SF3D.from_pretrained(
36
+ SF3D_MODEL_NAME,
37
+ config_name="config.yaml",
38
+ weight_name="model.safetensors",
39
+ )
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ return (model,)
44
+
45
+
46
+ class StableFast3DPreview:
47
+ CATEGORY = SF3D_CATEGORY
48
+ FUNCTION = "preview"
49
+ OUTPUT_NODE = True
50
+ RETURN_TYPES = ()
51
+
52
+ @classmethod
53
+ def INPUT_TYPES(s):
54
+ return {"required": {"mesh": ("MESH",)}}
55
+
56
+ def preview(self, mesh):
57
+ glbs = []
58
+ for m in mesh:
59
+ scene = trimesh.Scene(m)
60
+ glb_data = gltf.export_glb(scene, include_normals=True)
61
+ glb_base64 = base64.b64encode(glb_data).decode("utf-8")
62
+ glbs.append(glb_base64)
63
+ return {"ui": {"glbs": glbs}}
64
+
65
+
66
+ class StableFast3DSampler:
67
+ CATEGORY = SF3D_CATEGORY
68
+ FUNCTION = "predict"
69
+ RETURN_NAMES = ("mesh",)
70
+ RETURN_TYPES = ("MESH",)
71
+
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {
75
+ "required": {
76
+ "model": ("SF3D_MODEL",),
77
+ "image": ("IMAGE",),
78
+ "foreground_ratio": (
79
+ "FLOAT",
80
+ {"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.01},
81
+ ),
82
+ "texture_resolution": (
83
+ "INT",
84
+ {"default": 1024, "min": 512, "max": 2048, "step": 256},
85
+ ),
86
+ },
87
+ "optional": {
88
+ "mask": ("MASK",),
89
+ "remesh": (["none", "triangle", "quad"],),
90
+ "vertex_count": (
91
+ "INT",
92
+ {"default": -1, "min": -1, "max": 20000, "step": 1},
93
+ ),
94
+ },
95
+ }
96
+
97
+ def predict(
98
+ s,
99
+ model,
100
+ image,
101
+ mask,
102
+ foreground_ratio,
103
+ texture_resolution,
104
+ remesh="none",
105
+ vertex_count=-1,
106
+ ):
107
+ if image.shape[0] != 1:
108
+ raise ValueError("Only one image can be processed at a time")
109
+
110
+ pil_image = Image.fromarray(
111
+ torch.clamp(torch.round(255.0 * image[0]), 0, 255)
112
+ .type(torch.uint8)
113
+ .cpu()
114
+ .numpy()
115
+ )
116
+
117
+ if mask is not None:
118
+ print("Using Mask")
119
+ mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
120
+ np.uint8
121
+ )
122
+ mask_pil = Image.fromarray(mask_np, mode="L")
123
+ pil_image.putalpha(mask_pil)
124
+ else:
125
+ if image.shape[3] != 4:
126
+ print("No mask or alpha channel detected, Converting to RGBA")
127
+ pil_image = pil_image.convert("RGBA")
128
+
129
+ pil_image = resize_foreground(pil_image, foreground_ratio)
130
+ print(remesh)
131
+ with torch.no_grad():
132
+ with torch.autocast(
133
+ device_type="cuda", dtype=torch.bfloat16
134
+ ) if "cuda" in comfy.model_management.get_torch_device().type else nullcontext():
135
+ mesh, glob_dict = model.run_image(
136
+ pil_image,
137
+ bake_resolution=texture_resolution,
138
+ remesh=remesh,
139
+ vertex_count=vertex_count,
140
+ )
141
+
142
+ if mesh.vertices.shape[0] == 0:
143
+ raise ValueError("No subject detected in the image")
144
+
145
+ return ([mesh],)
146
+
147
+
148
+ class StableFast3DSave:
149
+ CATEGORY = SF3D_CATEGORY
150
+ FUNCTION = "save"
151
+ OUTPUT_NODE = True
152
+ RETURN_TYPES = ()
153
+
154
+ @classmethod
155
+ def INPUT_TYPES(s):
156
+ return {
157
+ "required": {
158
+ "mesh": ("MESH",),
159
+ "filename_prefix": ("STRING", {"default": "SF3D"}),
160
+ }
161
+ }
162
+
163
+ def __init__(self):
164
+ self.type = "output"
165
+
166
+ def save(self, mesh, filename_prefix):
167
+ output_dir = folder_paths.get_output_directory()
168
+ glbs = []
169
+ for idx, m in enumerate(mesh):
170
+ scene = trimesh.Scene(m)
171
+ glb_data = gltf.export_glb(scene, include_normals=True)
172
+ logging.info(f"Generated GLB model with {len(glb_data)} bytes")
173
+
174
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
175
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
176
+ )
177
+ filename = filename.replace("%batch_num%", str(idx))
178
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
179
+ with open(out_path, "wb") as f:
180
+ f.write(glb_data)
181
+ glbs.append(base64.b64encode(glb_data).decode("utf-8"))
182
+ return {"ui": {"glbs": glbs}}
183
+
184
+
185
+ NODE_DISPLAY_NAME_MAPPINGS = {
186
+ "StableFast3DLoader": "Stable Fast 3D Loader",
187
+ "StableFast3DPreview": "Stable Fast 3D Preview",
188
+ "StableFast3DSampler": "Stable Fast 3D Sampler",
189
+ "StableFast3DSave": "Stable Fast 3D Save",
190
+ }
191
+
192
+ NODE_CLASS_MAPPINGS = {
193
+ "StableFast3DLoader": StableFast3DLoader,
194
+ "StableFast3DPreview": StableFast3DPreview,
195
+ "StableFast3DSampler": StableFast3DSampler,
196
+ "StableFast3DSave": StableFast3DSave,
197
+ }
198
+
199
+ WEB_DIRECTORY = "./comfyui"
200
+
201
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
demo_files/scatterplot.jpg CHANGED
demo_files/workflows/sf3d_example.json ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "last_node_id": 10,
3
+ "last_link_id": 12,
4
+ "nodes": [
5
+ {
6
+ "id": 8,
7
+ "type": "StableFast3DSampler",
8
+ "pos": [
9
+ 756.9950672198843,
10
+ 9.735666739723854
11
+ ],
12
+ "size": {
13
+ "0": 315,
14
+ "1": 166
15
+ },
16
+ "flags": {},
17
+ "order": 3,
18
+ "mode": 0,
19
+ "inputs": [
20
+ {
21
+ "name": "model",
22
+ "type": "SF3D_MODEL",
23
+ "link": 8
24
+ },
25
+ {
26
+ "name": "image",
27
+ "type": "IMAGE",
28
+ "link": 10,
29
+ "slot_index": 1
30
+ },
31
+ {
32
+ "name": "mask",
33
+ "type": "MASK",
34
+ "link": 11
35
+ },
36
+ {
37
+ "name": "remesh",
38
+ "type": "none",
39
+ "link": null,
40
+ "slot_index": 3
41
+ }
42
+ ],
43
+ "outputs": [
44
+ {
45
+ "name": "mesh",
46
+ "type": "MESH",
47
+ "links": [
48
+ 9
49
+ ],
50
+ "shape": 3,
51
+ "slot_index": 0
52
+ }
53
+ ],
54
+ "properties": {
55
+ "Node name for S&R": "StableFast3DSampler"
56
+ },
57
+ "widgets_values": [
58
+ 0.85,
59
+ 1024,
60
+ "triangle"
61
+ ]
62
+ },
63
+ {
64
+ "id": 9,
65
+ "type": "StableFast3DSave",
66
+ "pos": [
67
+ 1116,
68
+ 8
69
+ ],
70
+ "size": [
71
+ 600,
72
+ 512
73
+ ],
74
+ "flags": {},
75
+ "order": 4,
76
+ "mode": 0,
77
+ "inputs": [
78
+ {
79
+ "name": "mesh",
80
+ "type": "MESH",
81
+ "link": 9
82
+ }
83
+ ],
84
+ "properties": {
85
+ "Node name for S&R": "StableFast3DSave"
86
+ },
87
+ "widgets_values": [
88
+ "SF3D",
89
+ null
90
+ ]
91
+ },
92
+ {
93
+ "id": 6,
94
+ "type": "InvertMask",
95
+ "pos": [
96
+ 485,
97
+ 132
98
+ ],
99
+ "size": {
100
+ "0": 210,
101
+ "1": 26
102
+ },
103
+ "flags": {},
104
+ "order": 2,
105
+ "mode": 0,
106
+ "inputs": [
107
+ {
108
+ "name": "mask",
109
+ "type": "MASK",
110
+ "link": 6
111
+ }
112
+ ],
113
+ "outputs": [
114
+ {
115
+ "name": "MASK",
116
+ "type": "MASK",
117
+ "links": [
118
+ 11
119
+ ],
120
+ "shape": 3,
121
+ "slot_index": 0
122
+ }
123
+ ],
124
+ "properties": {
125
+ "Node name for S&R": "InvertMask"
126
+ }
127
+ },
128
+ {
129
+ "id": 1,
130
+ "type": "LoadImage",
131
+ "pos": [
132
+ 105,
133
+ 26
134
+ ],
135
+ "size": {
136
+ "0": 315,
137
+ "1": 314
138
+ },
139
+ "flags": {},
140
+ "order": 0,
141
+ "mode": 0,
142
+ "outputs": [
143
+ {
144
+ "name": "IMAGE",
145
+ "type": "IMAGE",
146
+ "links": [
147
+ 10
148
+ ],
149
+ "shape": 3,
150
+ "slot_index": 0
151
+ },
152
+ {
153
+ "name": "MASK",
154
+ "type": "MASK",
155
+ "links": [
156
+ 6
157
+ ],
158
+ "shape": 3,
159
+ "slot_index": 1
160
+ }
161
+ ],
162
+ "properties": {
163
+ "Node name for S&R": "LoadImage"
164
+ },
165
+ "widgets_values": [
166
+ "axe (1).png",
167
+ "image"
168
+ ]
169
+ },
170
+ {
171
+ "id": 7,
172
+ "type": "StableFast3DLoader",
173
+ "pos": [
174
+ 478,
175
+ -27
176
+ ],
177
+ "size": {
178
+ "0": 210,
179
+ "1": 26
180
+ },
181
+ "flags": {},
182
+ "order": 1,
183
+ "mode": 0,
184
+ "outputs": [
185
+ {
186
+ "name": "sf3d_model",
187
+ "type": "SF3D_MODEL",
188
+ "links": [
189
+ 8
190
+ ],
191
+ "shape": 3,
192
+ "slot_index": 0
193
+ }
194
+ ],
195
+ "properties": {
196
+ "Node name for S&R": "StableFast3DLoader"
197
+ }
198
+ }
199
+ ],
200
+ "links": [
201
+ [
202
+ 6,
203
+ 1,
204
+ 1,
205
+ 6,
206
+ 0,
207
+ "MASK"
208
+ ],
209
+ [
210
+ 8,
211
+ 7,
212
+ 0,
213
+ 8,
214
+ 0,
215
+ "SF3D_MODEL"
216
+ ],
217
+ [
218
+ 9,
219
+ 8,
220
+ 0,
221
+ 9,
222
+ 0,
223
+ "MESH"
224
+ ],
225
+ [
226
+ 10,
227
+ 1,
228
+ 0,
229
+ 8,
230
+ 1,
231
+ "IMAGE"
232
+ ],
233
+ [
234
+ 11,
235
+ 6,
236
+ 0,
237
+ 8,
238
+ 2,
239
+ "MASK"
240
+ ]
241
+ ],
242
+ "groups": [],
243
+ "config": {},
244
+ "extra": {
245
+ "ds": {
246
+ "scale": 0.6209213230591552,
247
+ "offset": [
248
+ 80.89139921077967,
249
+ 610.3296066172098
250
+ ]
251
+ }
252
+ },
253
+ "version": 0.4
254
+ }
app.py → gradio_app.py RENAMED
@@ -1,6 +1,7 @@
1
  import os
2
  import tempfile
3
  import time
 
4
  from functools import lru_cache
5
  from typing import Any
6
 
@@ -11,9 +12,13 @@ import torch
11
  from gradio_litmodel3d import LitModel3D
12
  from PIL import Image
13
 
 
 
14
  import sf3d.utils as sf3d_utils
15
  from sf3d.system import SF3D
16
 
 
 
17
  rembg_session = rembg.new_session()
18
 
19
  COND_WIDTH = 512
@@ -28,32 +33,48 @@ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
28
  COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  model = SF3D.from_pretrained(
33
  "stabilityai/stable-fast-3d",
34
  config_name="config.yaml",
35
  weight_name="model.safetensors",
36
  )
37
- model.eval().cuda()
 
38
 
39
  example_files = [
40
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
41
  ]
42
 
43
 
44
- def run_model(input_image):
45
  start = time.time()
46
  with torch.no_grad():
47
- with torch.autocast(device_type="cuda", dtype=torch.float16):
 
 
48
  model_batch = create_batch(input_image)
49
- model_batch = {k: v.cuda() for k, v in model_batch.items()}
50
- trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
 
 
51
  trimesh_mesh = trimesh_mesh[0]
52
 
53
  # Create new tmp file
54
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
55
 
56
  trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
 
57
 
58
  print("Generation took:", time.time() - start, "s")
59
 
@@ -104,61 +125,6 @@ def remove_background(input_image: Image) -> Image:
104
  return rembg.remove(input_image, session=rembg_session)
105
 
106
 
107
- def resize_foreground(
108
- image: Image,
109
- ratio: float,
110
- ) -> Image:
111
- image = np.array(image)
112
- assert image.shape[-1] == 4
113
- alpha = np.where(image[..., 3] > 0)
114
- y1, y2, x1, x2 = (
115
- alpha[0].min(),
116
- alpha[0].max(),
117
- alpha[1].min(),
118
- alpha[1].max(),
119
- )
120
- # crop the foreground
121
- fg = image[y1:y2, x1:x2]
122
- # pad to square
123
- size = max(fg.shape[0], fg.shape[1])
124
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
125
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
126
- new_image = np.pad(
127
- fg,
128
- ((ph0, ph1), (pw0, pw1), (0, 0)),
129
- mode="constant",
130
- constant_values=((0, 0), (0, 0), (0, 0)),
131
- )
132
-
133
- # compute padding according to the ratio
134
- new_size = int(new_image.shape[0] / ratio)
135
- # pad to size, double side
136
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
137
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
138
- new_image = np.pad(
139
- new_image,
140
- ((ph0, ph1), (pw0, pw1), (0, 0)),
141
- mode="constant",
142
- constant_values=((0, 0), (0, 0), (0, 0)),
143
- )
144
- new_image = Image.fromarray(new_image, mode="RGBA").resize(
145
- (COND_WIDTH, COND_HEIGHT)
146
- )
147
- return new_image
148
-
149
-
150
- def square_crop(input_image: Image) -> Image:
151
- # Perform a center square crop
152
- min_size = min(input_image.size)
153
- left = (input_image.size[0] - min_size) // 2
154
- top = (input_image.size[1] - min_size) // 2
155
- right = (input_image.size[0] + min_size) // 2
156
- bottom = (input_image.size[1] + min_size) // 2
157
- return input_image.crop((left, top, right, bottom)).resize(
158
- (COND_WIDTH, COND_HEIGHT)
159
- )
160
-
161
-
162
  def show_mask_img(input_image: Image) -> Image:
163
  img_numpy = np.array(input_image)
164
  alpha = img_numpy[:, :, 3] / 255.0
@@ -167,9 +133,27 @@ def show_mask_img(input_image: Image) -> Image:
167
  return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
168
 
169
 
170
- def run_button(run_btn, input_image, background_state, foreground_ratio):
 
 
 
 
 
 
 
 
171
  if run_btn == "Run":
172
- glb_file: str = run_model(background_state)
 
 
 
 
 
 
 
 
 
 
173
 
174
  return (
175
  gr.update(),
@@ -182,12 +166,13 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
182
  elif run_btn == "Remove Background":
183
  rem_removed = remove_background(input_image)
184
 
185
- sqr_crop = square_crop(rem_removed)
186
- fr_res = resize_foreground(sqr_crop, foreground_ratio)
 
187
 
188
  return (
189
  gr.update(value="Run", visible=True),
190
- sqr_crop,
191
  fr_res,
192
  gr.update(value=show_mask_img(fr_res), visible=True),
193
  gr.update(value=None, visible=False),
@@ -210,11 +195,12 @@ def requires_bg_remove(image, fr):
210
 
211
  if min_alpha == 0:
212
  print("Already has alpha")
213
- sqr_crop = square_crop(image)
214
- fr_res = resize_foreground(sqr_crop, fr)
 
215
  return (
216
  gr.update(value="Run", visible=True),
217
- sqr_crop,
218
  fr_res,
219
  gr.update(value=show_mask_img(fr_res), visible=True),
220
  gr.update(visible=False),
@@ -231,7 +217,9 @@ def requires_bg_remove(image, fr):
231
 
232
 
233
  def update_foreground_ratio(img_proc, fr):
234
- foreground_res = resize_foreground(img_proc, fr)
 
 
235
  return (
236
  foreground_res,
237
  gr.update(value=show_mask_img(foreground_res)),
@@ -250,7 +238,8 @@ with gr.Blocks() as demo:
250
  **Tips**
251
  1. If the image already has an alpha channel, you can skip the background removal step.
252
  2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
253
- 3. You can upload your own HDR environment map to light the 3D model.
 
254
  """)
255
  with gr.Row(variant="panel"):
256
  with gr.Column():
@@ -280,6 +269,30 @@ with gr.Blocks() as demo:
280
  outputs=[background_remove_state, preview_removal],
281
  )
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  run_btn = gr.Button("Run", variant="primary", visible=False)
284
 
285
  with gr.Column():
@@ -341,6 +354,9 @@ with gr.Blocks() as demo:
341
  input_img,
342
  background_remove_state,
343
  foreground_ratio,
 
 
 
344
  ],
345
  outputs=[
346
  run_btn,
@@ -352,4 +368,4 @@ with gr.Blocks() as demo:
352
  ],
353
  )
354
 
355
- demo.launch()
 
1
  import os
2
  import tempfile
3
  import time
4
+ from contextlib import nullcontext
5
  from functools import lru_cache
6
  from typing import Any
7
 
 
12
  from gradio_litmodel3d import LitModel3D
13
  from PIL import Image
14
 
15
+ os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
16
+
17
  import sf3d.utils as sf3d_utils
18
  from sf3d.system import SF3D
19
 
20
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
21
+
22
  rembg_session = rembg.new_session()
23
 
24
  COND_WIDTH = 512
 
33
  COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
34
  )
35
 
36
+ generated_files = []
37
+
38
+ # Delete previous gradio temp dir folder
39
+ if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
40
+ print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
41
+ import shutil
42
+
43
+ shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
44
+
45
+ device = sf3d_utils.get_device()
46
 
47
  model = SF3D.from_pretrained(
48
  "stabilityai/stable-fast-3d",
49
  config_name="config.yaml",
50
  weight_name="model.safetensors",
51
  )
52
+ model.eval()
53
+ model = model.to(device)
54
 
55
  example_files = [
56
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
57
  ]
58
 
59
 
60
+ def run_model(input_image, remesh_option, vertex_count, texture_size):
61
  start = time.time()
62
  with torch.no_grad():
63
+ with torch.autocast(
64
+ device_type=device, dtype=torch.bfloat16
65
+ ) if "cuda" in device else nullcontext():
66
  model_batch = create_batch(input_image)
67
+ model_batch = {k: v.to(device) for k, v in model_batch.items()}
68
+ trimesh_mesh, _glob_dict = model.generate_mesh(
69
+ model_batch, texture_size, remesh_option, vertex_count
70
+ )
71
  trimesh_mesh = trimesh_mesh[0]
72
 
73
  # Create new tmp file
74
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
75
 
76
  trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
77
+ generated_files.append(tmp_file.name)
78
 
79
  print("Generation took:", time.time() - start, "s")
80
 
 
125
  return rembg.remove(input_image, session=rembg_session)
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def show_mask_img(input_image: Image) -> Image:
129
  img_numpy = np.array(input_image)
130
  alpha = img_numpy[:, :, 3] / 255.0
 
133
  return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
134
 
135
 
136
+ def run_button(
137
+ run_btn,
138
+ input_image,
139
+ background_state,
140
+ foreground_ratio,
141
+ remesh_option,
142
+ vertex_count,
143
+ texture_size,
144
+ ):
145
  if run_btn == "Run":
146
+ if torch.cuda.is_available():
147
+ torch.cuda.reset_peak_memory_stats()
148
+ glb_file: str = run_model(
149
+ background_state, remesh_option.lower(), vertex_count, texture_size
150
+ )
151
+ if torch.cuda.is_available():
152
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
153
+ elif torch.backends.mps.is_available():
154
+ print(
155
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
156
+ )
157
 
158
  return (
159
  gr.update(),
 
166
  elif run_btn == "Remove Background":
167
  rem_removed = remove_background(input_image)
168
 
169
+ fr_res = sf3d_utils.resize_foreground(
170
+ rem_removed, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
171
+ )
172
 
173
  return (
174
  gr.update(value="Run", visible=True),
175
+ rem_removed,
176
  fr_res,
177
  gr.update(value=show_mask_img(fr_res), visible=True),
178
  gr.update(value=None, visible=False),
 
195
 
196
  if min_alpha == 0:
197
  print("Already has alpha")
198
+ fr_res = sf3d_utils.resize_foreground(
199
+ image, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
200
+ )
201
  return (
202
  gr.update(value="Run", visible=True),
203
+ image,
204
  fr_res,
205
  gr.update(value=show_mask_img(fr_res), visible=True),
206
  gr.update(visible=False),
 
217
 
218
 
219
  def update_foreground_ratio(img_proc, fr):
220
+ foreground_res = sf3d_utils.resize_foreground(
221
+ img_proc, fr, out_size=(COND_WIDTH, COND_HEIGHT)
222
+ )
223
  return (
224
  foreground_res,
225
  gr.update(value=show_mask_img(foreground_res)),
 
238
  **Tips**
239
  1. If the image already has an alpha channel, you can skip the background removal step.
240
  2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
241
+ 3. You can select the remeshing option to control the mesh topology. This can introduce artifacts in the mesh on thin surfaces and should be turned off in such cases.
242
+ 4. You can upload your own HDR environment map to light the 3D model.
243
  """)
244
  with gr.Row(variant="panel"):
245
  with gr.Column():
 
269
  outputs=[background_remove_state, preview_removal],
270
  )
271
 
272
+ remesh_option = gr.Radio(
273
+ choices=["None", "Triangle", "Quad"],
274
+ label="Remeshing",
275
+ value="None",
276
+ visible=True,
277
+ )
278
+
279
+ vertex_count_slider = gr.Slider(
280
+ label="Target Vertex Count",
281
+ minimum=-1,
282
+ maximum=20000,
283
+ value=-1,
284
+ visible=True,
285
+ )
286
+
287
+ texture_size = gr.Slider(
288
+ label="Texture Size",
289
+ minimum=512,
290
+ maximum=2048,
291
+ value=1024,
292
+ step=256,
293
+ visible=True,
294
+ )
295
+
296
  run_btn = gr.Button("Run", variant="primary", visible=False)
297
 
298
  with gr.Column():
 
354
  input_img,
355
  background_remove_state,
356
  foreground_ratio,
357
+ remesh_option,
358
+ vertex_count_slider,
359
+ texture_size,
360
  ],
361
  outputs=[
362
  run_btn,
 
368
  ],
369
  )
370
 
371
+ demo.queue().launch(share=False)
requirements.txt CHANGED
@@ -1,13 +1,21 @@
1
- torch==2.1.2
2
- torchvision==0.16.2
 
 
3
  einops==0.7.0
4
  jaxtyping==0.2.31
5
  omegaconf==2.3.0
6
  transformers==4.42.3
7
- slangtorch==1.2.2
8
  open_clip_torch==2.24.0
9
  trimesh==4.4.1
10
  numpy==1.26.4
11
  huggingface-hub==0.23.4
12
- rembg[gpu]==2.0.57
 
 
 
 
13
  gradio-litmodel3d==0.0.1
 
 
 
 
1
+ wheel
2
+ setuptools==69.5.1
3
+ torch==2.5.1
4
+ torchvision==0.20.1
5
  einops==0.7.0
6
  jaxtyping==0.2.31
7
  omegaconf==2.3.0
8
  transformers==4.42.3
 
9
  open_clip_torch==2.24.0
10
  trimesh==4.4.1
11
  numpy==1.26.4
12
  huggingface-hub==0.23.4
13
+ rembg[gpu]==2.0.57; sys_platform != 'darwin'
14
+ rembg==2.0.57; sys_platform == 'darwin'
15
+ pynanoinstantmeshes==0.0.3
16
+ gpytoolbox==0.2.0
17
+ gradio==4.41.0
18
  gradio-litmodel3d==0.0.1
19
+ # (HF hack) These are installed at runtime in gradio_app.py
20
+ # ./texture_baker/
21
+ # ./uv_unwrapper/
ruff.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [lint]
2
+ ignore = ["F722"]
3
+ extend-select = ["I"]
run.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from contextlib import nullcontext
4
+
5
+ import rembg
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+ from sf3d.system import SF3D
11
+ from sf3d.utils import get_device, remove_background, resize_foreground
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "image", type=str, nargs="+", help="Path to input image(s) or folder."
17
+ )
18
+ parser.add_argument(
19
+ "--device",
20
+ default=get_device(),
21
+ type=str,
22
+ help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
23
+ )
24
+ parser.add_argument(
25
+ "--pretrained-model",
26
+ default="stabilityai/stable-fast-3d",
27
+ type=str,
28
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-fast-3d'",
29
+ )
30
+ parser.add_argument(
31
+ "--foreground-ratio",
32
+ default=0.85,
33
+ type=float,
34
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
35
+ )
36
+ parser.add_argument(
37
+ "--output-dir",
38
+ default="output/",
39
+ type=str,
40
+ help="Output directory to save the results. Default: 'output/'",
41
+ )
42
+ parser.add_argument(
43
+ "--texture-resolution",
44
+ default=1024,
45
+ type=int,
46
+ help="Texture atlas resolution. Default: 1024",
47
+ )
48
+ parser.add_argument(
49
+ "--remesh_option",
50
+ choices=["none", "triangle", "quad"],
51
+ default="none",
52
+ help="Remeshing option",
53
+ )
54
+ parser.add_argument(
55
+ "--target_vertex_count",
56
+ type=int,
57
+ help="Target vertex count. -1 does not perform a reduction.",
58
+ default=-1,
59
+ )
60
+ parser.add_argument(
61
+ "--batch_size", default=1, type=int, help="Batch size for inference"
62
+ )
63
+ args = parser.parse_args()
64
+
65
+ # Ensure args.device contains cuda
66
+ devices = ["cuda", "mps", "cpu"]
67
+ if not any(args.device in device for device in devices):
68
+ raise ValueError("Invalid device. Use cuda, mps or cpu")
69
+
70
+ output_dir = args.output_dir
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ device = args.device
74
+ if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
75
+ device = "cpu"
76
+
77
+ print("Device used: ", device)
78
+
79
+ model = SF3D.from_pretrained(
80
+ args.pretrained_model,
81
+ config_name="config.yaml",
82
+ weight_name="model.safetensors",
83
+ )
84
+ model.to(device)
85
+ model.eval()
86
+
87
+ rembg_session = rembg.new_session()
88
+ images = []
89
+ idx = 0
90
+ for image_path in args.image:
91
+
92
+ def handle_image(image_path, idx):
93
+ image = remove_background(
94
+ Image.open(image_path).convert("RGBA"), rembg_session
95
+ )
96
+ image = resize_foreground(image, args.foreground_ratio)
97
+ os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
98
+ image.save(os.path.join(output_dir, str(idx), "input.png"))
99
+ images.append(image)
100
+
101
+ if os.path.isdir(image_path):
102
+ image_paths = [
103
+ os.path.join(image_path, f)
104
+ for f in os.listdir(image_path)
105
+ if f.endswith((".png", ".jpg", ".jpeg"))
106
+ ]
107
+ for image_path in image_paths:
108
+ handle_image(image_path, idx)
109
+ idx += 1
110
+ else:
111
+ handle_image(image_path, idx)
112
+ idx += 1
113
+
114
+ for i in tqdm(range(0, len(images), args.batch_size)):
115
+ image = images[i : i + args.batch_size]
116
+ if torch.cuda.is_available():
117
+ torch.cuda.reset_peak_memory_stats()
118
+ with torch.no_grad():
119
+ with torch.autocast(
120
+ device_type=device, dtype=torch.bfloat16
121
+ ) if "cuda" in device else nullcontext():
122
+ mesh, glob_dict = model.run_image(
123
+ image,
124
+ bake_resolution=args.texture_resolution,
125
+ remesh=args.remesh_option,
126
+ vertex_count=args.target_vertex_count,
127
+ )
128
+ if torch.cuda.is_available():
129
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
130
+ elif torch.backends.mps.is_available():
131
+ print(
132
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
133
+ )
134
+
135
+ if len(image) == 1:
136
+ out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
137
+ mesh.export(out_mesh_path, include_normals=True)
138
+ else:
139
+ for j in range(len(mesh)):
140
+ out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
141
+ mesh[j].export(out_mesh_path, include_normals=True)
sf3d/models/image_estimator/clip_based_estimator.py CHANGED
@@ -95,7 +95,7 @@ class ClipBasedHeadEstimator(BaseModule):
95
  # Run the model
96
  # Resize cond_image to 224
97
  cond_image = nn.functional.interpolate(
98
- cond_image.flatten(0, 1).permute(0, 3, 1, 2),
99
  size=(224, 224),
100
  mode="bilinear",
101
  align_corners=False,
 
95
  # Run the model
96
  # Resize cond_image to 224
97
  cond_image = nn.functional.interpolate(
98
+ cond_image.flatten(0, 1).permute(0, 3, 1, 2).contiguous(),
99
  size=(224, 224),
100
  mode="bilinear",
101
  align_corners=False,
sf3d/models/mesh.py CHANGED
@@ -1,15 +1,30 @@
1
  from __future__ import annotations
2
 
 
3
  from typing import Any, Dict, Optional
4
 
 
 
 
5
  import torch
6
  import torch.nn.functional as F
 
7
  from jaxtyping import Float, Integer
8
  from torch import Tensor
9
 
10
- from sf3d.box_uv_unwrap import box_projection_uv_unwrap
11
  from sf3d.models.utils import dot
12
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class Mesh:
15
  def __init__(
@@ -25,6 +40,8 @@ class Mesh:
25
  for k, v in kwargs.items():
26
  self.add_extra(k, v)
27
 
 
 
28
  def add_extra(self, k, v) -> None:
29
  self.extras[k] = v
30
 
@@ -131,12 +148,112 @@ class Mesh:
131
 
132
  return tangents
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @torch.no_grad()
135
  def unwrap_uv(
136
  self,
137
  island_padding: float = 0.02,
138
  ) -> Mesh:
139
- uv, indices = box_projection_uv_unwrap(
140
  self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
141
  )
142
 
 
1
  from __future__ import annotations
2
 
3
+ import math
4
  from typing import Any, Dict, Optional
5
 
6
+ import gpytoolbox
7
+ import numpy as np
8
+ import pynanoinstantmeshes
9
  import torch
10
  import torch.nn.functional as F
11
+ import trimesh
12
  from jaxtyping import Float, Integer
13
  from torch import Tensor
14
 
 
15
  from sf3d.models.utils import dot
16
 
17
+ try:
18
+ from uv_unwrapper import Unwrapper
19
+ except ImportError:
20
+ import logging
21
+
22
+ logging.warning(
23
+ "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
24
+ )
25
+ # Exit early to avoid further errors
26
+ raise ImportError("uv_unwrapper not found")
27
+
28
 
29
  class Mesh:
30
  def __init__(
 
40
  for k, v in kwargs.items():
41
  self.add_extra(k, v)
42
 
43
+ self.unwrapper = Unwrapper()
44
+
45
  def add_extra(self, k, v) -> None:
46
  self.extras[k] = v
47
 
 
148
 
149
  return tangents
150
 
151
+ def quad_remesh(
152
+ self,
153
+ quad_vertex_count: int = -1,
154
+ quad_rosy: int = 4,
155
+ quad_crease_angle: float = -1.0,
156
+ quad_smooth_iter: int = 2,
157
+ quad_align_to_boundaries: bool = False,
158
+ ) -> Mesh:
159
+ if quad_vertex_count < 0:
160
+ quad_vertex_count = self.v_pos.shape[0]
161
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
162
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
163
+
164
+ new_vert, new_faces = pynanoinstantmeshes.remesh(
165
+ v_pos,
166
+ t_pos_idx,
167
+ quad_vertex_count // 4,
168
+ rosy=quad_rosy,
169
+ posy=4,
170
+ creaseAngle=quad_crease_angle,
171
+ align_to_boundaries=quad_align_to_boundaries,
172
+ smooth_iter=quad_smooth_iter,
173
+ deterministic=False,
174
+ )
175
+
176
+ # Briefly load in trimesh
177
+ mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
178
+
179
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
180
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
181
+
182
+ # Create new mesh
183
+ return Mesh(v_pos, t_pos_idx)
184
+
185
+ def triangle_remesh(
186
+ self,
187
+ triangle_average_edge_length_multiplier: Optional[float] = None,
188
+ triangle_remesh_steps: int = 10,
189
+ triangle_vertex_count=-1,
190
+ ):
191
+ if triangle_vertex_count > 0:
192
+ reduction = triangle_vertex_count / self.v_pos.shape[0]
193
+ print("Triangle reduction:", reduction)
194
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
195
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
196
+ if reduction > 1.0:
197
+ subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
198
+ print("Subdivide iters:", subdivide_iters)
199
+ v_pos, t_pos_idx = gpytoolbox.subdivide(
200
+ v_pos,
201
+ t_pos_idx,
202
+ iters=subdivide_iters,
203
+ )
204
+ reduction = triangle_vertex_count / v_pos.shape[0]
205
+
206
+ # Simplify
207
+ points_out, faces_out, _, _ = gpytoolbox.decimate(
208
+ v_pos,
209
+ t_pos_idx,
210
+ face_ratio=reduction,
211
+ )
212
+
213
+ # Convert back to torch
214
+ self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
215
+ self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
216
+ self._edges = None
217
+ triangle_average_edge_length_multiplier = None
218
+
219
+ edges = self.edges
220
+ if triangle_average_edge_length_multiplier is None:
221
+ h = None
222
+ else:
223
+ h = float(
224
+ torch.linalg.norm(
225
+ self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
226
+ )
227
+ .mean()
228
+ .item()
229
+ * triangle_average_edge_length_multiplier
230
+ )
231
+
232
+ # Convert to numpy
233
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
234
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
235
+
236
+ # Remesh
237
+ v_remesh, f_remesh = gpytoolbox.remesh_botsch(
238
+ v_pos,
239
+ t_pos_idx,
240
+ triangle_remesh_steps,
241
+ h,
242
+ )
243
+
244
+ # Convert back to torch
245
+ v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
246
+ t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
247
+
248
+ # Create new mesh
249
+ return Mesh(v_pos, t_pos_idx)
250
+
251
  @torch.no_grad()
252
  def unwrap_uv(
253
  self,
254
  island_padding: float = 0.02,
255
  ) -> Mesh:
256
+ uv, indices = self.unwrapper(
257
  self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
258
  )
259
 
sf3d/models/network.py CHANGED
@@ -7,10 +7,23 @@ import torch.nn.functional as F
7
  from einops import rearrange
8
  from jaxtyping import Float
9
  from torch import Tensor
 
10
  from torch.autograd import Function
11
- from torch.cuda.amp import custom_bwd, custom_fwd
12
 
13
  from sf3d.models.utils import BaseModule, normalize
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class PixelShuffleUpsampleNetwork(BaseModule):
@@ -65,13 +78,18 @@ class _TruncExp(Function): # pylint: disable=abstract-method
65
  # Implementation from torch-ngp:
66
  # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
67
  @staticmethod
68
- @custom_fwd(cast_inputs=torch.float32)
 
 
 
 
 
69
  def forward(ctx, x): # pylint: disable=arguments-differ
70
  ctx.save_for_backward(x)
71
  return torch.exp(x)
72
 
73
  @staticmethod
74
- @custom_bwd
75
  def backward(ctx, g): # pylint: disable=arguments-differ
76
  x = ctx.saved_tensors[0]
77
  return g * torch.exp(torch.clamp(x, max=15))
 
7
  from einops import rearrange
8
  from jaxtyping import Float
9
  from torch import Tensor
10
+ from torch.amp import custom_bwd, custom_fwd
11
  from torch.autograd import Function
 
12
 
13
  from sf3d.models.utils import BaseModule, normalize
14
+ from sf3d.utils import get_device
15
+
16
+
17
+ def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
18
+ def wrapper(fn):
19
+ if condition:
20
+ if len(kwargs) == 0:
21
+ return decorator_with_args
22
+ return decorator_with_args(*args, **kwargs)(fn)
23
+ else:
24
+ return fn
25
+
26
+ return wrapper
27
 
28
 
29
  class PixelShuffleUpsampleNetwork(BaseModule):
 
78
  # Implementation from torch-ngp:
79
  # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
  @staticmethod
81
+ @conditional_decorator(
82
+ custom_fwd,
83
+ "cuda" in get_device(),
84
+ cast_inputs=torch.float32,
85
+ device_type="cuda",
86
+ )
87
  def forward(ctx, x): # pylint: disable=arguments-differ
88
  ctx.save_for_backward(x)
89
  return torch.exp(x)
90
 
91
  @staticmethod
92
+ @conditional_decorator(custom_bwd, "cuda" in get_device())
93
  def backward(ctx, g): # pylint: disable=arguments-differ
94
  x = ctx.saved_tensors[0]
95
  return g * torch.exp(torch.clamp(x, max=15))
sf3d/models/utils.py CHANGED
@@ -1,6 +1,5 @@
1
  import dataclasses
2
  import importlib
3
- import math
4
  from dataclasses import dataclass
5
  from typing import Any, List, Optional, Tuple, Union
6
 
@@ -9,7 +8,7 @@ import PIL
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
- from jaxtyping import Bool, Float, Int, Num
13
  from omegaconf import DictConfig, OmegaConf
14
  from torch import Tensor
15
 
@@ -77,61 +76,6 @@ def normalize(x, dim=-1, eps=None):
77
  return F.normalize(x, dim=dim, p=2, eps=eps)
78
 
79
 
80
- def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
81
- # One pad for determinant
82
- tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
83
- det_tri = torch.det(tri_sq)
84
- tri_rev = torch.cat(
85
- (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
86
- )
87
- tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
88
- return tri_sq
89
-
90
-
91
- def triangle_intersection_2d(
92
- t1: Float[Tensor, "*B 3 2"],
93
- t2: Float[Tensor, "*B 3 2"],
94
- eps=1e-12,
95
- ) -> Float[Tensor, "*B"]: # noqa: F821
96
- """Returns True if triangles collide, False otherwise"""
97
-
98
- def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
99
- logdetx = torch.logdet(x.double())
100
- if eps is None:
101
- return ~torch.isfinite(logdetx)
102
- return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
103
-
104
- t1s = tri_winding(t1)
105
- t2s = tri_winding(t2)
106
-
107
- # Assume the triangles do not collide in the begging
108
- ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
109
- for i in range(3):
110
- edge = torch.roll(t1s, i, dims=1)[:, :2, :]
111
- # Check if all points of triangle 2 lay on the external side of edge E.
112
- # If this is the case the triangle do not collide
113
- upd = (
114
- chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
115
- & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
116
- & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
117
- )
118
- # Here no collision is still True due to inversion
119
- ret = ret | upd
120
-
121
- for i in range(3):
122
- edge = torch.roll(t2s, i, dims=1)[:, :2, :]
123
-
124
- upd = (
125
- chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
126
- & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
127
- & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
128
- )
129
- # Here no collision is still True due to inversion
130
- ret = ret | upd
131
-
132
- return ~ret # Do the inversion
133
-
134
-
135
  ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
136
 
137
 
 
1
  import dataclasses
2
  import importlib
 
3
  from dataclasses import dataclass
4
  from typing import Any, List, Optional, Tuple, Union
5
 
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from jaxtyping import Float, Int, Num
12
  from omegaconf import DictConfig, OmegaConf
13
  from torch import Tensor
14
 
 
76
  return F.normalize(x, dim=dim, p=2, eps=eps)
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
80
 
81
 
sf3d/system.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
 
2
  from dataclasses import dataclass, field
3
- from typing import Any, List, Optional, Tuple
4
 
5
  import numpy as np
6
  import torch
@@ -21,15 +22,23 @@ from sf3d.models.utils import (
21
  ImageProcessor,
22
  convert_data,
23
  dilate_fill,
24
- dot,
25
  find_class,
26
  float32_to_uint8_np,
27
  normalize,
28
  scale_tensor,
29
  )
30
- from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
31
 
32
- from .texture_baker import TextureBaker
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  class SF3D(BaseModule):
@@ -206,6 +215,7 @@ class SF3D(BaseModule):
206
  batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
207
  batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
208
  batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
 
209
  batch_size, n_input_views = batch["rgb_cond"].shape[:2]
210
 
211
  camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
@@ -234,10 +244,54 @@ class SF3D(BaseModule):
234
 
235
  def run_image(
236
  self,
237
- image: Image,
238
  bake_resolution: int,
 
 
239
  estimate_illumination: bool = False,
240
- ) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if image.mode != "RGBA":
242
  raise ValueError("Image must be in RGBA mode")
243
  img_cond = (
@@ -258,30 +312,14 @@ class SF3D(BaseModule):
258
  mask_cond,
259
  )
260
 
261
- c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
262
- intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
263
- self.cfg.default_fovy_deg,
264
- self.cfg.cond_image_size,
265
- self.cfg.cond_image_size,
266
- )
267
-
268
- batch = {
269
- "rgb_cond": rgb_cond,
270
- "mask_cond": mask_cond,
271
- "c2w_cond": c2w_cond.unsqueeze(0),
272
- "intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
273
- "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
274
- }
275
-
276
- meshes, global_dict = self.generate_mesh(
277
- batch, bake_resolution, estimate_illumination
278
- )
279
- return meshes[0], global_dict
280
 
281
  def generate_mesh(
282
  self,
283
  batch,
284
  bake_resolution: int,
 
 
285
  estimate_illumination: bool = False,
286
  ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
287
  batch["rgb_cond"] = self.image_processor(
@@ -300,8 +338,11 @@ class SF3D(BaseModule):
300
  if self.global_estimator is not None and estimate_illumination:
301
  global_dict.update(self.global_estimator(non_postprocessed_codes))
302
 
 
303
  with torch.no_grad():
304
- with torch.autocast(device_type="cuda", enabled=False):
 
 
305
  meshes = self.triplane_to_meshes(scene_codes)
306
 
307
  rets = []
@@ -311,6 +352,17 @@ class SF3D(BaseModule):
311
  rets.append(trimesh.Trimesh())
312
  continue
313
 
 
 
 
 
 
 
 
 
 
 
 
314
  mesh.unwrap_uv()
315
 
316
  # Build textures
@@ -323,7 +375,6 @@ class SF3D(BaseModule):
323
  mesh.v_pos,
324
  rast,
325
  mesh.t_pos_idx,
326
- mesh.v_tex,
327
  )
328
  gb_pos = pos_bake[bake_mask]
329
 
@@ -336,7 +387,6 @@ class SF3D(BaseModule):
336
  mesh.v_nrm,
337
  rast,
338
  mesh.t_pos_idx,
339
- mesh.v_tex,
340
  )
341
  gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
342
  decoded["normal"] = gb_nrm
@@ -377,29 +427,28 @@ class SF3D(BaseModule):
377
  mesh.v_tng,
378
  rast,
379
  mesh.t_pos_idx,
380
- mesh.v_tex,
381
  )
382
  gb_tng = tng[bake_mask]
383
  gb_tng = F.normalize(gb_tng, dim=-1)
384
  gb_btng = F.normalize(
385
- torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
386
  )
387
  normal = F.normalize(mat_out["normal"], dim=-1)
388
 
389
- bump = torch.cat(
390
- # Check if we have to flip some things
391
- (
392
- dot(normal, gb_tng),
393
- dot(normal, gb_btng),
394
- dot(normal, gb_nrm).clip(
395
- 0.3, 1
396
- ), # Never go below 0.3. This would indicate a flipped (or close to one) normal
397
- ),
398
- -1,
 
399
  )
400
- bump = (bump * 0.5 + 0.5).clamp(0, 1)
401
 
402
- f[bake_mask] = bump.view(-1, 3)
403
  mat_out["bump"] = f
404
  else:
405
  f[bake_mask] = v.view(-1, v.shape[-1])
@@ -410,12 +459,13 @@ class SF3D(BaseModule):
410
  return arr
411
  return (
412
  dilate_fill(
413
- arr.permute(2, 0, 1)[None, ...],
414
  bake_mask.unsqueeze(0).unsqueeze(0),
415
  iterations=bake_resolution // 150,
416
  )
417
  .squeeze(0)
418
  .permute(1, 2, 0)
 
419
  )
420
 
421
  verts_np = convert_data(mesh.v_pos)
 
1
  import os
2
+ from contextlib import nullcontext
3
  from dataclasses import dataclass, field
4
+ from typing import Any, List, Literal, Optional, Tuple, Union
5
 
6
  import numpy as np
7
  import torch
 
22
  ImageProcessor,
23
  convert_data,
24
  dilate_fill,
 
25
  find_class,
26
  float32_to_uint8_np,
27
  normalize,
28
  scale_tensor,
29
  )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w, get_device
31
 
32
+ try:
33
+ from texture_baker import TextureBaker
34
+ except ImportError:
35
+ import logging
36
+
37
+ logging.warning(
38
+ "Could not import texture_baker. Please install it via `pip install texture-baker/`"
39
+ )
40
+ # Exit early to avoid further errors
41
+ raise ImportError("texture_baker not found")
42
 
43
 
44
  class SF3D(BaseModule):
 
215
  batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
216
  batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
217
  batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
218
+
219
  batch_size, n_input_views = batch["rgb_cond"].shape[:2]
220
 
221
  camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
 
244
 
245
  def run_image(
246
  self,
247
+ image: Union[Image.Image, List[Image.Image]],
248
  bake_resolution: int,
249
+ remesh: Literal["none", "triangle", "quad"] = "none",
250
+ vertex_count: int = -1,
251
  estimate_illumination: bool = False,
252
+ ) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
253
+ if isinstance(image, list):
254
+ rgb_cond = []
255
+ mask_cond = []
256
+ for img in image:
257
+ mask, rgb = self.prepare_image(img)
258
+ mask_cond.append(mask)
259
+ rgb_cond.append(rgb)
260
+ rgb_cond = torch.stack(rgb_cond, 0)
261
+ mask_cond = torch.stack(mask_cond, 0)
262
+ batch_size = rgb_cond.shape[0]
263
+ else:
264
+ mask_cond, rgb_cond = self.prepare_image(image)
265
+ batch_size = 1
266
+
267
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
268
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
269
+ self.cfg.default_fovy_deg,
270
+ self.cfg.cond_image_size,
271
+ self.cfg.cond_image_size,
272
+ )
273
+
274
+ batch = {
275
+ "rgb_cond": rgb_cond,
276
+ "mask_cond": mask_cond,
277
+ "c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
278
+ "intrinsic_cond": intrinsic.to(self.device)
279
+ .view(1, 1, 3, 3)
280
+ .repeat(batch_size, 1, 1, 1),
281
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
282
+ .view(1, 1, 3, 3)
283
+ .repeat(batch_size, 1, 1, 1),
284
+ }
285
+
286
+ meshes, global_dict = self.generate_mesh(
287
+ batch, bake_resolution, remesh, vertex_count, estimate_illumination
288
+ )
289
+ if batch_size == 1:
290
+ return meshes[0], global_dict
291
+ else:
292
+ return meshes, global_dict
293
+
294
+ def prepare_image(self, image):
295
  if image.mode != "RGBA":
296
  raise ValueError("Image must be in RGBA mode")
297
  img_cond = (
 
312
  mask_cond,
313
  )
314
 
315
+ return mask_cond, rgb_cond
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  def generate_mesh(
318
  self,
319
  batch,
320
  bake_resolution: int,
321
+ remesh: Literal["none", "triangle", "quad"] = "none",
322
+ vertex_count: int = -1,
323
  estimate_illumination: bool = False,
324
  ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
325
  batch["rgb_cond"] = self.image_processor(
 
338
  if self.global_estimator is not None and estimate_illumination:
339
  global_dict.update(self.global_estimator(non_postprocessed_codes))
340
 
341
+ device = get_device()
342
  with torch.no_grad():
343
+ with torch.autocast(
344
+ device_type=device, enabled=False
345
+ ) if "cuda" in device else nullcontext():
346
  meshes = self.triplane_to_meshes(scene_codes)
347
 
348
  rets = []
 
352
  rets.append(trimesh.Trimesh())
353
  continue
354
 
355
+ if remesh == "triangle":
356
+ mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
357
+ elif remesh == "quad":
358
+ mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
359
+ else:
360
+ if vertex_count > 0:
361
+ print(
362
+ "Warning: vertex_count is ignored when remesh is none"
363
+ )
364
+
365
+ print("After Remesh", mesh.v_pos.shape[0], mesh.t_pos_idx.shape[0])
366
  mesh.unwrap_uv()
367
 
368
  # Build textures
 
375
  mesh.v_pos,
376
  rast,
377
  mesh.t_pos_idx,
 
378
  )
379
  gb_pos = pos_bake[bake_mask]
380
 
 
387
  mesh.v_nrm,
388
  rast,
389
  mesh.t_pos_idx,
 
390
  )
391
  gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
392
  decoded["normal"] = gb_nrm
 
427
  mesh.v_tng,
428
  rast,
429
  mesh.t_pos_idx,
 
430
  )
431
  gb_tng = tng[bake_mask]
432
  gb_tng = F.normalize(gb_tng, dim=-1)
433
  gb_btng = F.normalize(
434
+ torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
435
  )
436
  normal = F.normalize(mat_out["normal"], dim=-1)
437
 
438
+ # Create tangent space matrix and transform normal
439
+ tangent_matrix = torch.stack(
440
+ [gb_tng, gb_btng, gb_nrm], dim=-1
441
+ )
442
+ normal_tangent = torch.bmm(
443
+ tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
444
+ ).squeeze(-1)
445
+
446
+ # Convert from [-1,1] to [0,1] range for storage
447
+ normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
448
+ 0, 1
449
  )
 
450
 
451
+ f[bake_mask] = normal_tangent.view(-1, 3)
452
  mat_out["bump"] = f
453
  else:
454
  f[bake_mask] = v.view(-1, v.shape[-1])
 
459
  return arr
460
  return (
461
  dilate_fill(
462
+ arr.permute(2, 0, 1)[None, ...].contiguous(),
463
  bake_mask.unsqueeze(0).unsqueeze(0),
464
  iterations=bake_resolution // 150,
465
  )
466
  .squeeze(0)
467
  .permute(1, 2, 0)
468
+ .contiguous()
469
  )
470
 
471
  verts_np = convert_data(mesh.v_pos)
sf3d/utils.py CHANGED
@@ -1,13 +1,27 @@
1
- from typing import Any
 
2
 
3
  import numpy as np
4
  import rembg
5
  import torch
 
6
  from PIL import Image
7
 
8
  import sf3d.models.utils as sf3d_utils
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
12
  intrinsic = sf3d_utils.get_intrinsic_from_fov(
13
  np.deg2rad(fov_deg),
@@ -50,42 +64,42 @@ def remove_background(
50
  return image
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def resize_foreground(
54
- image: Image,
55
  ratio: float,
 
56
  ) -> Image:
57
- image = np.array(image)
58
- assert image.shape[-1] == 4
59
- alpha = np.where(image[..., 3] > 0)
60
- y1, y2, x1, x2 = (
61
- alpha[0].min(),
62
- alpha[0].max(),
63
- alpha[1].min(),
64
- alpha[1].max(),
65
- )
66
- # crop the foreground
67
- fg = image[y1:y2, x1:x2]
68
- # pad to square
69
- size = max(fg.shape[0], fg.shape[1])
70
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
71
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
72
- new_image = np.pad(
73
- fg,
74
- ((ph0, ph1), (pw0, pw1), (0, 0)),
75
- mode="constant",
76
- constant_values=((0, 0), (0, 0), (0, 0)),
77
- )
78
 
79
- # compute padding according to the ratio
80
- new_size = int(new_image.shape[0] / ratio)
81
- # pad to size, double side
82
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
83
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
84
- new_image = np.pad(
85
- new_image,
86
- ((ph0, ph1), (pw0, pw1), (0, 0)),
87
- mode="constant",
88
- constant_values=((0, 0), (0, 0), (0, 0)),
89
  )
90
- new_image = Image.fromarray(new_image, mode="RGBA")
 
 
91
  return new_image
 
1
+ import os
2
+ from typing import Any, Union
3
 
4
  import numpy as np
5
  import rembg
6
  import torch
7
+ import torchvision.transforms.functional as torchvision_F
8
  from PIL import Image
9
 
10
  import sf3d.models.utils as sf3d_utils
11
 
12
 
13
+ def get_device():
14
+ if os.environ.get("SF3D_USE_CPU", "0") == "1":
15
+ return "cpu"
16
+
17
+ device = "cpu"
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ device = "mps"
22
+ return device
23
+
24
+
25
  def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
26
  intrinsic = sf3d_utils.get_intrinsic_from_fov(
27
  np.deg2rad(fov_deg),
 
64
  return image
65
 
66
 
67
+ def get_1d_bounds(arr):
68
+ nz = np.flatnonzero(arr)
69
+ return nz[0], nz[-1]
70
+
71
+
72
+ def get_bbox_from_mask(mask, thr=0.5):
73
+ masks_for_box = (mask > thr).astype(np.float32)
74
+ assert masks_for_box.sum() > 0, "Empty mask!"
75
+ x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
76
+ y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
77
+ return x0, y0, x1, y1
78
+
79
+
80
  def resize_foreground(
81
+ image: Union[Image.Image, np.ndarray],
82
  ratio: float,
83
+ out_size=None,
84
  ) -> Image:
85
+ if isinstance(image, np.ndarray):
86
+ image = Image.fromarray(image, mode="RGBA")
87
+ assert image.mode == "RGBA"
88
+ # Get bounding box
89
+ mask_np = np.array(image)[:, :, -1]
90
+ x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
91
+ h, w = y2 - y1, x2 - x1
92
+ yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
93
+ scale = max(h, w) / ratio
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ new_image = torchvision_F.crop(
96
+ image,
97
+ top=int(yc - scale / 2),
98
+ left=int(xc - scale / 2),
99
+ height=int(scale),
100
+ width=int(scale),
 
 
 
 
101
  )
102
+ if out_size is not None:
103
+ new_image = new_image.resize(out_size)
104
+
105
  return new_image
texture_baker/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Texture baker
2
+
3
+ Small texture baker which rasterizes barycentric coordinates to a tensor.
4
+ It also implements an interpolation module which can be used to bake attributes to textures then.
5
+
6
+ ## Usage
7
+
8
+ The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates.
9
+ It supports baking on the CPU and GPU.
10
+
11
+ ```python
12
+ from texture_baker import TextureBaker
13
+
14
+ mesh = ...
15
+ uv = mesh.uv # num_vertex, 2
16
+ triangle_idx = mesh.faces # num_faces, 3
17
+ vertices = mesh.vertices # num_vertex, 3
18
+
19
+ tb = TextureBaker()
20
+ # First get the barycentric coordinates
21
+ rast = tb.rasterize(
22
+ uv=uv, face_indices=triangle_idx, bake_resolution=1024
23
+ )
24
+ # Then interpolate vertex attributes
25
+ position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx)
26
+ ```
texture_baker/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
texture_baker/setup.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import platform
4
+
5
+ import torch
6
+ from setuptools import find_packages, setup
7
+ from torch.utils.cpp_extension import (
8
+ CUDA_HOME,
9
+ BuildExtension,
10
+ CppExtension,
11
+ CUDAExtension,
12
+ )
13
+
14
+ library_name = "texture_baker"
15
+
16
+
17
+ def get_extensions():
18
+ debug_mode = os.getenv("DEBUG", "0") == "1"
19
+ use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1"
20
+ use_metal = (
21
+ os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1"
22
+ )
23
+ use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1"
24
+ if debug_mode:
25
+ print("Compiling in debug mode")
26
+
27
+ use_cuda = use_cuda and CUDA_HOME is not None
28
+ extension = CUDAExtension if use_cuda else CppExtension
29
+
30
+ extra_link_args = []
31
+ extra_compile_args = {
32
+ "cxx": [
33
+ "-O3" if not debug_mode else "-O0",
34
+ "-fdiagnostics-color=always",
35
+ "-fopenmp",
36
+ ] + ["-march=native"] if use_native_arch else [],
37
+ "nvcc": [
38
+ "-O3" if not debug_mode else "-O0",
39
+ ],
40
+ }
41
+ if debug_mode:
42
+ extra_compile_args["cxx"].append("-g")
43
+ if platform.system() == "Windows":
44
+ extra_compile_args["cxx"].append("/Z7")
45
+ extra_compile_args["cxx"].append("/Od")
46
+ extra_link_args.extend(["/DEBUG"])
47
+ extra_compile_args["cxx"].append("-UNDEBUG")
48
+ extra_compile_args["nvcc"].append("-UNDEBUG")
49
+ extra_compile_args["nvcc"].append("-g")
50
+ extra_link_args.extend(["-O0", "-g"])
51
+
52
+ define_macros = []
53
+ extensions = []
54
+ libraries = []
55
+
56
+ this_dir = os.path.dirname(os.path.curdir)
57
+ sources = glob.glob(
58
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
59
+ )
60
+
61
+ if len(sources) == 0:
62
+ print("No source files found for extension, skipping extension compilation")
63
+ return None
64
+
65
+ if use_cuda:
66
+ define_macros += [
67
+ ("THRUST_IGNORE_CUB_VERSION_CHECK", None),
68
+ ]
69
+ sources += glob.glob(
70
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True
71
+ )
72
+ libraries += ["cudart", "c10_cuda"]
73
+
74
+ if use_metal:
75
+ define_macros += [
76
+ ("WITH_MPS", None),
77
+ ]
78
+ sources += glob.glob(
79
+ os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
80
+ )
81
+ extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]})
82
+ extra_link_args += ["-arch", "arm64"]
83
+
84
+ extensions.append(
85
+ extension(
86
+ name=f"{library_name}._C",
87
+ sources=sources,
88
+ define_macros=define_macros,
89
+ extra_compile_args=extra_compile_args,
90
+ extra_link_args=extra_link_args,
91
+ libraries=libraries
92
+ + [
93
+ "c10",
94
+ "torch",
95
+ "torch_cpu",
96
+ "torch_python",
97
+ ],
98
+ )
99
+ )
100
+
101
+ for ext in extensions:
102
+ ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries]
103
+
104
+ print(extensions)
105
+
106
+ return extensions
107
+
108
+
109
+ setup(
110
+ name=library_name,
111
+ version="0.0.1",
112
+ packages=find_packages(where="."),
113
+ package_dir={"": "."},
114
+ ext_modules=get_extensions(),
115
+ install_requires=[],
116
+ package_data={
117
+ library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")],
118
+ },
119
+ description="Small texture baker which rasterizes barycentric coordinates to a tensor.",
120
+ long_description=open("README.md").read(),
121
+ long_description_content_type="text/markdown",
122
+ url="https://github.com/Stability-AI/texture_baker",
123
+ cmdclass={"build_ext": BuildExtension},
124
+ )
texture_baker/texture_baker/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .baker import TextureBaker # noqa: F401
texture_baker/texture_baker/baker.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+
6
+ class TextureBaker(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def rasterize(
11
+ self,
12
+ uv: Tensor,
13
+ face_indices: Tensor,
14
+ bake_resolution: int,
15
+ ) -> Tensor:
16
+ """
17
+ Rasterize the UV coordinates to a barycentric coordinates
18
+ & Triangle idxs texture map
19
+
20
+ Args:
21
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
22
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
23
+ bake_resolution (int): Resolution of the bake
24
+
25
+ Returns:
26
+ Tensor, bake_resolution bake_resolution 4, float: Rasterized map
27
+ """
28
+ return torch.ops.texture_baker_cpp.rasterize(
29
+ uv, face_indices.to(torch.int32), bake_resolution
30
+ )
31
+
32
+ def get_mask(self, rast: Tensor) -> Tensor:
33
+ """
34
+ Get the occupancy mask from the rasterized map
35
+
36
+ Args:
37
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
38
+
39
+ Returns:
40
+ Tensor, bake_resolution bake_resolution, bool: Mask
41
+ """
42
+ return rast[..., -1] >= 0
43
+
44
+ def interpolate(
45
+ self,
46
+ attr: Tensor,
47
+ rast: Tensor,
48
+ face_indices: Tensor,
49
+ ) -> Tensor:
50
+ """
51
+ Interpolate the attributes using the rasterized map
52
+
53
+ Args:
54
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
55
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
56
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
57
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
58
+
59
+ Returns:
60
+ Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes
61
+ """
62
+ return torch.ops.texture_baker_cpp.interpolate(
63
+ attr, face_indices.to(torch.int32), rast
64
+ )
65
+
66
+ def forward(
67
+ self,
68
+ attr: Tensor,
69
+ uv: Tensor,
70
+ face_indices: Tensor,
71
+ bake_resolution: int,
72
+ ) -> Tensor:
73
+ """
74
+ Bake the texture
75
+
76
+ Args:
77
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
78
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
79
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
80
+ bake_resolution (int): Resolution of the bake
81
+
82
+ Returns:
83
+ Tensor, bake_resolution bake_resolution 3, float: Baked texture
84
+ """
85
+ rast = self.rasterize(uv, face_indices, bake_resolution)
86
+ return self.interpolate(attr, rast, face_indices, uv)
texture_baker/texture_baker/csrc/baker.cpp ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <chrono>
4
+ #include <cmath>
5
+ #include <omp.h>
6
+ #include <torch/extension.h>
7
+ #ifndef __ARM_ARCH_ISA_A64
8
+ #include <immintrin.h>
9
+ #endif
10
+
11
+ #include "baker.h"
12
+
13
+ // #define TIMING
14
+ #define BINS 8
15
+
16
+ namespace texture_baker_cpp {
17
+ // Calculate the centroid of a triangle
18
+ tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1,
19
+ const tb_float2 &v2) {
20
+ return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f};
21
+ }
22
+
23
+ float BVH::find_best_split_plane(const BVHNode &node, int &best_axis,
24
+ int &best_pos, AABB &centroidBounds) {
25
+ float best_cost = std::numeric_limits<float>::max();
26
+
27
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
28
+ {
29
+ float boundsMin = centroidBounds.min[axis];
30
+ float boundsMax = centroidBounds.max[axis];
31
+ if (boundsMin == boundsMax) {
32
+ continue;
33
+ }
34
+
35
+ // Populate the bins
36
+ float scale = BINS / (boundsMax - boundsMin);
37
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
38
+ int leftSum = 0, rightSum = 0;
39
+
40
+ #ifndef __ARM_ARCH_ISA_A64
41
+ #ifndef _MSC_VER
42
+ if (__builtin_cpu_supports("sse"))
43
+ #elif (defined(_M_AMD64) || defined(_M_X64))
44
+ // SSE supported on Windows
45
+ if constexpr (true)
46
+ #endif
47
+ {
48
+ __m128 min4[BINS], max4[BINS];
49
+ unsigned int count[BINS];
50
+ for (unsigned int i = 0; i < BINS; i++)
51
+ min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f),
52
+ count[i] = 0;
53
+ for (int i = node.start; i < node.end; i++) {
54
+ int tri_idx = triangle_indices[i];
55
+ const Triangle &triangle = triangles[tri_idx];
56
+
57
+ int binIdx = std::min(
58
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
59
+ count[binIdx]++;
60
+ __m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f);
61
+ __m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f);
62
+ __m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f);
63
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
64
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
65
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
66
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
67
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
68
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
69
+ }
70
+ // gather data for the 7 planes between the 8 bins
71
+ __m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4;
72
+ __m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4;
73
+ for (int i = 0; i < BINS - 1; i++) {
74
+ leftSum += count[i];
75
+ rightSum += count[BINS - 1 - i];
76
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
77
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
78
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
79
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
80
+ float le[4], re[4];
81
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
82
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
83
+ // SSE order goes from back to front
84
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
85
+ rightCountArea[BINS - 2 - i] =
86
+ rightSum * (re[2] * re[3]); // 2D area calculation
87
+ }
88
+ }
89
+ #else
90
+ if constexpr (false) {
91
+ }
92
+ #endif
93
+ else {
94
+ struct Bin {
95
+ AABB bounds;
96
+ int triCount = 0;
97
+ } bins[BINS];
98
+
99
+ for (int i = node.start; i < node.end; i++) {
100
+ int tri_idx = triangle_indices[i];
101
+ const Triangle &triangle = triangles[tri_idx];
102
+
103
+ int binIdx = std::min(
104
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
105
+ bins[binIdx].triCount++;
106
+ bins[binIdx].bounds.grow(triangle.v0);
107
+ bins[binIdx].bounds.grow(triangle.v1);
108
+ bins[binIdx].bounds.grow(triangle.v2);
109
+ }
110
+
111
+ // Gather data for the planes between the bins
112
+ AABB leftBox, rightBox;
113
+
114
+ for (int i = 0; i < BINS - 1; i++) {
115
+ leftSum += bins[i].triCount;
116
+ leftBox.grow(bins[i].bounds);
117
+ leftCountArea[i] = leftSum * leftBox.area();
118
+
119
+ rightSum += bins[BINS - 1 - i].triCount;
120
+ rightBox.grow(bins[BINS - 1 - i].bounds);
121
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
122
+ }
123
+ }
124
+
125
+ // Calculate SAH cost for the planes
126
+ scale = (boundsMax - boundsMin) / BINS;
127
+ for (int i = 0; i < BINS - 1; i++) {
128
+ float planeCost = leftCountArea[i] + rightCountArea[i];
129
+ if (planeCost < best_cost) {
130
+ best_axis = axis;
131
+ best_pos = i + 1;
132
+ best_cost = planeCost;
133
+ }
134
+ }
135
+ }
136
+
137
+ return best_cost;
138
+ }
139
+
140
+ void BVH::update_node_bounds(BVHNode &node, AABB &centroidBounds) {
141
+ #ifndef __ARM_ARCH_ISA_A64
142
+ #ifndef _MSC_VER
143
+ if (__builtin_cpu_supports("sse"))
144
+ #elif (defined(_M_AMD64) || defined(_M_X64))
145
+ // SSE supported on Windows
146
+ if constexpr (true)
147
+ #endif
148
+ {
149
+ __m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f);
150
+ __m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f);
151
+
152
+ for (int i = node.start; i < node.end; i += 2) {
153
+ int tri_idx1 = triangle_indices[i];
154
+ const Triangle &leafTri1 = triangles[tri_idx1];
155
+ // Check if the second actually exists in the node
156
+ __m128 v0, v1, v2, centroid;
157
+ if (i + 1 < node.end) {
158
+ int tri_idx2 = triangle_indices[i + 1];
159
+ const Triangle leafTri2 = triangles[tri_idx2];
160
+
161
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
162
+ leafTri2.v0.y);
163
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
164
+ leafTri2.v1.y);
165
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
166
+ leafTri2.v2.y);
167
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
168
+ leafTri2.centroid.x, leafTri2.centroid.y);
169
+ } else {
170
+ // Otherwise do some duplicated work
171
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
172
+ leafTri1.v0.y);
173
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
174
+ leafTri1.v1.y);
175
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
176
+ leafTri1.v2.y);
177
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
178
+ leafTri1.centroid.x, leafTri1.centroid.y);
179
+ }
180
+
181
+ min4 = _mm_min_ps(min4, v0);
182
+ max4 = _mm_max_ps(max4, v0);
183
+ min4 = _mm_min_ps(min4, v1);
184
+ max4 = _mm_max_ps(max4, v1);
185
+ min4 = _mm_min_ps(min4, v2);
186
+ max4 = _mm_max_ps(max4, v2);
187
+ cmin4 = _mm_min_ps(cmin4, centroid);
188
+ cmax4 = _mm_max_ps(cmax4, centroid);
189
+ }
190
+
191
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
192
+ _mm_store_ps(min_values, min4);
193
+ _mm_store_ps(max_values, max4);
194
+ _mm_store_ps(cmin_values, cmin4);
195
+ _mm_store_ps(cmax_values, cmax4);
196
+
197
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
198
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
199
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
200
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
201
+
202
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
203
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
204
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
205
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
206
+ }
207
+ #else
208
+ if constexpr (false) {
209
+ }
210
+ #endif
211
+ {
212
+ node.bbox.invalidate();
213
+ centroidBounds.invalidate();
214
+
215
+ // Calculate the bounding box for the node
216
+ for (int i = node.start; i < node.end; ++i) {
217
+ int tri_idx = triangle_indices[i];
218
+ const Triangle &tri = triangles[tri_idx];
219
+ node.bbox.grow(tri.v0);
220
+ node.bbox.grow(tri.v1);
221
+ node.bbox.grow(tri.v2);
222
+ centroidBounds.grow(tri.centroid);
223
+ }
224
+ }
225
+ }
226
+
227
+ void BVH::build(const tb_float2 *vertices, const tb_int3 *indices,
228
+ const int64_t &num_indices) {
229
+ #ifdef TIMING
230
+ auto start = std::chrono::high_resolution_clock::now();
231
+ #endif
232
+ // Create triangles
233
+ for (size_t i = 0; i < num_indices; ++i) {
234
+ tb_int3 idx = indices[i];
235
+ triangles.push_back(
236
+ {vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast<int>(i),
237
+ triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])});
238
+ }
239
+
240
+ // Initialize triangle_indices
241
+ triangle_indices.resize(triangles.size());
242
+ std::iota(triangle_indices.begin(), triangle_indices.end(), 0);
243
+
244
+ // Build BVH nodes
245
+ // Reserve extra capacity to fix windows specific crashes
246
+ nodes.reserve(triangles.size() * 2 + 1);
247
+ nodes.push_back({}); // Create the root node
248
+ root = 0;
249
+
250
+ // Define a struct for queue entries
251
+ struct QueueEntry {
252
+ int node_idx;
253
+ int start;
254
+ int end;
255
+ };
256
+
257
+ // Queue for breadth-first traversal
258
+ std::queue<QueueEntry> node_queue;
259
+ node_queue.push({root, 0, (int)triangles.size()});
260
+
261
+ // Process each node in the queue
262
+ while (!node_queue.empty()) {
263
+ QueueEntry current = node_queue.front();
264
+ node_queue.pop();
265
+
266
+ int node_idx = current.node_idx;
267
+ int start = current.start;
268
+ int end = current.end;
269
+
270
+ BVHNode &node = nodes[node_idx];
271
+ node.start = start;
272
+ node.end = end;
273
+
274
+ // Calculate the bounding box for the node
275
+ AABB centroidBounds;
276
+ update_node_bounds(node, centroidBounds);
277
+
278
+ // Determine the best split using SAH
279
+ int best_axis, best_pos;
280
+
281
+ float splitCost =
282
+ find_best_split_plane(node, best_axis, best_pos, centroidBounds);
283
+ float nosplitCost = node.calculate_node_cost();
284
+
285
+ // Stop condition: if the best cost is greater than or equal to the parent's
286
+ // cost
287
+ if (splitCost >= nosplitCost) {
288
+ // Leaf node
289
+ node.left = node.right = -1;
290
+ continue;
291
+ }
292
+
293
+ float scale =
294
+ BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]);
295
+ int i = node.start;
296
+ int j = node.end - 1;
297
+
298
+ // Sort the triangle_indices in the range [start, end) based on the best
299
+ // axis
300
+ while (i <= j) {
301
+ // use the exact calculation we used for binning to prevent rare
302
+ // inaccuracies
303
+ int tri_idx = triangle_indices[i];
304
+ tb_float2 tcentr = triangles[tri_idx].centroid;
305
+ int binIdx = std::min(
306
+ BINS - 1,
307
+ (int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale));
308
+ if (binIdx < best_pos)
309
+ i++;
310
+ else
311
+ std::swap(triangle_indices[i], triangle_indices[j--]);
312
+ }
313
+ int leftCount = i - node.start;
314
+ if (leftCount == 0 || leftCount == node.num_triangles()) {
315
+ // Leaf node
316
+ node.left = node.right = -1;
317
+ continue;
318
+ }
319
+
320
+ int mid = i;
321
+
322
+ // Create and set left child
323
+ node.left = nodes.size();
324
+ nodes.push_back({});
325
+ node_queue.push({node.left, start, mid});
326
+
327
+ // Create and set right child
328
+ node = nodes[node_idx]; // Update the node - Potentially stale reference
329
+ node.right = nodes.size();
330
+ nodes.push_back({});
331
+ node_queue.push({node.right, mid, end});
332
+ }
333
+ #ifdef TIMING
334
+ auto end = std::chrono::high_resolution_clock::now();
335
+ std::chrono::duration<double> elapsed = end - start;
336
+ std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl;
337
+ #endif
338
+ }
339
+
340
+ // Utility function to clamp a value between a minimum and a maximum
341
+ float clamp(float val, float minVal, float maxVal) {
342
+ return std::min(std::max(val, minVal), maxVal);
343
+ }
344
+
345
+ // Function to check if a point (xy) is inside a triangle defined by vertices
346
+ // v1, v2, v3
347
+ bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2,
348
+ tb_float2 v3, float &u, float &v, float &w) {
349
+ // Vectors from v1 to v2, v3 and xy
350
+ tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y};
351
+ tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y};
352
+ tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y};
353
+
354
+ // Dot products of the vectors
355
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
356
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
357
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
358
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
359
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
360
+
361
+ // Calculate the barycentric coordinates
362
+ float denom = d00 * d11 - d01 * d01;
363
+ v = (d11 * d20 - d01 * d21) / denom;
364
+ w = (d00 * d21 - d01 * d20) / denom;
365
+ u = 1.0f - v - w;
366
+
367
+ // Check if the point is inside the triangle
368
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
369
+ }
370
+
371
+ bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w,
372
+ int &index) const {
373
+ const int max_stack_size = 64;
374
+ int node_stack[max_stack_size];
375
+ int stack_size = 0;
376
+
377
+ node_stack[stack_size++] = root;
378
+
379
+ while (stack_size > 0) {
380
+ int node_idx = node_stack[--stack_size];
381
+ const BVHNode &node = nodes[node_idx];
382
+
383
+ if (node.is_leaf()) {
384
+ for (int i = node.start; i < node.end; ++i) {
385
+ const Triangle &tri = triangles[triangle_indices[i]];
386
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) {
387
+ index = tri.index;
388
+ return true;
389
+ }
390
+ }
391
+ } else {
392
+ if (nodes[node.right].bbox.overlaps(point)) {
393
+ if (stack_size < max_stack_size) {
394
+ node_stack[stack_size++] = node.right;
395
+ } else {
396
+ // Handle stack overflow
397
+ throw std::runtime_error("Node stack overflow");
398
+ }
399
+ }
400
+ if (nodes[node.left].bbox.overlaps(point)) {
401
+ if (stack_size < max_stack_size) {
402
+ node_stack[stack_size++] = node.left;
403
+ } else {
404
+ // Handle stack overflow
405
+ throw std::runtime_error("Node stack overflow");
406
+ }
407
+ }
408
+ }
409
+ }
410
+
411
+ return false;
412
+ }
413
+
414
+ torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices,
415
+ int64_t bake_resolution) {
416
+ int width = bake_resolution;
417
+ int height = bake_resolution;
418
+ int num_pixels = width * height;
419
+ torch::Tensor rast_result = torch::empty(
420
+ {bake_resolution, bake_resolution, 4},
421
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
422
+
423
+ float *rast_result_ptr = rast_result.contiguous().data_ptr<float>();
424
+ const tb_float2 *vertices = (tb_float2 *)uv.data_ptr<float>();
425
+ const tb_int3 *tris = (tb_int3 *)indices.data_ptr<int>();
426
+
427
+ BVH bvh;
428
+ bvh.build(vertices, tris, indices.size(0));
429
+
430
+ #ifdef TIMING
431
+ auto start = std::chrono::high_resolution_clock::now();
432
+ #endif
433
+
434
+ #pragma omp parallel for
435
+ for (int idx = 0; idx < num_pixels; ++idx) {
436
+ int x = idx / height;
437
+ int y = idx % height;
438
+ int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel
439
+
440
+ tb_float2 pixel_coord = {float(y) / height, float(x) / width};
441
+ pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f);
442
+ pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f);
443
+
444
+ float u, v, w;
445
+ int triangle_idx;
446
+ if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) {
447
+ rast_result_ptr[idx_ + 0] = u;
448
+ rast_result_ptr[idx_ + 1] = v;
449
+ rast_result_ptr[idx_ + 2] = w;
450
+ rast_result_ptr[idx_ + 3] = static_cast<float>(triangle_idx);
451
+ } else {
452
+ rast_result_ptr[idx_ + 0] = 0.0f;
453
+ rast_result_ptr[idx_ + 1] = 0.0f;
454
+ rast_result_ptr[idx_ + 2] = 0.0f;
455
+ rast_result_ptr[idx_ + 3] = -1.0f;
456
+ }
457
+ }
458
+
459
+ #ifdef TIMING
460
+ auto end = std::chrono::high_resolution_clock::now();
461
+ std::chrono::duration<double> elapsed = end - start;
462
+ std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl;
463
+ #endif
464
+ return rast_result;
465
+ }
466
+
467
+ torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices,
468
+ torch::Tensor rast) {
469
+ #ifdef TIMING
470
+ auto start = std::chrono::high_resolution_clock::now();
471
+ #endif
472
+ int height = rast.size(0);
473
+ int width = rast.size(1);
474
+ torch::Tensor pos_bake = torch::empty(
475
+ {height, width, 3},
476
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
477
+
478
+ const float *attr_ptr = attr.contiguous().data_ptr<float>();
479
+ const int *indices_ptr = indices.contiguous().data_ptr<int>();
480
+ const float *rast_ptr = rast.contiguous().data_ptr<float>();
481
+ float *output_ptr = pos_bake.contiguous().data_ptr<float>();
482
+
483
+ int num_pixels = width * height;
484
+
485
+ #pragma omp parallel for
486
+ for (int idx = 0; idx < num_pixels; ++idx) {
487
+ int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel)
488
+ tb_float3 barycentric = {
489
+ rast_ptr[idx_ + 0],
490
+ rast_ptr[idx_ + 1],
491
+ rast_ptr[idx_ + 2],
492
+ };
493
+ int triangle_idx = static_cast<int>(rast_ptr[idx_ + 3]);
494
+
495
+ if (triangle_idx < 0) {
496
+ output_ptr[idx * 3 + 0] = 0.0f;
497
+ output_ptr[idx * 3 + 1] = 0.0f;
498
+ output_ptr[idx * 3 + 2] = 0.0f;
499
+ continue;
500
+ }
501
+
502
+ tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0],
503
+ indices_ptr[3 * triangle_idx + 1],
504
+ indices_ptr[3 * triangle_idx + 2]};
505
+ tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1],
506
+ attr_ptr[3 * triangle.x + 2]};
507
+ tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1],
508
+ attr_ptr[3 * triangle.y + 2]};
509
+ tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1],
510
+ attr_ptr[3 * triangle.z + 2]};
511
+
512
+ tb_float3 interpolated;
513
+ interpolated.x =
514
+ v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z;
515
+ interpolated.y =
516
+ v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z;
517
+ interpolated.z =
518
+ v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z;
519
+
520
+ output_ptr[idx * 3 + 0] = interpolated.x;
521
+ output_ptr[idx * 3 + 1] = interpolated.y;
522
+ output_ptr[idx * 3 + 2] = interpolated.z;
523
+ }
524
+
525
+ #ifdef TIMING
526
+ auto end = std::chrono::high_resolution_clock::now();
527
+ std::chrono::duration<double> elapsed = end - start;
528
+ std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl;
529
+ #endif
530
+ return pos_bake;
531
+ }
532
+
533
+ // Registers _C as a Python extension module.
534
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
535
+
536
+ // Defines the operators
537
+ TORCH_LIBRARY(texture_baker_cpp, m) {
538
+ m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor");
539
+ m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor");
540
+ }
541
+
542
+ // Registers CPP implementations
543
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) {
544
+ m.impl("rasterize", &rasterize_cpu);
545
+ m.impl("interpolate", &interpolate_cpu);
546
+ }
547
+
548
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__)
4
+ #define CUDA_ENABLED
5
+ #ifndef __METAL__
6
+ #define CUDA_HOST_DEVICE __host__ __device__
7
+ #define CUDA_DEVICE __device__
8
+ #define METAL_CONSTANT_MEM
9
+ #define METAL_THREAD_MEM
10
+ #else
11
+ #define tb_float2 float2
12
+ #define CUDA_HOST_DEVICE
13
+ #define CUDA_DEVICE
14
+ #define METAL_CONSTANT_MEM constant
15
+ #define METAL_THREAD_MEM thread
16
+ #endif
17
+ #else
18
+ #define CUDA_HOST_DEVICE
19
+ #define CUDA_DEVICE
20
+ #define METAL_CONSTANT_MEM
21
+ #define METAL_THREAD_MEM
22
+ #include <cfloat>
23
+ #include <limits>
24
+ #include <vector>
25
+ #endif
26
+
27
+ namespace texture_baker_cpp {
28
+ // Structure to represent a 2D point or vector
29
+ #ifndef __METAL__
30
+ union alignas(8) tb_float2 {
31
+ struct {
32
+ float x, y;
33
+ };
34
+
35
+ float data[2];
36
+
37
+ float &operator[](size_t idx) {
38
+ if (idx > 1)
39
+ throw std::runtime_error("bad index");
40
+ return data[idx];
41
+ }
42
+
43
+ const float &operator[](size_t idx) const {
44
+ if (idx > 1)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ bool operator==(const tb_float2 &rhs) const {
50
+ return x == rhs.x && y == rhs.y;
51
+ }
52
+ };
53
+
54
+ union alignas(4) tb_float3 {
55
+ struct {
56
+ float x, y, z;
57
+ };
58
+
59
+ float data[3];
60
+
61
+ float &operator[](size_t idx) {
62
+ if (idx > 2)
63
+ throw std::runtime_error("bad index");
64
+ return data[idx];
65
+ }
66
+
67
+ const float &operator[](size_t idx) const {
68
+ if (idx > 2)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+ };
73
+
74
+ union alignas(16) tb_float4 {
75
+ struct {
76
+ float x, y, z, w;
77
+ };
78
+
79
+ float data[4];
80
+
81
+ float &operator[](size_t idx) {
82
+ if (idx > 3)
83
+ throw std::runtime_error("bad index");
84
+ return data[idx];
85
+ }
86
+
87
+ const float &operator[](size_t idx) const {
88
+ if (idx > 3)
89
+ throw std::runtime_error("bad index");
90
+ return data[idx];
91
+ }
92
+ };
93
+ #endif
94
+
95
+ union alignas(4) tb_int3 {
96
+ struct {
97
+ int x, y, z;
98
+ };
99
+
100
+ int data[3];
101
+ #ifndef __METAL__
102
+ int &operator[](size_t idx) {
103
+ if (idx > 2)
104
+ throw std::runtime_error("bad index");
105
+ return data[idx];
106
+ }
107
+ #endif
108
+ };
109
+
110
+ // BVH structure to accelerate point-triangle intersection
111
+ struct alignas(16) AABB {
112
+ // Init bounding boxes with max/min
113
+ tb_float2 min = {FLT_MAX, FLT_MAX};
114
+ tb_float2 max = {FLT_MIN, FLT_MIN};
115
+
116
+ #ifndef CUDA_ENABLED
117
+ // grow the AABB to include a point
118
+ void grow(const tb_float2 &p) {
119
+ min.x = std::min(min.x, p.x);
120
+ min.y = std::min(min.y, p.y);
121
+ max.x = std::max(max.x, p.x);
122
+ max.y = std::max(max.y, p.y);
123
+ }
124
+
125
+ void grow(const AABB &b) {
126
+ if (b.min.x != FLT_MAX) {
127
+ grow(b.min);
128
+ grow(b.max);
129
+ }
130
+ }
131
+ #endif
132
+
133
+ // Check if two AABBs overlap
134
+ bool overlaps(const METAL_THREAD_MEM AABB &other) const {
135
+ return min.x <= other.max.x && max.x >= other.min.x &&
136
+ min.y <= other.max.y && max.y >= other.min.y;
137
+ }
138
+
139
+ bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const {
140
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
141
+ point.y <= max.y;
142
+ }
143
+
144
+ #if defined(__NVCC__)
145
+ CUDA_DEVICE bool overlaps(const float2 &point) const {
146
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
147
+ point.y <= max.y;
148
+ }
149
+ #endif
150
+
151
+ // Initialize AABB to an invalid state
152
+ void invalidate() {
153
+ min = {FLT_MAX, FLT_MAX};
154
+ max = {FLT_MIN, FLT_MIN};
155
+ }
156
+
157
+ // Calculate the area of the AABB
158
+ float area() const {
159
+ tb_float2 extent = {max.x - min.x, max.y - min.y};
160
+ return extent.x * extent.y;
161
+ }
162
+ };
163
+
164
+ struct BVHNode {
165
+ AABB bbox;
166
+ int start, end;
167
+ int left, right;
168
+
169
+ int num_triangles() const { return end - start; }
170
+
171
+ CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; }
172
+
173
+ float calculate_node_cost() {
174
+ float area = bbox.area();
175
+ return num_triangles() * area;
176
+ }
177
+ };
178
+
179
+ struct Triangle {
180
+ tb_float2 v0, v1, v2;
181
+ int index;
182
+ tb_float2 centroid;
183
+ };
184
+
185
+ #ifndef __METAL__
186
+ struct BVH {
187
+ std::vector<BVHNode> nodes;
188
+ std::vector<Triangle> triangles;
189
+ std::vector<int> triangle_indices;
190
+ int root;
191
+
192
+ void build(const tb_float2 *vertices, const tb_int3 *indices,
193
+ const int64_t &num_indices);
194
+ bool intersect(const tb_float2 &point, float &u, float &v, float &w,
195
+ int &index) const;
196
+
197
+ void update_node_bounds(BVHNode &node, AABB &centroidBounds);
198
+ float find_best_split_plane(const BVHNode &node, int &best_axis,
199
+ int &best_pos, AABB &centroidBounds);
200
+ };
201
+ #endif
202
+
203
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker_kernel.cu ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <torch/extension.h>
5
+
6
+ #include "baker.h"
7
+
8
+ // #define TIMING
9
+
10
+ #define STRINGIFY(x) #x
11
+ #define STR(x) STRINGIFY(x)
12
+ #define FILE_LINE __FILE__ ":" STR(__LINE__)
13
+ #define CUDA_CHECK_THROW(x) \
14
+ do { \
15
+ cudaError_t _result = x; \
16
+ if (_result != cudaSuccess) \
17
+ throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \
18
+ } while(0)
19
+
20
+ namespace texture_baker_cpp
21
+ {
22
+
23
+ __device__ float3 operator+(const float3 &a, const float3 &b)
24
+ {
25
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
26
+ }
27
+
28
+ // xy: 2D test position
29
+ // v1: vertex position 1
30
+ // v2: vertex position 2
31
+ // v3: vertex position 3
32
+ //
33
+ __forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w)
34
+ {
35
+ // Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3.
36
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
37
+ float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y);
38
+ float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y);
39
+ float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y);
40
+
41
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
42
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
43
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
44
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
45
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
46
+
47
+ float denom = d00 * d11 - d01 * d01;
48
+ v = (d11 * d20 - d01 * d21) / denom;
49
+ w = (d00 * d21 - d01 * d20) / denom;
50
+ u = 1.0f - v - w;
51
+
52
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
53
+ }
54
+
55
+ __global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height)
56
+ {
57
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
58
+ //int idx = x * width + y;
59
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
60
+ int x = idx / width;
61
+ int y = idx % width;
62
+
63
+ if (x >= width || y >= height)
64
+ return;
65
+
66
+ float4 barycentric = rast[idx];
67
+ int triangle_idx = int(barycentric.w);
68
+
69
+ if (triangle_idx < 0)
70
+ {
71
+ output[idx] = make_float3(0.0f, 0.0f, 0.0f);
72
+ return;
73
+ }
74
+
75
+ float3 v1 = attr[indices[triangle_idx].x];
76
+ float3 v2 = attr[indices[triangle_idx].y];
77
+ float3 v3 = attr[indices[triangle_idx].z];
78
+
79
+ output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x)
80
+ + make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y)
81
+ + make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z);
82
+ }
83
+
84
+ __device__ bool bvh_intersect(
85
+ const BVHNode* __restrict__ nodes,
86
+ const Triangle* __restrict__ triangles,
87
+ const int* __restrict__ triangle_indices,
88
+ const int root,
89
+ const float2 &point,
90
+ float &u, float &v, float &w,
91
+ int &index)
92
+ {
93
+ constexpr int max_stack_size = 64;
94
+ int node_stack[max_stack_size];
95
+ int stack_size = 0;
96
+
97
+ node_stack[stack_size++] = root;
98
+
99
+ while (stack_size > 0)
100
+ {
101
+ int node_idx = node_stack[--stack_size];
102
+ const BVHNode &node = nodes[node_idx];
103
+
104
+ if (node.is_leaf())
105
+ {
106
+ for (int i = node.start; i < node.end; ++i)
107
+ {
108
+ const Triangle &tri = triangles[triangle_indices[i]];
109
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
110
+ {
111
+ index = tri.index;
112
+ return true;
113
+ }
114
+ }
115
+ }
116
+ else
117
+ {
118
+ if (nodes[node.right].bbox.overlaps(point))
119
+ {
120
+ if (stack_size < max_stack_size)
121
+ {
122
+ node_stack[stack_size++] = node.right;
123
+ }
124
+ else
125
+ {
126
+ // Handle stack overflow
127
+ // Make sure NDEBUG is not defined (see setup.py)
128
+ assert(0 && "Node stack overflow");
129
+ }
130
+ }
131
+ if (nodes[node.left].bbox.overlaps(point))
132
+ {
133
+ if (stack_size < max_stack_size)
134
+ {
135
+ node_stack[stack_size++] = node.left;
136
+ }
137
+ else
138
+ {
139
+ // Handle stack overflow
140
+ // Make sure NDEBUG is not defined (see setup.py)
141
+ assert(0 && "Node stack overflow");
142
+ }
143
+ }
144
+ }
145
+ }
146
+
147
+ return false;
148
+ }
149
+
150
+ __global__ void kernel_bake_uv(
151
+ float2* __restrict__ uv,
152
+ int3* __restrict__ indices,
153
+ float4* __restrict__ output,
154
+ const BVHNode* __restrict__ nodes,
155
+ const Triangle* __restrict__ triangles,
156
+ const int* __restrict__ triangle_indices,
157
+ const int root,
158
+ const int width,
159
+ const int height,
160
+ const int num_indices)
161
+ {
162
+ //int idx = x * width + y;
163
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
164
+ int x = idx / width;
165
+ int y = idx % width;
166
+
167
+ if (y >= width || x >= height)
168
+ return;
169
+
170
+ // We index x,y but the original coords are HW. So swap them
171
+ float2 pixel_coord = make_float2(float(y) / height, float(x) / width);
172
+ pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f);
173
+ pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f);
174
+
175
+ float u, v, w;
176
+ int triangle_idx;
177
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
178
+
179
+ if (hit)
180
+ {
181
+ output[idx] = make_float4(u, v, w, float(triangle_idx));
182
+ return;
183
+ }
184
+
185
+ output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f);
186
+ }
187
+
188
+ torch::Tensor rasterize_gpu(
189
+ torch::Tensor uv,
190
+ torch::Tensor indices,
191
+ int64_t bake_resolution)
192
+ {
193
+ #ifdef TIMING
194
+ auto start = std::chrono::high_resolution_clock::now();
195
+ #endif
196
+ constexpr int block_size = 16 * 16;
197
+ int grid_size = bake_resolution * bake_resolution / block_size;
198
+ dim3 block_dims(block_size, 1, 1);
199
+ dim3 grid_dims(grid_size, 1, 1);
200
+
201
+ int num_indices = indices.size(0);
202
+
203
+ int width = bake_resolution;
204
+ int height = bake_resolution;
205
+
206
+ // Step 1: create an empty tensor to store the output.
207
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
208
+
209
+ auto vertices_cpu = uv.contiguous().cpu();
210
+ auto indices_cpu = indices.contiguous().cpu();
211
+
212
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
213
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
214
+
215
+ BVH bvh;
216
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
217
+
218
+ BVHNode *nodes_gpu = nullptr;
219
+ Triangle *triangles_gpu = nullptr;
220
+ int *triangle_indices_gpu = nullptr;
221
+ const int bvh_root = bvh.root;
222
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
223
+
224
+ CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream));
225
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream));
226
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream));
227
+
228
+ CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream));
229
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream));
230
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream));
231
+
232
+ kernel_bake_uv<<<grid_dims, block_dims, 0, cuda_stream>>>(
233
+ (float2 *)uv.contiguous().data_ptr<float>(),
234
+ (int3 *)indices.contiguous().data_ptr<int>(),
235
+ (float4 *)rast_result.contiguous().data_ptr<float>(),
236
+ nodes_gpu,
237
+ triangles_gpu,
238
+ triangle_indices_gpu,
239
+ bvh_root,
240
+ width,
241
+ height,
242
+ num_indices);
243
+
244
+ CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream));
245
+ CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream));
246
+ CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream));
247
+
248
+ #ifdef TIMING
249
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
250
+ auto end = std::chrono::high_resolution_clock::now();
251
+ std::chrono::duration<double> elapsed = end - start;
252
+ std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl;
253
+ #endif
254
+ return rast_result;
255
+ }
256
+
257
+ torch::Tensor interpolate_gpu(
258
+ torch::Tensor attr,
259
+ torch::Tensor indices,
260
+ torch::Tensor rast)
261
+ {
262
+ #ifdef TIMING
263
+ auto start = std::chrono::high_resolution_clock::now();
264
+ #endif
265
+ constexpr int block_size = 16 * 16;
266
+ int grid_size = rast.size(0) * rast.size(0) / block_size;
267
+ dim3 block_dims(block_size, 1, 1);
268
+ dim3 grid_dims(grid_size, 1, 1);
269
+
270
+ // Step 1: create an empty tensor to store the output.
271
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
272
+
273
+ int width = rast.size(0);
274
+ int height = rast.size(1);
275
+
276
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
277
+
278
+ kernel_interpolate<<<grid_dims, block_dims, 0, cuda_stream>>>(
279
+ (float3 *)attr.contiguous().data_ptr<float>(),
280
+ (int3 *)indices.contiguous().data_ptr<int>(),
281
+ (float4 *)rast.contiguous().data_ptr<float>(),
282
+ (float3 *)pos_bake.contiguous().data_ptr<float>(),
283
+ width,
284
+ height);
285
+ #ifdef TIMING
286
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
287
+ auto end = std::chrono::high_resolution_clock::now();
288
+ std::chrono::duration<double> elapsed = end - start;
289
+ std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl;
290
+ #endif
291
+ return pos_bake;
292
+ }
293
+
294
+ // Registers CUDA implementations
295
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m)
296
+ {
297
+ m.impl("rasterize", &rasterize_gpu);
298
+ m.impl("interpolate", &interpolate_gpu);
299
+ }
300
+
301
+ }
texture_baker/texture_baker/csrc/baker_kernel.metal ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ // This header is inlined manually
5
+ //#include "baker.h"
6
+
7
+ // Use the texture_baker_cpp so it can use the classes from baker.h
8
+ using namespace texture_baker_cpp;
9
+
10
+ // Utility function to compute barycentric coordinates
11
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) {
12
+ float2 v1v2 = v2 - v1;
13
+ float2 v1v3 = v3 - v1;
14
+ float2 xyv1 = xy - v1;
15
+
16
+ float d00 = dot(v1v2, v1v2);
17
+ float d01 = dot(v1v2, v1v3);
18
+ float d11 = dot(v1v3, v1v3);
19
+ float d20 = dot(xyv1, v1v2);
20
+ float d21 = dot(xyv1, v1v3);
21
+
22
+ float denom = d00 * d11 - d01 * d01;
23
+ v = (d11 * d20 - d01 * d21) / denom;
24
+ w = (d00 * d21 - d01 * d20) / denom;
25
+ u = 1.0f - v - w;
26
+
27
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
28
+ }
29
+
30
+ // Kernel function for interpolation
31
+ kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]],
32
+ constant packed_int3 *indices [[buffer(1)]],
33
+ constant packed_float4 *rast [[buffer(2)]],
34
+ device packed_float3 *output [[buffer(3)]],
35
+ constant int &width [[buffer(4)]],
36
+ constant int &height [[buffer(5)]],
37
+ uint3 blockIdx [[threadgroup_position_in_grid]],
38
+ uint3 threadIdx [[thread_position_in_threadgroup]],
39
+ uint3 blockDim [[threads_per_threadgroup]])
40
+ {
41
+ // Calculate global position using threadgroup and thread positions
42
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
43
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
44
+
45
+ if (x >= width || y >= height) return;
46
+
47
+ int idx = y * width + x;
48
+ float4 barycentric = rast[idx];
49
+ int triangle_idx = int(barycentric.w);
50
+
51
+ if (triangle_idx < 0) {
52
+ output[idx] = float3(0.0f, 0.0f, 0.0f);
53
+ return;
54
+ }
55
+
56
+ float3 v1 = attr[indices[triangle_idx].x];
57
+ float3 v2 = attr[indices[triangle_idx].y];
58
+ float3 v3 = attr[indices[triangle_idx].z];
59
+
60
+ output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
61
+ }
62
+
63
+ bool bvh_intersect(
64
+ constant BVHNode* nodes,
65
+ constant Triangle* triangles,
66
+ constant int* triangle_indices,
67
+ const thread int root,
68
+ const thread float2 &point,
69
+ thread float &u, thread float &v, thread float &w,
70
+ thread int &index)
71
+ {
72
+ const int max_stack_size = 64;
73
+ thread int node_stack[max_stack_size];
74
+ int stack_size = 0;
75
+
76
+ node_stack[stack_size++] = root;
77
+
78
+ while (stack_size > 0)
79
+ {
80
+ int node_idx = node_stack[--stack_size];
81
+ BVHNode node = nodes[node_idx];
82
+
83
+ if (node.is_leaf())
84
+ {
85
+ for (int i = node.start; i < node.end; ++i)
86
+ {
87
+ constant Triangle &tri = triangles[triangle_indices[i]];
88
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
89
+ {
90
+ index = tri.index;
91
+ return true;
92
+ }
93
+ }
94
+ }
95
+ else
96
+ {
97
+ BVHNode test_node = nodes[node.right];
98
+ if (test_node.bbox.overlaps(point))
99
+ {
100
+ if (stack_size < max_stack_size)
101
+ {
102
+ node_stack[stack_size++] = node.right;
103
+ }
104
+ else
105
+ {
106
+ // Handle stack overflow
107
+ // Sadly, metal doesn't support asserts (but you could try enabling metal validation layers)
108
+ return false;
109
+ }
110
+ }
111
+ test_node = nodes[node.left];
112
+ if (test_node.bbox.overlaps(point))
113
+ {
114
+ if (stack_size < max_stack_size)
115
+ {
116
+ node_stack[stack_size++] = node.left;
117
+ }
118
+ else
119
+ {
120
+ // Handle stack overflow
121
+ return false;
122
+ }
123
+ }
124
+ }
125
+ }
126
+
127
+ return false;
128
+ }
129
+
130
+
131
+ // Kernel function for baking UV
132
+ kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]],
133
+ constant packed_int3 *indices [[buffer(1)]],
134
+ device packed_float4 *output [[buffer(2)]],
135
+ constant BVHNode *nodes [[buffer(3)]],
136
+ constant Triangle *triangles [[buffer(4)]],
137
+ constant int *triangle_indices [[buffer(5)]],
138
+ constant int &root [[buffer(6)]],
139
+ constant int &width [[buffer(7)]],
140
+ constant int &height [[buffer(8)]],
141
+ constant int &num_indices [[buffer(9)]],
142
+ uint3 blockIdx [[threadgroup_position_in_grid]],
143
+ uint3 threadIdx [[thread_position_in_threadgroup]],
144
+ uint3 blockDim [[threads_per_threadgroup]])
145
+ {
146
+ // Calculate global position using threadgroup and thread positions
147
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
148
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
149
+
150
+
151
+ if (x >= width || y >= height) return;
152
+
153
+ int idx = x * width + y;
154
+
155
+ // Swap original coordinates
156
+ float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width));
157
+ pixel_coord = clamp(pixel_coord, 0.0f, 1.0f);
158
+ pixel_coord.y = 1.0f - pixel_coord.y;
159
+
160
+ float u, v, w;
161
+ int triangle_idx;
162
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
163
+
164
+ if (hit) {
165
+ output[idx] = float4(u, v, w, float(triangle_idx));
166
+ return;
167
+ }
168
+
169
+ output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f);
170
+ }
texture_baker/texture_baker/csrc/baker_kernel.mm ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/Context.h>
4
+ #include "baker.h"
5
+
6
+ #import <Foundation/Foundation.h>
7
+ #import <Metal/Metal.h>
8
+ #include <filesystem>
9
+
10
+ // Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
11
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
12
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
13
+ }
14
+
15
+ // Helper function to create a compute pipeline state object (PSO).
16
+ static inline id<MTLComputePipelineState> createComputePipelineState(id<MTLDevice> device, NSString* fullSource, std::string kernel_name) {
17
+ NSError *error = nil;
18
+
19
+ // Load the custom kernel shader.
20
+ MTLCompileOptions *options = [[MTLCompileOptions alloc] init];
21
+ // Add the preprocessor macro "__METAL__"
22
+ options.preprocessorMacros = @{@"__METAL__": @""};
23
+ id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error];
24
+ TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
25
+
26
+ id<MTLFunction> customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
27
+ TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str());
28
+
29
+ id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error];
30
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
31
+
32
+ return pso;
33
+ }
34
+
35
+ std::filesystem::path get_extension_path() {
36
+ // Ensure the GIL is held before calling any Python C API function
37
+ PyGILState_STATE gstate = PyGILState_Ensure();
38
+
39
+ const char* module_name = "texture_baker";
40
+
41
+ // Import the module by name
42
+ PyObject* module = PyImport_ImportModule(module_name);
43
+ if (!module) {
44
+ PyGILState_Release(gstate);
45
+ throw std::runtime_error("Could not import the module: " + std::string(module_name));
46
+ }
47
+
48
+ // Get the filename of the module
49
+ PyObject* filename_obj = PyModule_GetFilenameObject(module);
50
+ if (filename_obj) {
51
+ std::string path = PyUnicode_AsUTF8(filename_obj);
52
+ Py_DECREF(filename_obj);
53
+ PyGILState_Release(gstate);
54
+
55
+ // Get the directory part of the path (removing the __init__.py)
56
+ std::filesystem::path module_path = std::filesystem::path(path).parent_path();
57
+
58
+ // Append the 'csrc' directory to the path
59
+ module_path /= "csrc";
60
+
61
+ return module_path;
62
+ } else {
63
+ PyGILState_Release(gstate);
64
+ throw std::runtime_error("Could not retrieve the module filename.");
65
+ }
66
+ }
67
+
68
+ NSString *get_shader_sources_as_string()
69
+ {
70
+ const std::filesystem::path csrc_path = get_extension_path();
71
+ const std::string shader_path = (csrc_path / "baker_kernel.metal").string();
72
+ const std::string shader_header_path = (csrc_path / "baker.h").string();
73
+ // Load the Metal shader from the specified path
74
+ NSError *error = nil;
75
+
76
+ NSString* shaderHeaderSource = [
77
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()]
78
+ encoding:NSUTF8StringEncoding
79
+ error:&error];
80
+ if (error) {
81
+ throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String));
82
+ }
83
+
84
+ NSString* shaderSource = [
85
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()]
86
+ encoding:NSUTF8StringEncoding
87
+ error:&error];
88
+ if (error) {
89
+ throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String));
90
+ }
91
+
92
+ NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource];
93
+
94
+ return fullSource;
95
+ }
96
+
97
+ namespace texture_baker_cpp
98
+ {
99
+ torch::Tensor rasterize_gpu(
100
+ torch::Tensor uv,
101
+ torch::Tensor indices,
102
+ int64_t bake_resolution)
103
+ {
104
+ TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor");
105
+ TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous");
106
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
107
+
108
+ TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type());
109
+ TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type());
110
+
111
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
112
+
113
+ @autoreleasepool {
114
+ auto vertices_cpu = uv.contiguous().cpu();
115
+ auto indices_cpu = indices.contiguous().cpu();
116
+
117
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
118
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
119
+
120
+ BVH bvh;
121
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
122
+
123
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
124
+
125
+ NSString *fullSource = get_shader_sources_as_string();
126
+
127
+ // Create a compute pipeline state object using the helper function
128
+ id<MTLComputePipelineState> bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv");
129
+
130
+ // Get a reference to the command buffer for the MPS stream.
131
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
132
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
133
+
134
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
135
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
136
+
137
+ dispatch_sync(serialQueue, ^(){
138
+ // Start a compute pass.
139
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
140
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
141
+
142
+ // Get Metal buffers directly from PyTorch tensors
143
+ auto uv_buf = getMTLBufferStorage(uv.contiguous());
144
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
145
+ auto rast_result_buf = getMTLBufferStorage(rast_result);
146
+
147
+ const int width = bake_resolution;
148
+ const int height = bake_resolution;
149
+ const int num_indices = indices.size(0);
150
+ const int bvh_root = bvh.root;
151
+
152
+ // Wrap the existing CPU memory in Metal buffers with shared memory
153
+ id<MTLBuffer> nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil];
154
+ id<MTLBuffer> trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil];
155
+ id<MTLBuffer> triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil];
156
+
157
+ [computeEncoder setComputePipelineState:bake_uv_PSO];
158
+ [computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0];
159
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
160
+ [computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2];
161
+ [computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3];
162
+ [computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4];
163
+ [computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5];
164
+ [computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6];
165
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:7];
166
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:8];
167
+ [computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9];
168
+
169
+ // Calculate a thread group size.
170
+ int block_size = 16;
171
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
172
+ MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1);
173
+
174
+ // Encode the compute command.
175
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
176
+ [computeEncoder endEncoding];
177
+
178
+ // Commit the work.
179
+ torch::mps::commit();
180
+ });
181
+ }
182
+
183
+ return rast_result;
184
+ }
185
+
186
+ torch::Tensor interpolate_gpu(
187
+ torch::Tensor attr,
188
+ torch::Tensor indices,
189
+ torch::Tensor rast)
190
+ {
191
+ TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous");
192
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
193
+ TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous");
194
+
195
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
196
+ std::filesystem::path csrc_path = get_extension_path();
197
+
198
+ @autoreleasepool {
199
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
200
+
201
+ NSString *fullSource = get_shader_sources_as_string();
202
+ // Create a compute pipeline state object using the helper function
203
+ id<MTLComputePipelineState> interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate");
204
+
205
+ // Get a reference to the command buffer for the MPS stream.
206
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
207
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
208
+
209
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
210
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
211
+
212
+ dispatch_sync(serialQueue, ^(){
213
+ // Start a compute pass.
214
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
215
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
216
+
217
+ // Get Metal buffers directly from PyTorch tensors
218
+ auto attr_buf = getMTLBufferStorage(attr.contiguous());
219
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
220
+ auto rast_buf = getMTLBufferStorage(rast.contiguous());
221
+ auto pos_bake_buf = getMTLBufferStorage(pos_bake);
222
+
223
+ int width = rast.size(0);
224
+ int height = rast.size(1);
225
+
226
+ [computeEncoder setComputePipelineState:interpolate_PSO];
227
+ [computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0];
228
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
229
+ [computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2];
230
+ [computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3];
231
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:4];
232
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:5];
233
+
234
+ // Calculate a thread group size.
235
+
236
+ int block_size = 16;
237
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
238
+ MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1);
239
+
240
+ // Encode the compute command.
241
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
242
+
243
+ [computeEncoder endEncoding];
244
+
245
+ // Commit the work.
246
+ torch::mps::commit();
247
+ });
248
+ }
249
+
250
+ return pos_bake;
251
+ }
252
+
253
+ // Registers MPS implementations
254
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m)
255
+ {
256
+ m.impl("rasterize", &rasterize_gpu);
257
+ m.impl("interpolate", &interpolate_gpu);
258
+ }
259
+
260
+ }
uv_unwrapper/README.md ADDED
File without changes
uv_unwrapper/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
uv_unwrapper/setup.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import glob
3
+ import os
4
+
5
+ from setuptools import find_packages, setup
6
+ from torch.utils.cpp_extension import (
7
+ BuildExtension,
8
+ CppExtension,
9
+ )
10
+
11
+ library_name = "uv_unwrapper"
12
+
13
+
14
+ def get_extensions():
15
+ debug_mode = os.getenv("DEBUG", "0") == "1"
16
+ if debug_mode:
17
+ print("Compiling in debug mode")
18
+
19
+ is_mac = True if torch.backends.mps.is_available() else False
20
+ use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1"
21
+ extension = CppExtension
22
+
23
+ extra_link_args = []
24
+ extra_compile_args = {
25
+ "cxx": [
26
+ "-O3" if not debug_mode else "-O0",
27
+ "-fdiagnostics-color=always",
28
+ ("-Xclang " if is_mac else "") + "-fopenmp",
29
+ ] + ["-march=native"] if use_native_arch else [],
30
+ }
31
+ if debug_mode:
32
+ extra_compile_args["cxx"].append("-g")
33
+ extra_compile_args["cxx"].append("-UNDEBUG")
34
+ extra_link_args.extend(["-O0", "-g"])
35
+
36
+ define_macros = []
37
+ extensions = []
38
+
39
+ this_dir = os.path.dirname(os.path.curdir)
40
+ sources = glob.glob(
41
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
42
+ )
43
+
44
+ if len(sources) == 0:
45
+ print("No source files found for extension, skipping extension compilation")
46
+ return None
47
+
48
+ extensions.append(
49
+ extension(
50
+ name=f"{library_name}._C",
51
+ sources=sources,
52
+ define_macros=define_macros,
53
+ extra_compile_args=extra_compile_args,
54
+ extra_link_args=extra_link_args,
55
+ libraries=[
56
+ "c10",
57
+ "torch",
58
+ "torch_cpu",
59
+ "torch_python"
60
+ ] + ["omp"] if is_mac else [],
61
+ )
62
+ )
63
+
64
+ print(extensions)
65
+
66
+ return extensions
67
+
68
+
69
+ setup(
70
+ name=library_name,
71
+ version="0.0.1",
72
+ packages=find_packages(),
73
+ ext_modules=get_extensions(),
74
+ install_requires=[],
75
+ description="Box projection based UV unwrapper",
76
+ long_description=open("README.md").read(),
77
+ long_description_content_type="text/markdown",
78
+ cmdclass={"build_ext": BuildExtension},
79
+ )
uv_unwrapper/uv_unwrapper/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .unwrap import Unwrapper
5
+
6
+ __all__ = ["Unwrapper"]
uv_unwrapper/uv_unwrapper/csrc/bvh.cpp ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #include "bvh.h"
4
+ #include "common.h"
5
+ #include <cstring>
6
+ #include <iostream>
7
+ #include <queue>
8
+ #include <tuple>
9
+
10
+ namespace UVUnwrapper {
11
+ BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) {
12
+ // Copty tri to triangle
13
+ triangle = new Triangle[num_indices];
14
+ memcpy(triangle, tri, num_indices * sizeof(Triangle));
15
+
16
+ // Copy actual_idx to actualIdx
17
+ actualIdx = new int[num_indices];
18
+ memcpy(actualIdx, actual_idx, num_indices * sizeof(int));
19
+
20
+ triIdx = new int[num_indices];
21
+ triCount = num_indices;
22
+
23
+ bvhNode = new BVHNode[triCount * 2 + 64];
24
+ nodesUsed = 2;
25
+ memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode));
26
+
27
+ // populate triangle index array
28
+ for (int i = 0; i < triCount; i++)
29
+ triIdx[i] = i;
30
+
31
+ BVHNode &root = bvhNode[0];
32
+
33
+ root.start = 0, root.end = triCount;
34
+ AABB centroidBounds;
35
+ UpdateNodeBounds(0, centroidBounds);
36
+
37
+ // subdivide recursively
38
+ Subdivide(0, nodesUsed, centroidBounds);
39
+ }
40
+
41
+ BVH::BVH(const BVH &other)
42
+ : BVH(other.triangle, other.triIdx, other.triCount) {}
43
+
44
+ BVH::BVH(BVH &&other) noexcept // move constructor
45
+ : triIdx(std::exchange(other.triIdx, nullptr)),
46
+ actualIdx(std::exchange(other.actualIdx, nullptr)),
47
+ triangle(std::exchange(other.triangle, nullptr)),
48
+ bvhNode(std::exchange(other.bvhNode, nullptr)) {}
49
+
50
+ BVH &BVH::operator=(const BVH &other) // copy assignment
51
+ {
52
+ return *this = BVH(other);
53
+ }
54
+
55
+ BVH &BVH::operator=(BVH &&other) noexcept // move assignment
56
+ {
57
+ std::swap(triIdx, other.triIdx);
58
+ std::swap(actualIdx, other.actualIdx);
59
+ std::swap(triangle, other.triangle);
60
+ std::swap(bvhNode, other.bvhNode);
61
+ std::swap(triCount, other.triCount);
62
+ std::swap(nodesUsed, other.nodesUsed);
63
+ return *this;
64
+ }
65
+
66
+ BVH::~BVH() {
67
+ if (triIdx)
68
+ delete[] triIdx;
69
+ if (triangle)
70
+ delete[] triangle;
71
+ if (actualIdx)
72
+ delete[] actualIdx;
73
+ if (bvhNode)
74
+ delete[] bvhNode;
75
+ }
76
+
77
+ void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds) {
78
+ BVHNode &node = bvhNode[nodeIdx];
79
+ #ifndef __ARM_ARCH_ISA_A64
80
+ #ifndef _MSC_VER
81
+ if (__builtin_cpu_supports("sse"))
82
+ #elif (defined(_M_AMD64) || defined(_M_X64))
83
+ // SSE supported on Windows
84
+ if constexpr (true)
85
+ #endif
86
+ {
87
+ __m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN);
88
+ __m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN);
89
+ for (int i = node.start; i < node.end; i += 2) {
90
+ Triangle &leafTri1 = triangle[triIdx[i]];
91
+ __m128 v0, v1, v2, centroid;
92
+ if (i + 1 < node.end) {
93
+ const Triangle leafTri2 = triangle[triIdx[i + 1]];
94
+
95
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
96
+ leafTri2.v0.y);
97
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
98
+ leafTri2.v1.y);
99
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
100
+ leafTri2.v2.y);
101
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
102
+ leafTri2.centroid.x, leafTri2.centroid.y);
103
+ } else {
104
+ // Otherwise do some duplicated work
105
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
106
+ leafTri1.v0.y);
107
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
108
+ leafTri1.v1.y);
109
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
110
+ leafTri1.v2.y);
111
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
112
+ leafTri1.centroid.x, leafTri1.centroid.y);
113
+ }
114
+
115
+ min4 = _mm_min_ps(min4, v0);
116
+ max4 = _mm_max_ps(max4, v0);
117
+ min4 = _mm_min_ps(min4, v1);
118
+ max4 = _mm_max_ps(max4, v1);
119
+ min4 = _mm_min_ps(min4, v2);
120
+ max4 = _mm_max_ps(max4, v2);
121
+ cmin4 = _mm_min_ps(cmin4, centroid);
122
+ cmax4 = _mm_max_ps(cmax4, centroid);
123
+ }
124
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
125
+ _mm_store_ps(min_values, min4);
126
+ _mm_store_ps(max_values, max4);
127
+ _mm_store_ps(cmin_values, cmin4);
128
+ _mm_store_ps(cmax_values, cmax4);
129
+
130
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
131
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
132
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
133
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
134
+
135
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
136
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
137
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
138
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
139
+ }
140
+ #else
141
+ if constexpr (false) {
142
+ }
143
+ #endif
144
+ else {
145
+ node.bbox.invalidate();
146
+ centroidBounds.invalidate();
147
+
148
+ // Calculate the bounding box for the node
149
+ for (int i = node.start; i < node.end; ++i) {
150
+ const Triangle &tri = triangle[triIdx[i]];
151
+ node.bbox.grow(tri.v0);
152
+ node.bbox.grow(tri.v1);
153
+ node.bbox.grow(tri.v2);
154
+ centroidBounds.grow(tri.centroid);
155
+ }
156
+ }
157
+ }
158
+
159
+ void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr,
160
+ AABB &rootCentroidBounds) {
161
+ // Create a queue for the nodes to be subdivided
162
+ std::queue<std::tuple<unsigned int, AABB>> nodeQueue;
163
+ nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds));
164
+
165
+ while (!nodeQueue.empty()) {
166
+ // Get the next node to process from the queue
167
+ auto [node_idx, centroidBounds] = nodeQueue.front();
168
+ nodeQueue.pop();
169
+ BVHNode &node = bvhNode[node_idx];
170
+
171
+ // Check if left is -1 and right not or vice versa
172
+
173
+ int axis, splitPos;
174
+ float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds);
175
+
176
+ if (cost >= node.calculate_node_cost()) {
177
+ node.left = node.right = -1;
178
+ continue; // Move on to the next node in the queue
179
+ }
180
+
181
+ int i = node.start;
182
+ int j = node.end - 1;
183
+ float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]);
184
+ while (i <= j) {
185
+ int binIdx =
186
+ std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] -
187
+ centroidBounds.min[axis]) *
188
+ scale));
189
+ if (binIdx < splitPos)
190
+ i++;
191
+ else
192
+ std::swap(triIdx[i], triIdx[j--]);
193
+ }
194
+
195
+ int leftCount = i - node.start;
196
+ if (leftCount == 0 || leftCount == (int)node.num_triangles()) {
197
+ node.left = node.right = -1;
198
+ continue; // Move on to the next node in the queue
199
+ }
200
+
201
+ int mid = i;
202
+
203
+ // Create child nodes
204
+ int leftChildIdx = nodePtr++;
205
+ int rightChildIdx = nodePtr++;
206
+ bvhNode[leftChildIdx].start = node.start;
207
+ bvhNode[leftChildIdx].end = mid;
208
+ bvhNode[rightChildIdx].start = mid;
209
+ bvhNode[rightChildIdx].end = node.end;
210
+ node.left = leftChildIdx;
211
+ node.right = rightChildIdx;
212
+
213
+ // Update the bounds for the child nodes and push them onto the queue
214
+ UpdateNodeBounds(leftChildIdx, centroidBounds);
215
+ nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds));
216
+
217
+ UpdateNodeBounds(rightChildIdx, centroidBounds);
218
+ nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds));
219
+ }
220
+ }
221
+
222
+ float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos,
223
+ AABB &centroidBounds) {
224
+ float best_cost = FLT_MAX;
225
+
226
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
227
+ {
228
+ float boundsMin = centroidBounds.min[axis];
229
+ float boundsMax = centroidBounds.max[axis];
230
+ // Or floating point precision
231
+ if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) {
232
+ continue;
233
+ }
234
+
235
+ // populate the bins
236
+ float scale = BINS / (boundsMax - boundsMin);
237
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
238
+ int leftSum = 0, rightSum = 0;
239
+ #ifndef __ARM_ARCH_ISA_A64
240
+ #ifndef _MSC_VER
241
+ if (__builtin_cpu_supports("sse"))
242
+ #elif (defined(_M_AMD64) || defined(_M_X64))
243
+ // SSE supported on Windows
244
+ if constexpr (true)
245
+ #endif
246
+ {
247
+ __m128 min4[BINS], max4[BINS];
248
+ unsigned int count[BINS];
249
+ for (unsigned int i = 0; i < BINS; i++)
250
+ min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN),
251
+ count[i] = 0;
252
+ for (int i = node.start; i < node.end; i++) {
253
+ Triangle &tri = triangle[triIdx[i]];
254
+ int binIdx =
255
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
256
+ count[binIdx]++;
257
+
258
+ __m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f);
259
+ __m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f);
260
+ __m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f);
261
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
262
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
263
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
264
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
265
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
266
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
267
+ }
268
+ // gather data for the 7 planes between the 8 bins
269
+ __m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4;
270
+ __m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4;
271
+ for (int i = 0; i < BINS - 1; i++) {
272
+ leftSum += count[i];
273
+ rightSum += count[BINS - 1 - i];
274
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
275
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
276
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
277
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
278
+ float le[4], re[4];
279
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
280
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
281
+ // SSE order goes from back to front
282
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
283
+ rightCountArea[BINS - 2 - i] =
284
+ rightSum * (re[2] * re[3]); // 2D area calculation
285
+ }
286
+ }
287
+ #else
288
+ if constexpr (false) {
289
+ }
290
+ #endif
291
+ else {
292
+ struct Bin {
293
+ AABB bounds;
294
+ int triCount = 0;
295
+ } bin[BINS];
296
+ for (int i = node.start; i < node.end; i++) {
297
+ Triangle &tri = triangle[triIdx[i]];
298
+ int binIdx =
299
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
300
+ bin[binIdx].triCount++;
301
+ bin[binIdx].bounds.grow(tri.v0);
302
+ bin[binIdx].bounds.grow(tri.v1);
303
+ bin[binIdx].bounds.grow(tri.v2);
304
+ }
305
+ // gather data for the 7 planes between the 8 bins
306
+ AABB leftBox, rightBox;
307
+ for (int i = 0; i < BINS - 1; i++) {
308
+ leftSum += bin[i].triCount;
309
+ leftBox.grow(bin[i].bounds);
310
+ leftCountArea[i] = leftSum * leftBox.area();
311
+ rightSum += bin[BINS - 1 - i].triCount;
312
+ rightBox.grow(bin[BINS - 1 - i].bounds);
313
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
314
+ }
315
+ }
316
+
317
+ // calculate SAH cost for the 7 planes
318
+ scale = (boundsMax - boundsMin) / BINS;
319
+ for (int i = 0; i < BINS - 1; i++) {
320
+ const float planeCost = leftCountArea[i] + rightCountArea[i];
321
+ if (planeCost < best_cost)
322
+ best_axis = axis, best_pos = i + 1, best_cost = planeCost;
323
+ }
324
+ }
325
+ return best_cost;
326
+ }
327
+
328
+ std::vector<int> BVH::Intersect(Triangle &tri_intersect) {
329
+ /**
330
+ * @brief Intersect a triangle with the BVH
331
+ *
332
+ * @param triangle the triangle to intersect
333
+ *
334
+ * @return -1 for no intersection, the index of the intersected triangle
335
+ * otherwise
336
+ */
337
+
338
+ const int max_stack_size = 64;
339
+ int node_stack[max_stack_size];
340
+ int stack_size = 0;
341
+ std::vector<int> intersected_triangles;
342
+
343
+ node_stack[stack_size++] = 0; // Start with the root node (index 0)
344
+ while (stack_size > 0) {
345
+ int node_idx = node_stack[--stack_size];
346
+ const BVHNode &node = bvhNode[node_idx];
347
+ if (node.is_leaf()) {
348
+ for (int i = node.start; i < node.end; ++i) {
349
+ const Triangle &tri = triangle[triIdx[i]];
350
+ // Check that the triangle is not the same as the intersected triangle
351
+ if (tri == tri_intersect)
352
+ continue;
353
+ if (tri_intersect.overlaps(tri)) {
354
+ intersected_triangles.push_back(actualIdx[triIdx[i]]);
355
+ }
356
+ }
357
+ } else {
358
+ // Check right child first
359
+ if (bvhNode[node.right].bbox.overlaps(tri_intersect)) {
360
+ if (stack_size < max_stack_size) {
361
+ node_stack[stack_size++] = node.right;
362
+ } else {
363
+ throw std::runtime_error("Node stack overflow");
364
+ }
365
+ }
366
+
367
+ // Check left child
368
+ if (bvhNode[node.left].bbox.overlaps(tri_intersect)) {
369
+ if (stack_size < max_stack_size) {
370
+ node_stack[stack_size++] = node.left;
371
+ } else {
372
+ throw std::runtime_error("Node stack overflow");
373
+ }
374
+ }
375
+ }
376
+ }
377
+ return intersected_triangles; // Return all intersected triangle indices
378
+ }
379
+
380
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/bvh.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cfloat>
4
+ #include <cmath>
5
+ #ifndef __ARM_ARCH_ISA_A64
6
+ #include <immintrin.h>
7
+ #endif
8
+ #include <limits>
9
+ #include <vector>
10
+
11
+ #include "common.h"
12
+ #include "intersect.h"
13
+ /**
14
+ * Based on https://github.com/jbikker/bvh_article released under the unlicense.
15
+ */
16
+
17
+ // bin count for binned BVH building
18
+ #define BINS 8
19
+
20
+ namespace UVUnwrapper {
21
+ // minimalist triangle struct
22
+ struct alignas(32) Triangle {
23
+ uv_float2 v0;
24
+ uv_float2 v1;
25
+ uv_float2 v2;
26
+ uv_float2 centroid;
27
+
28
+ bool overlaps(const Triangle &other) {
29
+ // return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2);
30
+ return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1,
31
+ other.v2);
32
+ }
33
+
34
+ bool operator==(const Triangle &rhs) const {
35
+ return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2;
36
+ }
37
+ };
38
+
39
+ // minimalist AABB struct with grow functionality
40
+ struct alignas(16) AABB {
41
+ // Init bounding boxes with max/min
42
+ uv_float2 min = {FLT_MAX, FLT_MAX};
43
+ uv_float2 max = {FLT_MIN, FLT_MIN};
44
+
45
+ void grow(const uv_float2 &p) {
46
+ min.x = std::min(min.x, p.x);
47
+ min.y = std::min(min.y, p.y);
48
+ max.x = std::max(max.x, p.x);
49
+ max.y = std::max(max.y, p.y);
50
+ }
51
+
52
+ void grow(const AABB &b) {
53
+ if (b.min.x != FLT_MAX) {
54
+ grow(b.min);
55
+ grow(b.max);
56
+ }
57
+ }
58
+
59
+ bool overlaps(const Triangle &tri) {
60
+ return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2);
61
+ }
62
+
63
+ float area() const {
64
+ uv_float2 extent = {max.x - min.x, max.y - min.y};
65
+ return extent.x * extent.y;
66
+ }
67
+
68
+ void invalidate() {
69
+ min = {FLT_MAX, FLT_MAX};
70
+ max = {FLT_MIN, FLT_MIN};
71
+ }
72
+ };
73
+
74
+ // 32-byte BVH node struct
75
+ struct alignas(32) BVHNode {
76
+ AABB bbox; // 16
77
+ int start = 0, end = 0; // 8
78
+ int left, right;
79
+
80
+ int num_triangles() const { return end - start; }
81
+
82
+ bool is_leaf() const { return left == -1 && right == -1; }
83
+
84
+ float calculate_node_cost() {
85
+ float area = bbox.area();
86
+ return num_triangles() * area;
87
+ }
88
+ };
89
+
90
+ class BVH {
91
+ public:
92
+ BVH() = default;
93
+ BVH(BVH &&other) noexcept;
94
+ BVH(const BVH &other);
95
+ BVH &operator=(const BVH &other);
96
+ BVH &operator=(BVH &&other) noexcept;
97
+ BVH(Triangle *tri, int *actual_idx, const size_t &num_indices);
98
+ ~BVH();
99
+
100
+ std::vector<int> Intersect(Triangle &triangle);
101
+
102
+ private:
103
+ void Subdivide(unsigned int node_idx, unsigned int &nodePtr,
104
+ AABB &centroidBounds);
105
+ void UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds);
106
+ float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos,
107
+ AABB &centroidBounds);
108
+
109
+ public:
110
+ int *triIdx = nullptr;
111
+ int *actualIdx = nullptr;
112
+ unsigned int triCount;
113
+ unsigned int nodesUsed;
114
+ BVHNode *bvhNode = nullptr;
115
+ Triangle *triangle = nullptr;
116
+ };
117
+
118
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/common.h ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+
8
+ const float EPSILON = 1e-7f;
9
+
10
+ // Structure to represent a 2D point or vector
11
+ union alignas(8) uv_float2 {
12
+ struct {
13
+ float x, y;
14
+ };
15
+
16
+ float data[2];
17
+
18
+ float &operator[](size_t idx) {
19
+ if (idx > 1)
20
+ throw std::runtime_error("bad index");
21
+ return data[idx];
22
+ }
23
+
24
+ const float &operator[](size_t idx) const {
25
+ if (idx > 1)
26
+ throw std::runtime_error("bad index");
27
+ return data[idx];
28
+ }
29
+
30
+ bool operator==(const uv_float2 &rhs) const {
31
+ return x == rhs.x && y == rhs.y;
32
+ }
33
+ };
34
+
35
+ // Do not align as this is specifically tweaked for BVHNode
36
+ union uv_float3 {
37
+ struct {
38
+ float x, y, z;
39
+ };
40
+
41
+ float data[3];
42
+
43
+ float &operator[](size_t idx) {
44
+ if (idx > 3)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ const float &operator[](size_t idx) const {
50
+ if (idx > 3)
51
+ throw std::runtime_error("bad index");
52
+ return data[idx];
53
+ }
54
+
55
+ bool operator==(const uv_float3 &rhs) const {
56
+ return x == rhs.x && y == rhs.y && z == rhs.z;
57
+ }
58
+ };
59
+
60
+ union alignas(16) uv_float4 {
61
+ struct {
62
+ float x, y, z, w;
63
+ };
64
+
65
+ float data[4];
66
+
67
+ float &operator[](size_t idx) {
68
+ if (idx > 3)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+
73
+ const float &operator[](size_t idx) const {
74
+ if (idx > 3)
75
+ throw std::runtime_error("bad index");
76
+ return data[idx];
77
+ }
78
+
79
+ bool operator==(const uv_float4 &rhs) const {
80
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
81
+ }
82
+ };
83
+
84
+ union alignas(8) uv_int2 {
85
+ struct {
86
+ int x, y;
87
+ };
88
+
89
+ int data[2];
90
+
91
+ int &operator[](size_t idx) {
92
+ if (idx > 1)
93
+ throw std::runtime_error("bad index");
94
+ return data[idx];
95
+ }
96
+
97
+ const int &operator[](size_t idx) const {
98
+ if (idx > 1)
99
+ throw std::runtime_error("bad index");
100
+ return data[idx];
101
+ }
102
+
103
+ bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; }
104
+ };
105
+
106
+ union alignas(4) uv_int3 {
107
+ struct {
108
+ int x, y, z;
109
+ };
110
+
111
+ int data[3];
112
+
113
+ int &operator[](size_t idx) {
114
+ if (idx > 2)
115
+ throw std::runtime_error("bad index");
116
+ return data[idx];
117
+ }
118
+
119
+ const int &operator[](size_t idx) const {
120
+ if (idx > 2)
121
+ throw std::runtime_error("bad index");
122
+ return data[idx];
123
+ }
124
+
125
+ bool operator==(const uv_int3 &rhs) const {
126
+ return x == rhs.x && y == rhs.y && z == rhs.z;
127
+ }
128
+ };
129
+
130
+ union alignas(16) uv_int4 {
131
+ struct {
132
+ int x, y, z, w;
133
+ };
134
+
135
+ int data[4];
136
+
137
+ int &operator[](size_t idx) {
138
+ if (idx > 3)
139
+ throw std::runtime_error("bad index");
140
+ return data[idx];
141
+ }
142
+
143
+ const int &operator[](size_t idx) const {
144
+ if (idx > 3)
145
+ throw std::runtime_error("bad index");
146
+ return data[idx];
147
+ }
148
+
149
+ bool operator==(const uv_int4 &rhs) const {
150
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
151
+ }
152
+ };
153
+
154
+ inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; }
155
+
156
+ // Create a triangle centroid
157
+ inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1,
158
+ const uv_float2 &v2) {
159
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)};
160
+ }
161
+
162
+ inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1,
163
+ const uv_float3 &v2) {
164
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y),
165
+ calc_mean(v0.z, v1.z, v2.z)};
166
+ }
167
+
168
+ // Helper functions for vector math
169
+ inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) {
170
+ return {a.x - b.x, a.y - b.y};
171
+ }
172
+
173
+ inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) {
174
+ return {a.x - b.x, a.y - b.y, a.z - b.z};
175
+ }
176
+
177
+ inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) {
178
+ return {a.x + b.x, a.y + b.y};
179
+ }
180
+
181
+ inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) {
182
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
183
+ }
184
+
185
+ inline uv_float2 operator*(const uv_float2 &a, float scalar) {
186
+ return {a.x * scalar, a.y * scalar};
187
+ }
188
+
189
+ inline uv_float3 operator*(const uv_float3 &a, float scalar) {
190
+ return {a.x * scalar, a.y * scalar, a.z * scalar};
191
+ }
192
+
193
+ inline float dot(const uv_float2 &a, const uv_float2 &b) {
194
+ return a.x * b.x + a.y * b.y;
195
+ }
196
+
197
+ inline float dot(const uv_float3 &a, const uv_float3 &b) {
198
+ return a.x * b.x + a.y * b.y + a.z * b.z;
199
+ }
200
+
201
+ inline float cross(const uv_float2 &a, const uv_float2 &b) {
202
+ return a.x * b.y - a.y * b.x;
203
+ }
204
+
205
+ inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) {
206
+ return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
207
+ }
208
+
209
+ inline uv_float2 abs_vec(const uv_float2 &v) {
210
+ return {std::abs(v.x), std::abs(v.y)};
211
+ }
212
+
213
+ inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) {
214
+ return {std::min(a.x, b.x), std::min(a.y, b.y)};
215
+ }
216
+
217
+ inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) {
218
+ return {std::max(a.x, b.x), std::max(a.y, b.y)};
219
+ }
220
+
221
+ inline float distance_to(const uv_float2 &a, const uv_float2 &b) {
222
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2));
223
+ }
224
+
225
+ inline float distance_to(const uv_float3 &a, const uv_float3 &b) {
226
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) +
227
+ std::pow(a.z - b.z, 2));
228
+ }
229
+
230
+ inline uv_float2 normalize(const uv_float2 &v) {
231
+ float len = std::sqrt(v.x * v.x + v.y * v.y);
232
+ return {v.x / len, v.y / len};
233
+ }
234
+
235
+ inline uv_float3 normalize(const uv_float3 &v) {
236
+ float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
237
+ return {v.x / len, v.y / len, v.z / len};
238
+ }
239
+
240
+ inline float magnitude(const uv_float3 &v) {
241
+ return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
242
+ }
243
+
244
+ struct Matrix4 {
245
+ std::array<std::array<float, 4>, 4> m;
246
+
247
+ Matrix4() {
248
+ for (auto &row : m) {
249
+ row.fill(0.0f);
250
+ }
251
+ m[3][3] = 1.0f; // Identity matrix for 4th row and column
252
+ }
253
+
254
+ void set(float m00, float m01, float m02, float m03, float m10, float m11,
255
+ float m12, float m13, float m20, float m21, float m22, float m23,
256
+ float m30, float m31, float m32, float m33) {
257
+ m[0][0] = m00;
258
+ m[0][1] = m01;
259
+ m[0][2] = m02;
260
+ m[0][3] = m03;
261
+ m[1][0] = m10;
262
+ m[1][1] = m11;
263
+ m[1][2] = m12;
264
+ m[1][3] = m13;
265
+ m[2][0] = m20;
266
+ m[2][1] = m21;
267
+ m[2][2] = m22;
268
+ m[2][3] = m23;
269
+ m[3][0] = m30;
270
+ m[3][1] = m31;
271
+ m[3][2] = m32;
272
+ m[3][3] = m33;
273
+ }
274
+
275
+ float determinant() const {
276
+ return m[0][3] * m[1][2] * m[2][1] * m[3][0] -
277
+ m[0][2] * m[1][3] * m[2][1] * m[3][0] -
278
+ m[0][3] * m[1][1] * m[2][2] * m[3][0] +
279
+ m[0][1] * m[1][3] * m[2][2] * m[3][0] +
280
+ m[0][2] * m[1][1] * m[2][3] * m[3][0] -
281
+ m[0][1] * m[1][2] * m[2][3] * m[3][0] -
282
+ m[0][3] * m[1][2] * m[2][0] * m[3][1] +
283
+ m[0][2] * m[1][3] * m[2][0] * m[3][1] +
284
+ m[0][3] * m[1][0] * m[2][2] * m[3][1] -
285
+ m[0][0] * m[1][3] * m[2][2] * m[3][1] -
286
+ m[0][2] * m[1][0] * m[2][3] * m[3][1] +
287
+ m[0][0] * m[1][2] * m[2][3] * m[3][1] +
288
+ m[0][3] * m[1][1] * m[2][0] * m[3][2] -
289
+ m[0][1] * m[1][3] * m[2][0] * m[3][2] -
290
+ m[0][3] * m[1][0] * m[2][1] * m[3][2] +
291
+ m[0][0] * m[1][3] * m[2][1] * m[3][2] +
292
+ m[0][1] * m[1][0] * m[2][3] * m[3][2] -
293
+ m[0][0] * m[1][1] * m[2][3] * m[3][2] -
294
+ m[0][2] * m[1][1] * m[2][0] * m[3][3] +
295
+ m[0][1] * m[1][2] * m[2][0] * m[3][3] +
296
+ m[0][2] * m[1][0] * m[2][1] * m[3][3] -
297
+ m[0][0] * m[1][2] * m[2][1] * m[3][3] -
298
+ m[0][1] * m[1][0] * m[2][2] * m[3][3] +
299
+ m[0][0] * m[1][1] * m[2][2] * m[3][3];
300
+ }
301
+
302
+ Matrix4 operator*(const Matrix4 &other) const {
303
+ Matrix4 result;
304
+ for (int row = 0; row < 4; ++row) {
305
+ for (int col = 0; col < 4; ++col) {
306
+ result.m[row][col] =
307
+ m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] +
308
+ m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col];
309
+ }
310
+ }
311
+ return result;
312
+ }
313
+
314
+ Matrix4 operator*(float scalar) const {
315
+ Matrix4 result = *this;
316
+ for (auto &row : result.m) {
317
+ for (auto &element : row) {
318
+ element *= scalar;
319
+ }
320
+ }
321
+ return result;
322
+ }
323
+
324
+ Matrix4 operator+(const Matrix4 &other) const {
325
+ Matrix4 result;
326
+ for (int i = 0; i < 4; ++i) {
327
+ for (int j = 0; j < 4; ++j) {
328
+ result.m[i][j] = m[i][j] + other.m[i][j];
329
+ }
330
+ }
331
+ return result;
332
+ }
333
+
334
+ Matrix4 operator-(const Matrix4 &other) const {
335
+ Matrix4 result;
336
+ for (int i = 0; i < 4; ++i) {
337
+ for (int j = 0; j < 4; ++j) {
338
+ result.m[i][j] = m[i][j] - other.m[i][j];
339
+ }
340
+ }
341
+ return result;
342
+ }
343
+
344
+ float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; }
345
+
346
+ Matrix4 identity() const {
347
+ Matrix4 identity;
348
+ identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);
349
+ return identity;
350
+ }
351
+
352
+ Matrix4 power(int exp) const {
353
+ if (exp == 0)
354
+ return identity();
355
+ if (exp == 1)
356
+ return *this;
357
+
358
+ Matrix4 result = *this;
359
+ for (int i = 1; i < exp; ++i) {
360
+ result = result * (*this);
361
+ }
362
+ return result;
363
+ }
364
+
365
+ void print() {
366
+ // Print all entries in 4 rows with 4 columns
367
+ for (int i = 0; i < 4; ++i) {
368
+ for (int j = 0; j < 4; ++j) {
369
+ std::cout << m[i][j] << " ";
370
+ }
371
+ std::cout << std::endl;
372
+ }
373
+ }
374
+
375
+ bool invert() {
376
+ double inv[16], det;
377
+ double mArr[16];
378
+
379
+ // Convert the matrix to a 1D array for easier manipulation
380
+ for (int i = 0; i < 4; ++i) {
381
+ for (int j = 0; j < 4; ++j) {
382
+ mArr[i * 4 + j] = static_cast<double>(m[i][j]);
383
+ }
384
+ }
385
+
386
+ inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] -
387
+ mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] +
388
+ mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10];
389
+
390
+ inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] +
391
+ mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] -
392
+ mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10];
393
+
394
+ inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] -
395
+ mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] +
396
+ mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9];
397
+
398
+ inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] +
399
+ mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] -
400
+ mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9];
401
+
402
+ inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] +
403
+ mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] -
404
+ mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10];
405
+
406
+ inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] -
407
+ mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] +
408
+ mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10];
409
+
410
+ inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] +
411
+ mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] -
412
+ mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9];
413
+
414
+ inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] -
415
+ mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] +
416
+ mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9];
417
+
418
+ inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] -
419
+ mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] +
420
+ mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6];
421
+
422
+ inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] +
423
+ mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] -
424
+ mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6];
425
+
426
+ inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] -
427
+ mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] +
428
+ mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5];
429
+
430
+ inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] +
431
+ mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] -
432
+ mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5];
433
+
434
+ inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] +
435
+ mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] -
436
+ mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6];
437
+
438
+ inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] -
439
+ mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] +
440
+ mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6];
441
+
442
+ inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] +
443
+ mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] -
444
+ mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5];
445
+
446
+ inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] -
447
+ mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] +
448
+ mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5];
449
+
450
+ det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] +
451
+ mArr[3] * inv[12];
452
+
453
+ if (fabs(det) < 1e-6) {
454
+ return false;
455
+ }
456
+
457
+ det = 1.0 / det;
458
+
459
+ for (int i = 0; i < 16; i++) {
460
+ inv[i] *= det;
461
+ }
462
+
463
+ // Convert the 1D array back to the 4x4 matrix
464
+ for (int i = 0; i < 4; ++i) {
465
+ for (int j = 0; j < 4; ++j) {
466
+ m[i][j] = static_cast<float>(inv[i * 4 + j]);
467
+ }
468
+ }
469
+
470
+ return true;
471
+ }
472
+ };
473
+
474
+ inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) {
475
+ float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] +
476
+ v.z * matrix.m[0][2] + matrix.m[0][3];
477
+ float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] +
478
+ v.z * matrix.m[1][2] + matrix.m[1][3];
479
+ float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] +
480
+ v.z * matrix.m[2][2] + matrix.m[2][3];
481
+ float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] +
482
+ matrix.m[3][3];
483
+
484
+ if (std::fabs(w) > EPSILON) {
485
+ newX /= w;
486
+ newY /= w;
487
+ newZ /= w;
488
+ }
489
+
490
+ v.x = newX;
491
+ v.y = newY;
492
+ v.z = newZ;
493
+ }
uv_unwrapper/uv_unwrapper/csrc/intersect.cpp ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "intersect.h"
2
+ #include "bvh.h"
3
+ #include <algorithm>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+ #include <vector>
8
+
9
+ bool triangle_aabb_intersection(const uv_float2 &aabbMin,
10
+ const uv_float2 &aabbMax, const uv_float2 &v0,
11
+ const uv_float2 &v1, const uv_float2 &v2) {
12
+ // Convert the min and max aabb defintion to left, right, top, bottom
13
+ float l = aabbMin.x;
14
+ float r = aabbMax.x;
15
+ float t = aabbMin.y;
16
+ float b = aabbMax.y;
17
+
18
+ int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) |
19
+ ((v0.y > b) ? 8 : 0);
20
+ if (b0 == 3)
21
+ return true;
22
+
23
+ int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) |
24
+ ((v1.y > b) ? 8 : 0);
25
+ if (b1 == 3)
26
+ return true;
27
+
28
+ int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) |
29
+ ((v2.y > b) ? 8 : 0);
30
+ if (b2 == 3)
31
+ return true;
32
+
33
+ float m, c, s;
34
+
35
+ int i0 = b0 ^ b1;
36
+ if (i0 != 0) {
37
+ if (v1.x != v0.x) {
38
+ m = (v1.y - v0.y) / (v1.x - v0.x);
39
+ c = v0.y - (m * v0.x);
40
+ if (i0 & 1) {
41
+ s = m * l + c;
42
+ if (s >= t && s <= b)
43
+ return true;
44
+ }
45
+ if (i0 & 2) {
46
+ s = (t - c) / m;
47
+ if (s >= l && s <= r)
48
+ return true;
49
+ }
50
+ if (i0 & 4) {
51
+ s = m * r + c;
52
+ if (s >= t && s <= b)
53
+ return true;
54
+ }
55
+ if (i0 & 8) {
56
+ s = (b - c) / m;
57
+ if (s >= l && s <= r)
58
+ return true;
59
+ }
60
+ } else {
61
+ if (l == v0.x || r == v0.x)
62
+ return true;
63
+ if (v0.x > l && v0.x < r)
64
+ return true;
65
+ }
66
+ }
67
+
68
+ int i1 = b1 ^ b2;
69
+ if (i1 != 0) {
70
+ if (v2.x != v1.x) {
71
+ m = (v2.y - v1.y) / (v2.x - v1.x);
72
+ c = v1.y - (m * v1.x);
73
+ if (i1 & 1) {
74
+ s = m * l + c;
75
+ if (s >= t && s <= b)
76
+ return true;
77
+ }
78
+ if (i1 & 2) {
79
+ s = (t - c) / m;
80
+ if (s >= l && s <= r)
81
+ return true;
82
+ }
83
+ if (i1 & 4) {
84
+ s = m * r + c;
85
+ if (s >= t && s <= b)
86
+ return true;
87
+ }
88
+ if (i1 & 8) {
89
+ s = (b - c) / m;
90
+ if (s >= l && s <= r)
91
+ return true;
92
+ }
93
+ } else {
94
+ if (l == v1.x || r == v1.x)
95
+ return true;
96
+ if (v1.x > l && v1.x < r)
97
+ return true;
98
+ }
99
+ }
100
+
101
+ int i2 = b0 ^ b2;
102
+ if (i2 != 0) {
103
+ if (v2.x != v0.x) {
104
+ m = (v2.y - v0.y) / (v2.x - v0.x);
105
+ c = v0.y - (m * v0.x);
106
+ if (i2 & 1) {
107
+ s = m * l + c;
108
+ if (s >= t && s <= b)
109
+ return true;
110
+ }
111
+ if (i2 & 2) {
112
+ s = (t - c) / m;
113
+ if (s >= l && s <= r)
114
+ return true;
115
+ }
116
+ if (i2 & 4) {
117
+ s = m * r + c;
118
+ if (s >= t && s <= b)
119
+ return true;
120
+ }
121
+ if (i2 & 8) {
122
+ s = (b - c) / m;
123
+ if (s >= l && s <= r)
124
+ return true;
125
+ }
126
+ } else {
127
+ if (l == v0.x || r == v0.x)
128
+ return true;
129
+ if (v0.x > l && v0.x < r)
130
+ return true;
131
+ }
132
+ }
133
+
134
+ // Bounding box check
135
+ float tbb_l = std::min(v0.x, std::min(v1.x, v2.x));
136
+ float tbb_t = std::min(v0.y, std::min(v1.y, v2.y));
137
+ float tbb_r = std::max(v0.x, std::max(v1.x, v2.x));
138
+ float tbb_b = std::max(v0.y, std::max(v1.y, v2.y));
139
+
140
+ if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) {
141
+ float v0x = v2.x - v0.x;
142
+ float v0y = v2.y - v0.y;
143
+ float v1x = v1.x - v0.x;
144
+ float v1y = v1.y - v0.y;
145
+ float v2x, v2y;
146
+
147
+ float dot00, dot01, dot02, dot11, dot12, invDenom, u, v;
148
+
149
+ // Top-left corner
150
+ v2x = l - v0.x;
151
+ v2y = t - v0.y;
152
+
153
+ dot00 = v0x * v0x + v0y * v0y;
154
+ dot01 = v0x * v1x + v0y * v1y;
155
+ dot02 = v0x * v2x + v0y * v2y;
156
+ dot11 = v1x * v1x + v1y * v1y;
157
+ dot12 = v1x * v2x + v1y * v2y;
158
+
159
+ invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01);
160
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
161
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
162
+
163
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
164
+ return true;
165
+
166
+ // Bottom-left corner
167
+ v2x = l - v0.x;
168
+ v2y = b - v0.y;
169
+
170
+ dot02 = v0x * v2x + v0y * v2y;
171
+ dot12 = v1x * v2x + v1y * v2y;
172
+
173
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
174
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
175
+
176
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
177
+ return true;
178
+
179
+ // Bottom-right corner
180
+ v2x = r - v0.x;
181
+ v2y = b - v0.y;
182
+
183
+ dot02 = v0x * v2x + v0y * v2y;
184
+ dot12 = v1x * v2x + v1y * v2y;
185
+
186
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
187
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
188
+
189
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
190
+ return true;
191
+
192
+ // Top-right corner
193
+ v2x = r - v0.x;
194
+ v2y = t - v0.y;
195
+
196
+ dot02 = v0x * v2x + v0y * v2y;
197
+ dot12 = v1x * v2x + v1y * v2y;
198
+
199
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
200
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
201
+
202
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
203
+ return true;
204
+ }
205
+
206
+ return false;
207
+ }
208
+
209
+ void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) {
210
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
211
+
212
+ // If the determinant is negative, the triangle is oriented clockwise
213
+ if (det < 0) {
214
+ // Swap vertices b and c to ensure counter-clockwise winding
215
+ std::swap(b, c);
216
+ }
217
+ }
218
+
219
+ struct Triangle {
220
+ uv_float3 a, b, c;
221
+
222
+ Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1)
223
+ : a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {}
224
+
225
+ Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1)
226
+ : a(p1), b(q1), c(r1) {}
227
+
228
+ void getNormal(uv_float3 &normal) const {
229
+ uv_float3 u = b - a;
230
+ uv_float3 v = c - a;
231
+ normal = normalize(cross(u, v));
232
+ }
233
+ };
234
+
235
+ bool isTriDegenerated(const Triangle &tri) {
236
+ uv_float3 u = tri.a - tri.b;
237
+ uv_float3 v = tri.a - tri.c;
238
+ uv_float3 cr = cross(u, v);
239
+ return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON;
240
+ }
241
+
242
+ int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c,
243
+ const uv_float3 &d) {
244
+ Matrix4 _matrix4;
245
+ _matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y,
246
+ d.z, 1);
247
+ float det = _matrix4.determinant();
248
+
249
+ if (det < -EPSILON)
250
+ return -1;
251
+ else if (det > EPSILON)
252
+ return 1;
253
+ else
254
+ return 0;
255
+ }
256
+
257
+ int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) {
258
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
259
+
260
+ if (det < -EPSILON)
261
+ return -1;
262
+ else if (det > EPSILON)
263
+ return 1;
264
+ else
265
+ return 0;
266
+ }
267
+
268
+ int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) {
269
+ uv_float2 a_2d = {a.x, a.y};
270
+ uv_float2 b_2d = {b.x, b.y};
271
+ uv_float2 c_2d = {c.x, c.y};
272
+ return orient2D(a_2d, b_2d, c_2d);
273
+ }
274
+
275
+ void permuteTriLeft(Triangle &tri) {
276
+ uv_float3 tmp = tri.a;
277
+ tri.a = tri.b;
278
+ tri.b = tri.c;
279
+ tri.c = tmp;
280
+ }
281
+
282
+ void permuteTriRight(Triangle &tri) {
283
+ uv_float3 tmp = tri.c;
284
+ tri.c = tri.b;
285
+ tri.b = tri.a;
286
+ tri.a = tmp;
287
+ }
288
+
289
+ void makeTriCounterClockwise(Triangle &tri) {
290
+ if (orient2D(tri.a, tri.b, tri.c) < 0) {
291
+ uv_float3 tmp = tri.c;
292
+ tri.c = tri.b;
293
+ tri.b = tmp;
294
+ }
295
+ }
296
+
297
+ void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p,
298
+ const uv_float3 &n, uv_float3 &target) {
299
+ uv_float3 u = b - a;
300
+ uv_float3 v = a - p;
301
+ float dot1 = dot(n, u);
302
+ float dot2 = dot(n, v);
303
+ u = u * (-dot2 / dot1);
304
+ target = a + u;
305
+ }
306
+
307
+ void computeLineIntersection(const Triangle &t1, const Triangle &t2,
308
+ std::vector<uv_float3> &target) {
309
+ uv_float3 n1, n2;
310
+ t1.getNormal(n1);
311
+ t2.getNormal(n2);
312
+
313
+ int o1 = orient3D(t1.a, t1.c, t2.b, t2.a);
314
+ int o2 = orient3D(t1.a, t1.b, t2.c, t2.a);
315
+
316
+ uv_float3 i1, i2;
317
+
318
+ if (o1 > 0) {
319
+ if (o2 > 0) {
320
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
321
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
322
+ } else {
323
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
324
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
325
+ }
326
+ } else {
327
+ if (o2 > 0) {
328
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
329
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
330
+ } else {
331
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
332
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
333
+ }
334
+ }
335
+
336
+ target.push_back(i1);
337
+ if (distance_to(i1, i2) >= EPSILON) {
338
+ target.push_back(i2);
339
+ }
340
+ }
341
+
342
+ void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) {
343
+ // Permute a, b, c so that a is alone on its side
344
+ if (oa == ob) {
345
+ // c is alone, permute right so c becomes a
346
+ permuteTriRight(tri);
347
+ } else if (oa == oc) {
348
+ // b is alone, permute so b becomes a
349
+ permuteTriLeft(tri);
350
+ } else if (ob != oc) {
351
+ // In case a, b, c have different orientation, put a on positive side
352
+ if (ob > 0) {
353
+ permuteTriLeft(tri);
354
+ } else if (oc > 0) {
355
+ permuteTriRight(tri);
356
+ }
357
+ }
358
+ }
359
+
360
+ void makeTriAVertexPositive(Triangle &tri, const Triangle &other) {
361
+ int o = orient3D(other.a, other.b, other.c, tri.a);
362
+ if (o < 0) {
363
+ std::swap(tri.b, tri.c);
364
+ }
365
+ }
366
+
367
+ bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c,
368
+ std::vector<uv_float3> *target = nullptr) {
369
+ int o2a = orient3D(t1.a, t1.b, t1.c, t2.a);
370
+ int o2b = orient3D(t1.a, t1.b, t1.c, t2.b);
371
+ int o2c = orient3D(t1.a, t1.b, t1.c, t2.c);
372
+
373
+ if (o2a == o2b && o2a == o2c) {
374
+ return false;
375
+ }
376
+
377
+ // Make a vertex alone on its side for both triangles
378
+ makeTriAVertexAlone(t1, o1a, o1b, o1c);
379
+ makeTriAVertexAlone(t2, o2a, o2b, o2c);
380
+
381
+ // Ensure the vertex on the positive side
382
+ makeTriAVertexPositive(t2, t1);
383
+ makeTriAVertexPositive(t1, t2);
384
+
385
+ int o1 = orient3D(t1.a, t1.b, t2.a, t2.b);
386
+ int o2 = orient3D(t1.a, t1.c, t2.c, t2.a);
387
+
388
+ if (o1 <= 0 && o2 <= 0) {
389
+ if (target) {
390
+ computeLineIntersection(t1, t2, *target);
391
+ }
392
+ return true;
393
+ }
394
+
395
+ return false;
396
+ }
397
+
398
+ void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1,
399
+ const uv_float3 &a2, const uv_float3 &b2,
400
+ uv_float3 &target) {
401
+ float dx1 = a1.x - b1.x;
402
+ float dx2 = a2.x - b2.x;
403
+ float dy1 = a1.y - b1.y;
404
+ float dy2 = a2.y - b2.y;
405
+
406
+ float D = dx1 * dy2 - dx2 * dy1;
407
+
408
+ float n1 = a1.x * b1.y - a1.y * b1.x;
409
+ float n2 = a2.x * b2.y - a2.y * b2.x;
410
+
411
+ target.x = (n1 * dx2 - n2 * dx1) / D;
412
+ target.y = (n1 * dy2 - n2 * dy1) / D;
413
+ target.z = 0;
414
+ }
415
+
416
+ void clipTriangle(const Triangle &t1, const Triangle &t2,
417
+ std::vector<uv_float3> &target) {
418
+ std::vector<uv_float3> clip = {t1.a, t1.b, t1.c};
419
+ std::vector<uv_float3> output = {t2.a, t2.b, t2.c};
420
+ std::vector<int> orients(output.size() * 3, 0);
421
+ uv_float3 inter;
422
+
423
+ for (int i = 0; i < 3; ++i) {
424
+ const int i_prev = (i + 2) % 3;
425
+ std::vector<uv_float3> input;
426
+ std::copy(output.begin(), output.end(), std::back_inserter(input));
427
+ output.clear();
428
+
429
+ for (size_t j = 0; j < input.size(); ++j) {
430
+ orients[j] = orient2D(clip[i_prev], clip[i], input[j]);
431
+ }
432
+
433
+ for (size_t j = 0; j < input.size(); ++j) {
434
+ const int j_prev = (j - 1 + input.size()) % input.size();
435
+
436
+ if (orients[j] >= 0) {
437
+ if (orients[j_prev] < 0) {
438
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j],
439
+ inter);
440
+ output.push_back({inter.x, inter.y, inter.z});
441
+ }
442
+ output.push_back({input[j].x, input[j].y, input[j].z});
443
+ } else if (orients[j_prev] >= 0) {
444
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter);
445
+ output.push_back({inter.x, inter.y, inter.z});
446
+ }
447
+ }
448
+ }
449
+
450
+ // Clear duplicated points
451
+ for (const auto &point : output) {
452
+ int j = 0;
453
+ bool sameFound = false;
454
+ while (!sameFound && j < target.size()) {
455
+ sameFound = distance_to(point, target[j]) <= 1e-6;
456
+ j++;
457
+ }
458
+
459
+ if (!sameFound) {
460
+ target.push_back(point);
461
+ }
462
+ }
463
+ }
464
+
465
+ bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) {
466
+ const uv_float3 &p1 = t1.a;
467
+ const uv_float3 &q1 = t1.b;
468
+ const uv_float3 &r1 = t1.c;
469
+ const uv_float3 &p2 = t2.a;
470
+ const uv_float3 &r2 = t2.c;
471
+
472
+ if (orient2D(r2, p2, q1) >= 0) { // I
473
+ if (orient2D(r2, p1, q1) >= 0) { // II.a
474
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
475
+ return true;
476
+ } else {
477
+ if (orient2D(p1, p2, r1) >= 0) { // IV.a
478
+ if (orient2D(q1, r1, p2) >= 0) { // V
479
+ return true;
480
+ }
481
+ }
482
+ }
483
+ }
484
+ } else {
485
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
486
+ if (orient2D(q1, r1, r2) >= 0) { // III.b
487
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper)
488
+ return true;
489
+ }
490
+ }
491
+ }
492
+ }
493
+
494
+ return false;
495
+ }
496
+
497
+ bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) {
498
+ const uv_float3 &p1 = t1.a;
499
+ const uv_float3 &q1 = t1.b;
500
+ const uv_float3 &r1 = t1.c;
501
+ const uv_float3 &p2 = t2.a;
502
+ const uv_float3 &q2 = t2.b;
503
+ const uv_float3 &r2 = t2.c;
504
+
505
+ if (orient2D(r2, p2, q1) >= 0) { // I
506
+ if (orient2D(q2, r2, q1) >= 0) { // II.a
507
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
508
+ if (orient2D(p1, q2, q1) <= 0) { // IV.a
509
+ return true;
510
+ }
511
+ } else {
512
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b
513
+ if (orient2D(r2, p2, r1) <= 0) { // V.a
514
+ return true;
515
+ }
516
+ }
517
+ }
518
+ } else {
519
+ if (orient2D(p1, q2, q1) <= 0) { // III.b
520
+ if (orient2D(q2, r2, r1) >= 0) { // IV.c
521
+ if (orient2D(q1, r1, q2) >= 0) { // V.b
522
+ return true;
523
+ }
524
+ }
525
+ }
526
+ }
527
+ } else {
528
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
529
+ if (orient2D(q1, r1, r2) >= 0) { // III.c
530
+ if (orient2D(r1, p1, p2) >= 0) { // IV.d
531
+ return true;
532
+ }
533
+ } else {
534
+ if (orient2D(q1, r1, q2) >= 0) { // IV.e
535
+ if (orient2D(q2, r2, r1) >= 0) { // V.c
536
+ return true;
537
+ }
538
+ }
539
+ }
540
+ }
541
+ }
542
+
543
+ return false;
544
+ }
545
+
546
+ bool coplanarIntersect(Triangle &t1, Triangle &t2,
547
+ std::vector<uv_float3> *target = nullptr) {
548
+ uv_float3 normal, u, v;
549
+ t1.getNormal(normal);
550
+ normal = normalize(normal);
551
+ u = normalize(t1.a - t1.b);
552
+ v = cross(normal, u);
553
+
554
+ // Move basis to t1.a
555
+ u = u + t1.a;
556
+ v = v + t1.a;
557
+ normal = normal + t1.a;
558
+
559
+ Matrix4 _matrix;
560
+ _matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z,
561
+ u.z, v.z, normal.z, 1, 1, 1, 1);
562
+
563
+ Matrix4 _affineMatrix;
564
+ _affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1);
565
+
566
+ _matrix.invert(); // Invert the _matrix
567
+ _matrix = _affineMatrix * _matrix;
568
+
569
+ // Apply transformation
570
+ apply_matrix4(t1.a, _matrix);
571
+ apply_matrix4(t1.b, _matrix);
572
+ apply_matrix4(t1.c, _matrix);
573
+ apply_matrix4(t2.a, _matrix);
574
+ apply_matrix4(t2.b, _matrix);
575
+ apply_matrix4(t2.c, _matrix);
576
+
577
+ makeTriCounterClockwise(t1);
578
+ makeTriCounterClockwise(t2);
579
+
580
+ const uv_float3 &p1 = t1.a;
581
+ const uv_float3 &p2 = t2.a;
582
+ const uv_float3 &q2 = t2.b;
583
+ const uv_float3 &r2 = t2.c;
584
+
585
+ int o_p2q2 = orient2D(p2, q2, p1);
586
+ int o_q2r2 = orient2D(q2, r2, p1);
587
+ int o_r2p2 = orient2D(r2, p2, p1);
588
+
589
+ bool intersecting = false;
590
+ if (o_p2q2 >= 0) {
591
+ if (o_q2r2 >= 0) {
592
+ if (o_r2p2 >= 0) {
593
+ // + + +
594
+ intersecting = true;
595
+ } else {
596
+ // + + -
597
+ intersecting = intersectionTypeR1(t1, t2);
598
+ }
599
+ } else {
600
+ if (o_r2p2 >= 0) {
601
+ // + - +
602
+ permuteTriRight(t2);
603
+ intersecting = intersectionTypeR1(t1, t2);
604
+ } else {
605
+ // + - -
606
+ intersecting = intersectionTypeR2(t1, t2);
607
+ }
608
+ }
609
+ } else {
610
+ if (o_q2r2 >= 0) {
611
+ if (o_r2p2 >= 0) {
612
+ // - + +
613
+ permuteTriLeft(t2);
614
+ intersecting = intersectionTypeR1(t1, t2);
615
+ } else {
616
+ // - + -
617
+ permuteTriLeft(t2);
618
+ intersecting = intersectionTypeR2(t1, t2);
619
+ }
620
+ } else {
621
+ if (o_r2p2 >= 0) {
622
+ // - - +
623
+ permuteTriRight(t2);
624
+ intersecting = intersectionTypeR2(t1, t2);
625
+ } else {
626
+ // - - -
627
+ std::cerr << "Triangles should not be flat." << std::endl;
628
+ return false;
629
+ }
630
+ }
631
+ }
632
+
633
+ if (intersecting && target) {
634
+ clipTriangle(t1, t2, *target);
635
+
636
+ _matrix.invert();
637
+ // Apply the transform to each target point
638
+ for (int i = 0; i < target->size(); ++i) {
639
+ apply_matrix4(target->at(i), _matrix);
640
+ }
641
+ }
642
+
643
+ return intersecting;
644
+ }
645
+
646
+ // Helper function to calculate the area of a polygon
647
+ float polygon_area(const std::vector<uv_float3> &polygon) {
648
+ if (polygon.size() < 3)
649
+ return 0.0f; // Not a polygon
650
+
651
+ uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector
652
+
653
+ // Calculate the cross product of edges around the polygon
654
+ for (size_t i = 0; i < polygon.size(); ++i) {
655
+ uv_float3 p1 = polygon[i];
656
+ uv_float3 p2 = polygon[(i + 1) % polygon.size()];
657
+
658
+ normal = normal + cross(p1, p2); // Accumulate the normal vector
659
+ }
660
+
661
+ float area =
662
+ magnitude(normal) / 2.0f; // Area is half the magnitude of the normal
663
+ return area;
664
+ }
665
+
666
+ bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
667
+ uv_float2 p2, uv_float2 q2, uv_float2 r2) {
668
+ Triangle t1(p1, q1, r1);
669
+ Triangle t2(p2, q2, r2);
670
+
671
+ if (isTriDegenerated(t1) || isTriDegenerated(t2)) {
672
+ // std::cerr << "Degenerated triangles provided, skipping." << std::endl;
673
+ return false;
674
+ }
675
+
676
+ int o1a = orient3D(t2.a, t2.b, t2.c, t1.a);
677
+ int o1b = orient3D(t2.a, t2.b, t2.c, t1.b);
678
+ int o1c = orient3D(t2.a, t2.b, t2.c, t1.c);
679
+
680
+ std::vector<uv_float3> intersections;
681
+ bool intersects;
682
+
683
+ if (o1a == o1b && o1a == o1c) // [[likely]]
684
+ {
685
+ intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections);
686
+ } else // [[unlikely]]
687
+ {
688
+ intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections);
689
+ }
690
+
691
+ if (intersects) {
692
+ float area = polygon_area(intersections);
693
+
694
+ // std::cout << "Intersection area: " << area << std::endl;
695
+ if (area < 1e-10f || std::isfinite(area) == false) {
696
+ // std::cout<<"Invalid area: " << area << std::endl;
697
+ return false; // Ignore intersection if the area is too small
698
+ }
699
+ }
700
+
701
+ return intersects;
702
+ }
uv_unwrapper/uv_unwrapper/csrc/intersect.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+ #include <vector>
5
+
6
+ bool triangle_aabb_intersection(const uv_float2 &aabb_min,
7
+ const uv_float2 &aabb_max, const uv_float2 &v0,
8
+ const uv_float2 &v1, const uv_float2 &v2);
9
+ bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
10
+ uv_float2 p2, uv_float2 q2, uv_float2 r2);
uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "bvh.h"
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/Context.h>
4
+ #include <chrono>
5
+ #include <cmath>
6
+ #include <cstring>
7
+ #include <omp.h>
8
+ #include <set>
9
+ #include <torch/extension.h>
10
+ #include <vector>
11
+
12
+ // #define TIMING
13
+
14
+ #if defined(_MSC_VER)
15
+ #include <BaseTsd.h>
16
+ typedef SSIZE_T ssize_t;
17
+ #endif
18
+
19
+ namespace UVUnwrapper {
20
+ void create_bvhs(BVH *bvhs, Triangle *triangles,
21
+ std::vector<std::set<int>> &triangle_per_face, int num_faces,
22
+ int start, int end) {
23
+ #pragma omp parallel for
24
+ for (int i = start; i < end; i++) {
25
+ int num_triangles = triangle_per_face[i].size();
26
+ Triangle *triangles_per_face = new Triangle[num_triangles];
27
+ int *indices = new int[num_triangles];
28
+ int j = 0;
29
+ for (int idx : triangle_per_face[i]) {
30
+ triangles_per_face[j] = triangles[idx];
31
+ indices[j++] = idx;
32
+ }
33
+ // Each thread writes to it's own memory space
34
+ // First check if the number of triangles is 0
35
+ if (num_triangles == 0) {
36
+ bvhs[i - start] = std::move(BVH()); // Default constructor
37
+ } else {
38
+ bvhs[i - start] = std::move(
39
+ BVH(triangles_per_face, indices,
40
+ num_triangles)); // BVH now handles memory of triangles_per_face
41
+ }
42
+ delete[] triangles_per_face;
43
+ }
44
+ }
45
+
46
+ void perform_intersection_check(BVH *bvhs, int num_bvhs, Triangle *triangles,
47
+ uv_float3 *vertex_tri_centroids,
48
+ int64_t *assign_indices_ptr,
49
+ ssize_t num_indices, int offset,
50
+ std::vector<std::set<int>> &triangle_per_face) {
51
+ std::vector<std::pair<int, int>>
52
+ unique_intersections; // Store unique intersections as pairs of triangle
53
+ // indices
54
+
55
+ // Step 1: Detect intersections in parallel
56
+ #pragma omp parallel for
57
+ for (int i = 0; i < num_indices; i++) {
58
+ if (assign_indices_ptr[i] < offset) {
59
+ continue;
60
+ }
61
+
62
+ Triangle cur_tri = triangles[i];
63
+ auto &cur_bvh = bvhs[assign_indices_ptr[i] - offset];
64
+
65
+ if (cur_bvh.bvhNode == nullptr) {
66
+ continue;
67
+ }
68
+
69
+ std::vector<int> intersections = cur_bvh.Intersect(cur_tri);
70
+
71
+ if (!intersections.empty()) {
72
+
73
+ #pragma omp critical
74
+ {
75
+ for (int intersect : intersections) {
76
+ if (i != intersect) {
77
+ // Ensure we only store unique pairs (A, B) where A < B to avoid
78
+ // duplication
79
+ if (i < intersect) {
80
+ unique_intersections.push_back(std::make_pair(i, intersect));
81
+ } else {
82
+ unique_intersections.push_back(std::make_pair(intersect, i));
83
+ }
84
+ }
85
+ }
86
+ }
87
+ }
88
+ }
89
+
90
+ // Step 2: Process unique intersections
91
+ for (int idx = 0; idx < unique_intersections.size(); idx++) {
92
+ int first = unique_intersections[idx].first;
93
+ int second = unique_intersections[idx].second;
94
+
95
+ int i_idx = assign_indices_ptr[first];
96
+
97
+ int norm_idx = i_idx % 6;
98
+ int axis = (norm_idx < 2) ? 0 : (norm_idx < 4) ? 1 : 2;
99
+ bool use_max = (i_idx % 2) == 1;
100
+
101
+ float pos_a = vertex_tri_centroids[first][axis];
102
+ float pos_b = vertex_tri_centroids[second][axis];
103
+ // Sort the intersections based on vertex_tri_centroids along the specified
104
+ // axis
105
+ if (use_max) {
106
+ if (pos_a < pos_b) {
107
+ std::swap(first, second);
108
+ }
109
+ } else {
110
+ if (pos_a > pos_b) {
111
+ std::swap(first, second);
112
+ }
113
+ }
114
+
115
+ // Update the unique intersections
116
+ unique_intersections[idx].first = first;
117
+ unique_intersections[idx].second = second;
118
+ }
119
+
120
+ // Now only get the second intersections from the pair and put them in a set
121
+ // The second intersection should always be the occluded triangle
122
+ std::set<int> second_intersections;
123
+ for (int idx = 0; idx < (int)unique_intersections.size(); idx++) {
124
+ int second = unique_intersections[idx].second;
125
+ second_intersections.insert(second);
126
+ }
127
+
128
+ for (int int_idx : second_intersections) {
129
+ // Move the second (occluded) triangle by 6
130
+ int intersect_idx = assign_indices_ptr[int_idx];
131
+ int new_index = intersect_idx + 6;
132
+ new_index = std::clamp(new_index, 0, 12);
133
+
134
+ assign_indices_ptr[int_idx] = new_index;
135
+ triangle_per_face[intersect_idx].erase(int_idx);
136
+ triangle_per_face[new_index].insert(int_idx);
137
+ }
138
+ }
139
+
140
+ torch::Tensor assign_faces_uv_to_atlas_index(torch::Tensor vertices,
141
+ torch::Tensor indices,
142
+ torch::Tensor face_uv,
143
+ torch::Tensor face_index) {
144
+ // Get the number of faces
145
+ int num_faces = indices.size(0);
146
+ torch::Tensor assign_indices =
147
+ torch::empty(
148
+ {
149
+ num_faces,
150
+ },
151
+ torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU))
152
+ .contiguous();
153
+
154
+ auto vert_accessor = vertices.accessor<float, 2>();
155
+ auto indices_accessor = indices.accessor<int64_t, 2>();
156
+ auto face_uv_accessor = face_uv.accessor<float, 2>();
157
+
158
+ const int64_t *face_index_ptr = face_index.contiguous().data_ptr<int64_t>();
159
+ int64_t *assign_indices_ptr = assign_indices.data_ptr<int64_t>();
160
+ // copy face_index to assign_indices
161
+ memcpy(assign_indices_ptr, face_index_ptr, num_faces * sizeof(int64_t));
162
+
163
+ #ifdef TIMING
164
+ auto start = std::chrono::high_resolution_clock::now();
165
+ #endif
166
+ uv_float3 *vertex_tri_centroids = new uv_float3[num_faces];
167
+ Triangle *triangles = new Triangle[num_faces];
168
+
169
+ // Use std::set to store triangles for each face
170
+ std::vector<std::set<int>> triangle_per_face;
171
+ triangle_per_face.resize(13);
172
+
173
+ #pragma omp parallel for
174
+ for (int i = 0; i < num_faces; i++) {
175
+ int face_idx = i * 3;
176
+ triangles[i].v0 = {face_uv_accessor[face_idx + 0][0],
177
+ face_uv_accessor[face_idx + 0][1]};
178
+ triangles[i].v1 = {face_uv_accessor[face_idx + 1][0],
179
+ face_uv_accessor[face_idx + 1][1]};
180
+ triangles[i].v2 = {face_uv_accessor[face_idx + 2][0],
181
+ face_uv_accessor[face_idx + 2][1]};
182
+ triangles[i].centroid =
183
+ triangle_centroid(triangles[i].v0, triangles[i].v1, triangles[i].v2);
184
+
185
+ uv_float3 v0 = {vert_accessor[indices_accessor[i][0]][0],
186
+ vert_accessor[indices_accessor[i][0]][1],
187
+ vert_accessor[indices_accessor[i][0]][2]};
188
+ uv_float3 v1 = {vert_accessor[indices_accessor[i][1]][0],
189
+ vert_accessor[indices_accessor[i][1]][1],
190
+ vert_accessor[indices_accessor[i][1]][2]};
191
+ uv_float3 v2 = {vert_accessor[indices_accessor[i][2]][0],
192
+ vert_accessor[indices_accessor[i][2]][1],
193
+ vert_accessor[indices_accessor[i][2]][2]};
194
+ vertex_tri_centroids[i] = triangle_centroid(v0, v1, v2);
195
+
196
+ // Assign the triangle to the face index
197
+ #pragma omp critical
198
+ { triangle_per_face[face_index_ptr[i]].insert(i); }
199
+ }
200
+
201
+ #ifdef TIMING
202
+ auto start_bvh = std::chrono::high_resolution_clock::now();
203
+ #endif
204
+
205
+ BVH *bvhs = new BVH[6];
206
+ create_bvhs(bvhs, triangles, triangle_per_face, num_faces, 0, 6);
207
+
208
+ #ifdef TIMING
209
+ auto end_bvh = std::chrono::high_resolution_clock::now();
210
+ std::chrono::duration<double> elapsed_seconds = end_bvh - start_bvh;
211
+ std::cout << "BVH build time: " << elapsed_seconds.count() << "s\n";
212
+
213
+ auto start_intersection_1 = std::chrono::high_resolution_clock::now();
214
+ #endif
215
+
216
+ perform_intersection_check(bvhs, 6, triangles, vertex_tri_centroids,
217
+ assign_indices_ptr, num_faces, 0,
218
+ triangle_per_face);
219
+
220
+ #ifdef TIMING
221
+ auto end_intersection_1 = std::chrono::high_resolution_clock::now();
222
+ elapsed_seconds = end_intersection_1 - start_intersection_1;
223
+ std::cout << "Intersection 1 time: " << elapsed_seconds.count() << "s\n";
224
+ #endif
225
+ // Create 6 new bvhs and delete the old ones
226
+ BVH *new_bvhs = new BVH[6];
227
+ create_bvhs(new_bvhs, triangles, triangle_per_face, num_faces, 6, 12);
228
+
229
+ #ifdef TIMING
230
+ auto end_bvh2 = std::chrono::high_resolution_clock::now();
231
+ elapsed_seconds = end_bvh2 - end_intersection_1;
232
+ std::cout << "BVH 2 build time: " << elapsed_seconds.count() << "s\n";
233
+ auto start_intersection_2 = std::chrono::high_resolution_clock::now();
234
+ #endif
235
+
236
+ perform_intersection_check(new_bvhs, 6, triangles, vertex_tri_centroids,
237
+ assign_indices_ptr, num_faces, 6,
238
+ triangle_per_face);
239
+
240
+ #ifdef TIMING
241
+ auto end_intersection_2 = std::chrono::high_resolution_clock::now();
242
+ elapsed_seconds = end_intersection_2 - start_intersection_2;
243
+ std::cout << "Intersection 2 time: " << elapsed_seconds.count() << "s\n";
244
+ elapsed_seconds = end_intersection_2 - start;
245
+ std::cout << "Total time: " << elapsed_seconds.count() << "s\n";
246
+ #endif
247
+
248
+ // Cleanup
249
+ delete[] vertex_tri_centroids;
250
+ delete[] triangles;
251
+ delete[] bvhs;
252
+ delete[] new_bvhs;
253
+
254
+ return assign_indices;
255
+ }
256
+
257
+ // Registers _C as a Python extension module.
258
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
259
+
260
+ // Defines the operators
261
+ TORCH_LIBRARY(UVUnwrapper, m) {
262
+ m.def("assign_faces_uv_to_atlas_index(Tensor vertices, Tensor indices, "
263
+ "Tensor face_uv, Tensor face_index) -> Tensor");
264
+ }
265
+
266
+ // Registers CPP implementations
267
+ TORCH_LIBRARY_IMPL(UVUnwrapper, CPU, m) {
268
+ m.impl("assign_faces_uv_to_atlas_index", &assign_faces_uv_to_atlas_index);
269
+ }
270
+
271
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/unwrap.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+
10
+ class Unwrapper(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def _box_assign_vertex_to_cube_face(
15
+ self,
16
+ vertex_positions: Tensor,
17
+ vertex_normals: Tensor,
18
+ triangle_idxs: Tensor,
19
+ bbox: Tensor,
20
+ ) -> Tuple[Tensor, Tensor]:
21
+ """
22
+ Assigns each vertex to a cube face based on the face normal
23
+
24
+ Args:
25
+ vertex_positions (Tensor, Nv 3, float): Vertex positions
26
+ vertex_normals (Tensor, Nv 3, float): Vertex normals
27
+ triangle_idxs (Tensor, Nf 3, int): Triangle indices
28
+ bbox (Tensor, 2 3, float): Bounding box of the mesh
29
+
30
+ Returns:
31
+ Tensor, Nf 3 2, float: UV coordinates
32
+ Tensor, Nf, int: Cube face indices
33
+ """
34
+
35
+ # Test to not have a scaled model to fit the space better
36
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
37
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
38
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
39
+
40
+ # Create a [0, 1] normalized vertex position
41
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
42
+ # And to [-1, 1]
43
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
44
+
45
+ # Get all vertex positions for each triangle
46
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
47
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
48
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
49
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
50
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
51
+
52
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
53
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
54
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
55
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
56
+
57
+ # Just average the normals per face
58
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
59
+
60
+ # Now decide based on the face normal in which box map we project
61
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
62
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
63
+
64
+ axis = torch.tensor(
65
+ [
66
+ [1, 0, 0], # 0
67
+ [-1, 0, 0], # 1
68
+ [0, 1, 0], # 2
69
+ [0, -1, 0], # 3
70
+ [0, 0, 1], # 4
71
+ [0, 0, -1], # 5
72
+ ],
73
+ device=face_normal.device,
74
+ dtype=face_normal.dtype,
75
+ )
76
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
77
+ index = face_normal_axis.argmax(-1)
78
+
79
+ max_axis, uc, vc = (
80
+ torch.ones_like(abs_x),
81
+ torch.zeros_like(tri_stack[..., :1]),
82
+ torch.zeros_like(tri_stack[..., :1]),
83
+ )
84
+ mask_pos_x = index == 0
85
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
86
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
87
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
88
+
89
+ mask_neg_x = index == 1
90
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
91
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
92
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
93
+
94
+ mask_pos_y = index == 2
95
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
96
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
97
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
98
+
99
+ mask_neg_y = index == 3
100
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
101
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
102
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
103
+
104
+ mask_pos_z = index == 4
105
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
106
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
107
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
108
+
109
+ mask_neg_z = index == 5
110
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
111
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
112
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
113
+
114
+ # UC from [-1, 1] to [0, 1]
115
+ max_dim_div = max_axis.max(dim=0, keepdim=True).values
116
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
117
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
118
+
119
+ uv = torch.stack([uc, vc], dim=-1)
120
+
121
+ return uv, index
122
+
123
+ def _assign_faces_uv_to_atlas_index(
124
+ self,
125
+ vertex_positions: Tensor,
126
+ triangle_idxs: Tensor,
127
+ face_uv: Tensor,
128
+ face_index: Tensor,
129
+ ) -> Tensor: # noqa: F821
130
+ """
131
+ Assigns the face UV to the atlas index
132
+
133
+ Args:
134
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
135
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
136
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
137
+ face_index (Integer[Tensor, "Nf"]): Face indices
138
+
139
+ Returns:
140
+ Integer[Tensor, "Nf"]: Atlas index
141
+ """
142
+ return torch.ops.UVUnwrapper.assign_faces_uv_to_atlas_index(
143
+ vertex_positions.cpu(),
144
+ triangle_idxs.cpu(),
145
+ face_uv.view(-1, 2).cpu(),
146
+ face_index.cpu(),
147
+ ).to(vertex_positions.device)
148
+
149
+ def _find_slice_offset_and_scale(
150
+ self, index: Tensor
151
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # noqa: F821
152
+ """
153
+ Find the slice offset and scale
154
+
155
+ Args:
156
+ index (Integer[Tensor, "Nf"]): Atlas index
157
+
158
+ Returns:
159
+ Float[Tensor, "Nf"]: Offset x
160
+ Float[Tensor, "Nf"]: Offset y
161
+ Float[Tensor, "Nf"]: Division x
162
+ Float[Tensor, "Nf"]: Division y
163
+ """
164
+
165
+ # 6 due to the 6 cube faces
166
+ off = 1 / 3
167
+ dupl_off = 1 / 6
168
+
169
+ # Here, we need to decide how to pack the textures in the case of overlap
170
+ def x_offset_calc(x, i):
171
+ offset_calc = i // 6
172
+ # Initial coordinates - just 3x2 grid
173
+ if offset_calc == 0:
174
+ return off * x
175
+ else:
176
+ # Smaller 3x2 grid plus eventual shift to right for
177
+ # second overlap
178
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
179
+
180
+ def y_offset_calc(x, i):
181
+ offset_calc = i // 6
182
+ # Initial coordinates - just a 3x2 grid
183
+ if offset_calc == 0:
184
+ return off * x
185
+ else:
186
+ # Smaller coordinates in the lowest row
187
+ return dupl_off * x + off * 2
188
+
189
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
190
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
191
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
192
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
193
+ for i in range(index.max().item() + 1):
194
+ mask = index == i
195
+ if not mask.any():
196
+ continue
197
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
198
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
199
+
200
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
201
+ # All overlap elements are saved in half scale
202
+ div_x[index >= 6] = 6
203
+ div_y = div_x.clone() # Same for y
204
+ # Except for the random overlaps
205
+ div_x[index >= 12] = 2
206
+ # But the random overlaps are saved in a large block in the lower thirds
207
+ div_y[index >= 12] = 3
208
+
209
+ return offset_x, offset_y, div_x, div_y
210
+
211
+ def _calculate_tangents(
212
+ self,
213
+ vertex_positions: Tensor,
214
+ vertex_normals: Tensor,
215
+ triangle_idxs: Tensor,
216
+ face_uv: Tensor,
217
+ ) -> Tensor:
218
+ """
219
+ Calculate the tangents for each triangle
220
+
221
+ Args:
222
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
223
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
224
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
225
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
226
+
227
+ Returns:
228
+ Float[Tensor, "Nf 3 4"]: Tangents
229
+ """
230
+ vn_idx = [None] * 3
231
+ pos = [None] * 3
232
+ tex = face_uv.unbind(1)
233
+ for i in range(0, 3):
234
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
235
+ # t_nrm_idx is always the same as t_pos_idx
236
+ vn_idx[i] = triangle_idxs[:, i]
237
+
238
+ if(torch.backends.mps.is_available()):
239
+ tangents = torch.zeros_like(vertex_normals).contiguous()
240
+ tansum = torch.zeros_like(vertex_normals).contiguous()
241
+ else:
242
+ tangents = torch.zeros_like(vertex_normals)
243
+ tansum = torch.zeros_like(vertex_normals)
244
+
245
+ # Compute tangent space for each triangle
246
+ duv1 = tex[1] - tex[0]
247
+ duv2 = tex[2] - tex[0]
248
+ dpos1 = pos[1] - pos[0]
249
+ dpos2 = pos[2] - pos[0]
250
+
251
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
252
+
253
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
254
+
255
+ # Avoid division by zero for degenerated texture coordinates
256
+ denom_safe = denom.clip(1e-6)
257
+ tang = tng_nom / denom_safe
258
+
259
+ # Update all 3 vertices
260
+ for i in range(0, 3):
261
+ idx = vn_idx[i][:, None].repeat(1, 3)
262
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
263
+ tansum.scatter_add_(
264
+ 0, idx, torch.ones_like(tang)
265
+ ) # tansum[n_i] = tansum[n_i] + 1
266
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
267
+ # triangles influence the tangent space more
268
+ tangents = tangents / tansum
269
+
270
+ # Normalize and make sure tangent is perpendicular to normal
271
+ tangents = F.normalize(tangents, dim=1)
272
+ tangents = F.normalize(
273
+ tangents
274
+ - (tangents * vertex_normals).sum(-1, keepdim=True) * vertex_normals
275
+ )
276
+
277
+ return tangents
278
+
279
+ def _rotate_uv_slices_consistent_space(
280
+ self,
281
+ vertex_positions: Tensor,
282
+ vertex_normals: Tensor,
283
+ triangle_idxs: Tensor,
284
+ uv: Tensor,
285
+ index: Tensor,
286
+ ) -> Tensor:
287
+ """
288
+ Rotate the UV slices so they are in a consistent space
289
+
290
+ Args:
291
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
292
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
293
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
294
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
295
+ index (Integer[Tensor, "Nf"]): Atlas index
296
+
297
+ Returns:
298
+ Float[Tensor, "Nf 3 2"]: Rotated UV coordinates
299
+ """
300
+
301
+ tangents = self._calculate_tangents(
302
+ vertex_positions, vertex_normals, triangle_idxs, uv
303
+ )
304
+ pos_stack = torch.stack(
305
+ [
306
+ -vertex_positions[..., 1],
307
+ vertex_positions[..., 0],
308
+ torch.zeros_like(vertex_positions[..., 0]),
309
+ ],
310
+ dim=-1,
311
+ )
312
+ expected_tangents = F.normalize(
313
+ torch.linalg.cross(
314
+ vertex_normals,
315
+ torch.linalg.cross(pos_stack, vertex_normals, dim=-1),
316
+ dim=-1,
317
+ ),
318
+ -1,
319
+ )
320
+
321
+ actual_tangents = tangents[triangle_idxs]
322
+ expected_tangents = expected_tangents[triangle_idxs]
323
+
324
+ def rotation_matrix_2d(theta):
325
+ c, s = torch.cos(theta), torch.sin(theta)
326
+ return torch.tensor([[c, -s], [s, c]])
327
+
328
+ # Now find the rotation
329
+ index_mod = index % 6 # Shouldn't happen. Just for safety
330
+ for i in range(6):
331
+ mask = index_mod == i
332
+ if not mask.any():
333
+ continue
334
+
335
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
336
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
337
+
338
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
339
+ cross_product = (
340
+ actual_mean_tangent[0] * expected_mean_tangent[1]
341
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
342
+ )
343
+ angle = torch.atan2(cross_product, dot_product)
344
+
345
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
346
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
347
+ uv_cur = uv[mask] * 2 - 1 # Center it first
348
+ # Rotate it
349
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
350
+
351
+ # Rescale uv[mask] to be within the 0-1 range
352
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
353
+
354
+ return uv
355
+
356
+ def _handle_slice_uvs(
357
+ self,
358
+ uv: Tensor,
359
+ index: Tensor, # noqa: F821
360
+ island_padding: float,
361
+ max_index: int = 6 * 2,
362
+ ) -> Tensor: # noqa: F821
363
+ """
364
+ Handle the slice UVs
365
+
366
+ Args:
367
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
368
+ index (Integer[Tensor, "Nf"]): Atlas index
369
+ island_padding (float): Island padding
370
+ max_index (int): Maximum index
371
+
372
+ Returns:
373
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
374
+
375
+ """
376
+ uc, vc = uv.unbind(-1)
377
+
378
+ # Get the second slice (The first overlap)
379
+ index_filter = [index == i for i in range(6, max_index)]
380
+
381
+ # Normalize them to always fully fill the atlas patch
382
+ for i, fi in enumerate(index_filter):
383
+ if fi.sum() > 0:
384
+ # Scale the slice but only up to a factor of 2
385
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
386
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(
387
+ 0.5
388
+ )
389
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(
390
+ 0.5
391
+ )
392
+
393
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
394
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
395
+
396
+ return torch.stack([uc_padded, vc_padded], dim=-1)
397
+
398
+ def _handle_remaining_uvs(
399
+ self,
400
+ uv: Tensor,
401
+ index: Tensor, # noqa: F821
402
+ island_padding: float,
403
+ ) -> Tensor:
404
+ """
405
+ Handle the remaining UVs (The ones that are not slices)
406
+
407
+ Args:
408
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
409
+ index (Integer[Tensor, "Nf"]): Atlas index
410
+ island_padding (float): Island padding
411
+
412
+ Returns:
413
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
414
+ """
415
+ uc, vc = uv.unbind(-1)
416
+ # Get all remaining elements
417
+ remaining_filter = index >= 6 * 2
418
+ squares_left = remaining_filter.sum()
419
+
420
+ if squares_left == 0:
421
+ return uv
422
+
423
+ uc = uc[remaining_filter]
424
+ vc = vc[remaining_filter]
425
+
426
+ # Or remaining triangles are distributed in a rectangle
427
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
428
+ ratio = 0.5 * (1 / 3) # 1.5
429
+ # sqrt(744/(0.5*(1/3)))
430
+
431
+ mult = math.sqrt(squares_left / ratio)
432
+ num_square_width = int(math.ceil(0.5 * mult))
433
+ num_square_height = int(math.ceil(squares_left / num_square_width))
434
+
435
+ width = 1 / num_square_width
436
+ height = 1 / num_square_height
437
+
438
+ # The idea is again to keep the texture resolution consistent with the first slice
439
+ # This only occupys half the region in the texture chart but the scaling on the squares
440
+ # assumes full coverage.
441
+ clip_val = min(width, height) * 1.5
442
+ # Now normalize the UVs with taking into account the maximum scaling
443
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
444
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
445
+ ).clip(clip_val)
446
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
447
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
448
+ ).clip(clip_val)
449
+ # Add a small padding
450
+ uc = (
451
+ uc * (1 - island_padding * num_square_width * 0.5)
452
+ + island_padding * num_square_width * 0.25
453
+ ).clip(0, 1)
454
+ vc = (
455
+ vc * (1 - island_padding * num_square_height * 0.5)
456
+ + island_padding * num_square_height * 0.25
457
+ ).clip(0, 1)
458
+
459
+ uc = uc * width
460
+ vc = vc * height
461
+
462
+ # And calculate offsets for each element
463
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
464
+ x_idx = idx % num_square_width
465
+ y_idx = idx // num_square_width
466
+ # And move each triangle to its own spot
467
+ uc = uc + x_idx[:, None] * width
468
+ vc = vc + y_idx[:, None] * height
469
+
470
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
471
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
472
+
473
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
474
+
475
+ return uv
476
+
477
+ def _distribute_individual_uvs_in_atlas(
478
+ self,
479
+ face_uv: Tensor,
480
+ assigned_faces: Tensor,
481
+ offset_x: Tensor,
482
+ offset_y: Tensor,
483
+ div_x: Tensor,
484
+ div_y: Tensor,
485
+ island_padding: float,
486
+ ) -> Tensor:
487
+ """
488
+ Distribute the individual UVs in the atlas
489
+
490
+ Args:
491
+ face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
492
+ assigned_faces (Integer[Tensor, "Nf"]): Assigned faces
493
+ offset_x (Float[Tensor, "Nf"]): Offset x
494
+ offset_y (Float[Tensor, "Nf"]): Offset y
495
+ div_x (Float[Tensor, "Nf"]): Division x
496
+ div_y (Float[Tensor, "Nf"]): Division y
497
+ island_padding (float): Island padding
498
+
499
+ Returns:
500
+ Float[Tensor, "Nf 3 2"]: Updated UV coordinates
501
+ """
502
+ # Place the slice first
503
+ placed_uv = self._handle_slice_uvs(face_uv, assigned_faces, island_padding)
504
+ # Then handle the remaining overlap elements
505
+ placed_uv = self._handle_remaining_uvs(
506
+ placed_uv, assigned_faces, island_padding
507
+ )
508
+
509
+ uc, vc = placed_uv.unbind(-1)
510
+ uc = uc / div_x[:, None] + offset_x[:, None]
511
+ vc = vc / div_y[:, None] + offset_y[:, None]
512
+
513
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
514
+
515
+ return uv
516
+
517
+ def _get_unique_face_uv(
518
+ self,
519
+ uv: Tensor,
520
+ ) -> Tuple[Tensor, Tensor]:
521
+ """
522
+ Get the unique face UV
523
+
524
+ Args:
525
+ uv (Float[Tensor, "Nf 3 2"]): UV coordinates
526
+
527
+ Returns:
528
+ Float[Tensor, "Utex 3"]: Unique UV coordinates
529
+ Integer[Tensor, "Nf"]: Vertex index
530
+ """
531
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
532
+ # And add the face to uv index mapping
533
+ vtex_idx = unique_idx.view(-1, 3)
534
+
535
+ return unique_uv, vtex_idx
536
+
537
+ def _align_mesh_with_main_axis(
538
+ self, vertex_positions: Tensor, vertex_normals: Tensor
539
+ ) -> Tuple[Tensor, Tensor]:
540
+ """
541
+ Align the mesh with the main axis
542
+
543
+ Args:
544
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
545
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
546
+
547
+ Returns:
548
+ Float[Tensor, "Nv 3"]: Rotated vertex positions
549
+ Float[Tensor, "Nv 3"]: Rotated vertex normals
550
+ """
551
+
552
+ # Use pca to find the 2 main axis (third is derived by cross product)
553
+ # Set the random seed so it's repeatable
554
+ torch.manual_seed(0)
555
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
556
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
557
+
558
+ main_axis = F.normalize(main_axis, eps=1e-6, dim=-1) # 3,
559
+ # Orthogonalize the second axis
560
+ seconday_axis = F.normalize(
561
+ seconday_axis
562
+ - (seconday_axis * main_axis).sum(-1, keepdim=True) * main_axis,
563
+ eps=1e-6,
564
+ dim=-1,
565
+ ) # 3,
566
+ # Create perpendicular third axis
567
+ third_axis = F.normalize(
568
+ torch.cross(main_axis, seconday_axis, dim=-1), dim=-1, eps=1e-6
569
+ ) # 3,
570
+
571
+ # Check to which canonical axis each aligns
572
+ main_axis_max_idx = main_axis.abs().argmax().item()
573
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
574
+ third_axis_max_idx = third_axis.abs().argmax().item()
575
+
576
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
577
+ # If two axes have the same argmax move one of them
578
+ all_possible_axis = {0, 1, 2}
579
+ cur_index = 1
580
+ while (
581
+ len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]))
582
+ != 3
583
+ ):
584
+ # Find missing axis
585
+ missing_axis = all_possible_axis - set(
586
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
587
+ )
588
+ missing_axis = missing_axis.pop()
589
+ # Just assign it to third axis as it had the smallest contribution to the
590
+ # overall shape
591
+ if cur_index == 1:
592
+ third_axis_max_idx = missing_axis
593
+ elif cur_index == 2:
594
+ seconday_axis_max_idx = missing_axis
595
+ else:
596
+ raise ValueError("Could not find 3 unique axis")
597
+ cur_index += 1
598
+
599
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
600
+ raise ValueError("Could not find 3 unique axis")
601
+
602
+ axes = [None] * 3
603
+ axes[main_axis_max_idx] = main_axis
604
+ axes[seconday_axis_max_idx] = seconday_axis
605
+ axes[third_axis_max_idx] = third_axis
606
+ # Create rotation matrix from the individual axes
607
+ rot_mat = torch.stack(axes, dim=1).T
608
+
609
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
610
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
611
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
612
+
613
+ return vertex_positions, vertex_normals
614
+
615
+ def forward(
616
+ self,
617
+ vertex_positions: Tensor,
618
+ vertex_normals: Tensor,
619
+ triangle_idxs: Tensor,
620
+ island_padding: float,
621
+ ) -> Tuple[Tensor, Tensor]:
622
+ """
623
+ Unwrap the mesh
624
+
625
+ Args:
626
+ vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
627
+ vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
628
+ triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
629
+ island_padding (float): Island padding
630
+
631
+ Returns:
632
+ Float[Tensor, "Utex 3"]: Unique UV coordinates
633
+ Integer[Tensor, "Nf"]: Vertex index
634
+ """
635
+ vertex_positions, vertex_normals = self._align_mesh_with_main_axis(
636
+ vertex_positions, vertex_normals
637
+ )
638
+ bbox = torch.stack(
639
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values],
640
+ dim=0,
641
+ ) # 2, 3
642
+
643
+ face_uv, face_index = self._box_assign_vertex_to_cube_face(
644
+ vertex_positions, vertex_normals, triangle_idxs, bbox
645
+ )
646
+
647
+ face_uv = self._rotate_uv_slices_consistent_space(
648
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
649
+ )
650
+
651
+ assigned_atlas_index = self._assign_faces_uv_to_atlas_index(
652
+ vertex_positions, triangle_idxs, face_uv, face_index
653
+ )
654
+
655
+ offset_x, offset_y, div_x, div_y = self._find_slice_offset_and_scale(
656
+ assigned_atlas_index
657
+ )
658
+
659
+ placed_uv = self._distribute_individual_uvs_in_atlas(
660
+ face_uv,
661
+ assigned_atlas_index,
662
+ offset_x,
663
+ offset_y,
664
+ div_x,
665
+ div_y,
666
+ island_padding,
667
+ )
668
+
669
+ return self._get_unique_face_uv(placed_uv)