Spaces:
Running
on
L4
Running
on
L4
Update to latest inference code
Browse filesCo-authored-by: Mark Boss <[email protected]>
- .gitattributes +1 -0
- .gitignore +167 -0
- .pre-commit-config.yaml +24 -0
- README.md +2 -4
- __init__.py +201 -0
- demo_files/scatterplot.jpg +0 -0
- demo_files/workflows/sf3d_example.json +254 -0
- app.py → gradio_app.py +87 -71
- requirements.txt +12 -4
- ruff.toml +3 -0
- run.py +141 -0
- sf3d/models/image_estimator/clip_based_estimator.py +1 -1
- sf3d/models/mesh.py +119 -2
- sf3d/models/network.py +21 -3
- sf3d/models/utils.py +1 -57
- sf3d/system.py +93 -43
- sf3d/utils.py +48 -34
- texture_baker/README.md +26 -0
- texture_baker/requirements.txt +2 -0
- texture_baker/setup.py +124 -0
- texture_baker/texture_baker/__init__.py +4 -0
- texture_baker/texture_baker/baker.py +86 -0
- texture_baker/texture_baker/csrc/baker.cpp +548 -0
- texture_baker/texture_baker/csrc/baker.h +203 -0
- texture_baker/texture_baker/csrc/baker_kernel.cu +301 -0
- texture_baker/texture_baker/csrc/baker_kernel.metal +170 -0
- texture_baker/texture_baker/csrc/baker_kernel.mm +260 -0
- uv_unwrapper/README.md +0 -0
- uv_unwrapper/requirements.txt +2 -0
- uv_unwrapper/setup.py +79 -0
- uv_unwrapper/uv_unwrapper/__init__.py +6 -0
- uv_unwrapper/uv_unwrapper/csrc/bvh.cpp +380 -0
- uv_unwrapper/uv_unwrapper/csrc/bvh.h +118 -0
- uv_unwrapper/uv_unwrapper/csrc/common.h +493 -0
- uv_unwrapper/uv_unwrapper/csrc/intersect.cpp +702 -0
- uv_unwrapper/uv_unwrapper/csrc/intersect.h +10 -0
- uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp +271 -0
- uv_unwrapper/uv_unwrapper/unwrap.py +669 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv*/
|
127 |
+
env/
|
128 |
+
venv*/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
162 |
+
.vs/
|
163 |
+
.idea/
|
164 |
+
.vscode/
|
165 |
+
|
166 |
+
stabilityai/
|
167 |
+
output/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3
|
3 |
+
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
6 |
+
rev: v4.4.0
|
7 |
+
hooks:
|
8 |
+
- id: trailing-whitespace
|
9 |
+
- id: check-ast
|
10 |
+
- id: check-merge-conflict
|
11 |
+
- id: check-yaml
|
12 |
+
- id: end-of-file-fixer
|
13 |
+
- id: trailing-whitespace
|
14 |
+
args: [--markdown-linebreak-ext=md]
|
15 |
+
|
16 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
17 |
+
# Ruff version.
|
18 |
+
rev: v0.3.5
|
19 |
+
hooks:
|
20 |
+
# Run the linter.
|
21 |
+
- id: ruff
|
22 |
+
args: [ --fix ]
|
23 |
+
# Run the formatter.
|
24 |
+
- id: ruff-format
|
README.md
CHANGED
@@ -4,9 +4,9 @@ emoji: 🎮
|
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
python_version: 3.10.13
|
9 |
-
app_file:
|
10 |
pinned: false
|
11 |
models:
|
12 |
- stabilityai/stable-fast-3d
|
@@ -14,5 +14,3 @@ license: other
|
|
14 |
license_name: stabilityai-ai-community
|
15 |
license_link: LICENSE.md
|
16 |
---
|
17 |
-
|
18 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.41.0
|
8 |
python_version: 3.10.13
|
9 |
+
app_file: gradio_app.py
|
10 |
pinned: false
|
11 |
models:
|
12 |
- stabilityai/stable-fast-3d
|
|
|
14 |
license_name: stabilityai-ai-community
|
15 |
license_link: LICENSE.md
|
16 |
---
|
|
|
|
__init__.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from contextlib import nullcontext
|
6 |
+
|
7 |
+
import comfy.model_management
|
8 |
+
import folder_paths
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import trimesh
|
12 |
+
from PIL import Image
|
13 |
+
from trimesh.exchange import gltf
|
14 |
+
|
15 |
+
sys.path.append(os.path.dirname(__file__))
|
16 |
+
from sf3d.system import SF3D
|
17 |
+
from sf3d.utils import resize_foreground
|
18 |
+
|
19 |
+
SF3D_CATEGORY = "StableFast3D"
|
20 |
+
SF3D_MODEL_NAME = "stabilityai/stable-fast-3d"
|
21 |
+
|
22 |
+
|
23 |
+
class StableFast3DLoader:
|
24 |
+
CATEGORY = SF3D_CATEGORY
|
25 |
+
FUNCTION = "load"
|
26 |
+
RETURN_NAMES = ("sf3d_model",)
|
27 |
+
RETURN_TYPES = ("SF3D_MODEL",)
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def INPUT_TYPES(cls):
|
31 |
+
return {"required": {}}
|
32 |
+
|
33 |
+
def load(self):
|
34 |
+
device = comfy.model_management.get_torch_device()
|
35 |
+
model = SF3D.from_pretrained(
|
36 |
+
SF3D_MODEL_NAME,
|
37 |
+
config_name="config.yaml",
|
38 |
+
weight_name="model.safetensors",
|
39 |
+
)
|
40 |
+
model.to(device)
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
return (model,)
|
44 |
+
|
45 |
+
|
46 |
+
class StableFast3DPreview:
|
47 |
+
CATEGORY = SF3D_CATEGORY
|
48 |
+
FUNCTION = "preview"
|
49 |
+
OUTPUT_NODE = True
|
50 |
+
RETURN_TYPES = ()
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def INPUT_TYPES(s):
|
54 |
+
return {"required": {"mesh": ("MESH",)}}
|
55 |
+
|
56 |
+
def preview(self, mesh):
|
57 |
+
glbs = []
|
58 |
+
for m in mesh:
|
59 |
+
scene = trimesh.Scene(m)
|
60 |
+
glb_data = gltf.export_glb(scene, include_normals=True)
|
61 |
+
glb_base64 = base64.b64encode(glb_data).decode("utf-8")
|
62 |
+
glbs.append(glb_base64)
|
63 |
+
return {"ui": {"glbs": glbs}}
|
64 |
+
|
65 |
+
|
66 |
+
class StableFast3DSampler:
|
67 |
+
CATEGORY = SF3D_CATEGORY
|
68 |
+
FUNCTION = "predict"
|
69 |
+
RETURN_NAMES = ("mesh",)
|
70 |
+
RETURN_TYPES = ("MESH",)
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def INPUT_TYPES(s):
|
74 |
+
return {
|
75 |
+
"required": {
|
76 |
+
"model": ("SF3D_MODEL",),
|
77 |
+
"image": ("IMAGE",),
|
78 |
+
"foreground_ratio": (
|
79 |
+
"FLOAT",
|
80 |
+
{"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.01},
|
81 |
+
),
|
82 |
+
"texture_resolution": (
|
83 |
+
"INT",
|
84 |
+
{"default": 1024, "min": 512, "max": 2048, "step": 256},
|
85 |
+
),
|
86 |
+
},
|
87 |
+
"optional": {
|
88 |
+
"mask": ("MASK",),
|
89 |
+
"remesh": (["none", "triangle", "quad"],),
|
90 |
+
"vertex_count": (
|
91 |
+
"INT",
|
92 |
+
{"default": -1, "min": -1, "max": 20000, "step": 1},
|
93 |
+
),
|
94 |
+
},
|
95 |
+
}
|
96 |
+
|
97 |
+
def predict(
|
98 |
+
s,
|
99 |
+
model,
|
100 |
+
image,
|
101 |
+
mask,
|
102 |
+
foreground_ratio,
|
103 |
+
texture_resolution,
|
104 |
+
remesh="none",
|
105 |
+
vertex_count=-1,
|
106 |
+
):
|
107 |
+
if image.shape[0] != 1:
|
108 |
+
raise ValueError("Only one image can be processed at a time")
|
109 |
+
|
110 |
+
pil_image = Image.fromarray(
|
111 |
+
torch.clamp(torch.round(255.0 * image[0]), 0, 255)
|
112 |
+
.type(torch.uint8)
|
113 |
+
.cpu()
|
114 |
+
.numpy()
|
115 |
+
)
|
116 |
+
|
117 |
+
if mask is not None:
|
118 |
+
print("Using Mask")
|
119 |
+
mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
|
120 |
+
np.uint8
|
121 |
+
)
|
122 |
+
mask_pil = Image.fromarray(mask_np, mode="L")
|
123 |
+
pil_image.putalpha(mask_pil)
|
124 |
+
else:
|
125 |
+
if image.shape[3] != 4:
|
126 |
+
print("No mask or alpha channel detected, Converting to RGBA")
|
127 |
+
pil_image = pil_image.convert("RGBA")
|
128 |
+
|
129 |
+
pil_image = resize_foreground(pil_image, foreground_ratio)
|
130 |
+
print(remesh)
|
131 |
+
with torch.no_grad():
|
132 |
+
with torch.autocast(
|
133 |
+
device_type="cuda", dtype=torch.bfloat16
|
134 |
+
) if "cuda" in comfy.model_management.get_torch_device().type else nullcontext():
|
135 |
+
mesh, glob_dict = model.run_image(
|
136 |
+
pil_image,
|
137 |
+
bake_resolution=texture_resolution,
|
138 |
+
remesh=remesh,
|
139 |
+
vertex_count=vertex_count,
|
140 |
+
)
|
141 |
+
|
142 |
+
if mesh.vertices.shape[0] == 0:
|
143 |
+
raise ValueError("No subject detected in the image")
|
144 |
+
|
145 |
+
return ([mesh],)
|
146 |
+
|
147 |
+
|
148 |
+
class StableFast3DSave:
|
149 |
+
CATEGORY = SF3D_CATEGORY
|
150 |
+
FUNCTION = "save"
|
151 |
+
OUTPUT_NODE = True
|
152 |
+
RETURN_TYPES = ()
|
153 |
+
|
154 |
+
@classmethod
|
155 |
+
def INPUT_TYPES(s):
|
156 |
+
return {
|
157 |
+
"required": {
|
158 |
+
"mesh": ("MESH",),
|
159 |
+
"filename_prefix": ("STRING", {"default": "SF3D"}),
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
def __init__(self):
|
164 |
+
self.type = "output"
|
165 |
+
|
166 |
+
def save(self, mesh, filename_prefix):
|
167 |
+
output_dir = folder_paths.get_output_directory()
|
168 |
+
glbs = []
|
169 |
+
for idx, m in enumerate(mesh):
|
170 |
+
scene = trimesh.Scene(m)
|
171 |
+
glb_data = gltf.export_glb(scene, include_normals=True)
|
172 |
+
logging.info(f"Generated GLB model with {len(glb_data)} bytes")
|
173 |
+
|
174 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
175 |
+
folder_paths.get_save_image_path(filename_prefix, output_dir)
|
176 |
+
)
|
177 |
+
filename = filename.replace("%batch_num%", str(idx))
|
178 |
+
out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
|
179 |
+
with open(out_path, "wb") as f:
|
180 |
+
f.write(glb_data)
|
181 |
+
glbs.append(base64.b64encode(glb_data).decode("utf-8"))
|
182 |
+
return {"ui": {"glbs": glbs}}
|
183 |
+
|
184 |
+
|
185 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
186 |
+
"StableFast3DLoader": "Stable Fast 3D Loader",
|
187 |
+
"StableFast3DPreview": "Stable Fast 3D Preview",
|
188 |
+
"StableFast3DSampler": "Stable Fast 3D Sampler",
|
189 |
+
"StableFast3DSave": "Stable Fast 3D Save",
|
190 |
+
}
|
191 |
+
|
192 |
+
NODE_CLASS_MAPPINGS = {
|
193 |
+
"StableFast3DLoader": StableFast3DLoader,
|
194 |
+
"StableFast3DPreview": StableFast3DPreview,
|
195 |
+
"StableFast3DSampler": StableFast3DSampler,
|
196 |
+
"StableFast3DSave": StableFast3DSave,
|
197 |
+
}
|
198 |
+
|
199 |
+
WEB_DIRECTORY = "./comfyui"
|
200 |
+
|
201 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
|
demo_files/scatterplot.jpg
CHANGED
![]() |
![]() |
demo_files/workflows/sf3d_example.json
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"last_node_id": 10,
|
3 |
+
"last_link_id": 12,
|
4 |
+
"nodes": [
|
5 |
+
{
|
6 |
+
"id": 8,
|
7 |
+
"type": "StableFast3DSampler",
|
8 |
+
"pos": [
|
9 |
+
756.9950672198843,
|
10 |
+
9.735666739723854
|
11 |
+
],
|
12 |
+
"size": {
|
13 |
+
"0": 315,
|
14 |
+
"1": 166
|
15 |
+
},
|
16 |
+
"flags": {},
|
17 |
+
"order": 3,
|
18 |
+
"mode": 0,
|
19 |
+
"inputs": [
|
20 |
+
{
|
21 |
+
"name": "model",
|
22 |
+
"type": "SF3D_MODEL",
|
23 |
+
"link": 8
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "image",
|
27 |
+
"type": "IMAGE",
|
28 |
+
"link": 10,
|
29 |
+
"slot_index": 1
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "mask",
|
33 |
+
"type": "MASK",
|
34 |
+
"link": 11
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "remesh",
|
38 |
+
"type": "none",
|
39 |
+
"link": null,
|
40 |
+
"slot_index": 3
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"name": "mesh",
|
46 |
+
"type": "MESH",
|
47 |
+
"links": [
|
48 |
+
9
|
49 |
+
],
|
50 |
+
"shape": 3,
|
51 |
+
"slot_index": 0
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"properties": {
|
55 |
+
"Node name for S&R": "StableFast3DSampler"
|
56 |
+
},
|
57 |
+
"widgets_values": [
|
58 |
+
0.85,
|
59 |
+
1024,
|
60 |
+
"triangle"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"id": 9,
|
65 |
+
"type": "StableFast3DSave",
|
66 |
+
"pos": [
|
67 |
+
1116,
|
68 |
+
8
|
69 |
+
],
|
70 |
+
"size": [
|
71 |
+
600,
|
72 |
+
512
|
73 |
+
],
|
74 |
+
"flags": {},
|
75 |
+
"order": 4,
|
76 |
+
"mode": 0,
|
77 |
+
"inputs": [
|
78 |
+
{
|
79 |
+
"name": "mesh",
|
80 |
+
"type": "MESH",
|
81 |
+
"link": 9
|
82 |
+
}
|
83 |
+
],
|
84 |
+
"properties": {
|
85 |
+
"Node name for S&R": "StableFast3DSave"
|
86 |
+
},
|
87 |
+
"widgets_values": [
|
88 |
+
"SF3D",
|
89 |
+
null
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"id": 6,
|
94 |
+
"type": "InvertMask",
|
95 |
+
"pos": [
|
96 |
+
485,
|
97 |
+
132
|
98 |
+
],
|
99 |
+
"size": {
|
100 |
+
"0": 210,
|
101 |
+
"1": 26
|
102 |
+
},
|
103 |
+
"flags": {},
|
104 |
+
"order": 2,
|
105 |
+
"mode": 0,
|
106 |
+
"inputs": [
|
107 |
+
{
|
108 |
+
"name": "mask",
|
109 |
+
"type": "MASK",
|
110 |
+
"link": 6
|
111 |
+
}
|
112 |
+
],
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"name": "MASK",
|
116 |
+
"type": "MASK",
|
117 |
+
"links": [
|
118 |
+
11
|
119 |
+
],
|
120 |
+
"shape": 3,
|
121 |
+
"slot_index": 0
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"properties": {
|
125 |
+
"Node name for S&R": "InvertMask"
|
126 |
+
}
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"id": 1,
|
130 |
+
"type": "LoadImage",
|
131 |
+
"pos": [
|
132 |
+
105,
|
133 |
+
26
|
134 |
+
],
|
135 |
+
"size": {
|
136 |
+
"0": 315,
|
137 |
+
"1": 314
|
138 |
+
},
|
139 |
+
"flags": {},
|
140 |
+
"order": 0,
|
141 |
+
"mode": 0,
|
142 |
+
"outputs": [
|
143 |
+
{
|
144 |
+
"name": "IMAGE",
|
145 |
+
"type": "IMAGE",
|
146 |
+
"links": [
|
147 |
+
10
|
148 |
+
],
|
149 |
+
"shape": 3,
|
150 |
+
"slot_index": 0
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"name": "MASK",
|
154 |
+
"type": "MASK",
|
155 |
+
"links": [
|
156 |
+
6
|
157 |
+
],
|
158 |
+
"shape": 3,
|
159 |
+
"slot_index": 1
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"properties": {
|
163 |
+
"Node name for S&R": "LoadImage"
|
164 |
+
},
|
165 |
+
"widgets_values": [
|
166 |
+
"axe (1).png",
|
167 |
+
"image"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"id": 7,
|
172 |
+
"type": "StableFast3DLoader",
|
173 |
+
"pos": [
|
174 |
+
478,
|
175 |
+
-27
|
176 |
+
],
|
177 |
+
"size": {
|
178 |
+
"0": 210,
|
179 |
+
"1": 26
|
180 |
+
},
|
181 |
+
"flags": {},
|
182 |
+
"order": 1,
|
183 |
+
"mode": 0,
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"name": "sf3d_model",
|
187 |
+
"type": "SF3D_MODEL",
|
188 |
+
"links": [
|
189 |
+
8
|
190 |
+
],
|
191 |
+
"shape": 3,
|
192 |
+
"slot_index": 0
|
193 |
+
}
|
194 |
+
],
|
195 |
+
"properties": {
|
196 |
+
"Node name for S&R": "StableFast3DLoader"
|
197 |
+
}
|
198 |
+
}
|
199 |
+
],
|
200 |
+
"links": [
|
201 |
+
[
|
202 |
+
6,
|
203 |
+
1,
|
204 |
+
1,
|
205 |
+
6,
|
206 |
+
0,
|
207 |
+
"MASK"
|
208 |
+
],
|
209 |
+
[
|
210 |
+
8,
|
211 |
+
7,
|
212 |
+
0,
|
213 |
+
8,
|
214 |
+
0,
|
215 |
+
"SF3D_MODEL"
|
216 |
+
],
|
217 |
+
[
|
218 |
+
9,
|
219 |
+
8,
|
220 |
+
0,
|
221 |
+
9,
|
222 |
+
0,
|
223 |
+
"MESH"
|
224 |
+
],
|
225 |
+
[
|
226 |
+
10,
|
227 |
+
1,
|
228 |
+
0,
|
229 |
+
8,
|
230 |
+
1,
|
231 |
+
"IMAGE"
|
232 |
+
],
|
233 |
+
[
|
234 |
+
11,
|
235 |
+
6,
|
236 |
+
0,
|
237 |
+
8,
|
238 |
+
2,
|
239 |
+
"MASK"
|
240 |
+
]
|
241 |
+
],
|
242 |
+
"groups": [],
|
243 |
+
"config": {},
|
244 |
+
"extra": {
|
245 |
+
"ds": {
|
246 |
+
"scale": 0.6209213230591552,
|
247 |
+
"offset": [
|
248 |
+
80.89139921077967,
|
249 |
+
610.3296066172098
|
250 |
+
]
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"version": 0.4
|
254 |
+
}
|
app.py → gradio_app.py
RENAMED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
import time
|
|
|
4 |
from functools import lru_cache
|
5 |
from typing import Any
|
6 |
|
@@ -11,9 +12,13 @@ import torch
|
|
11 |
from gradio_litmodel3d import LitModel3D
|
12 |
from PIL import Image
|
13 |
|
|
|
|
|
14 |
import sf3d.utils as sf3d_utils
|
15 |
from sf3d.system import SF3D
|
16 |
|
|
|
|
|
17 |
rembg_session = rembg.new_session()
|
18 |
|
19 |
COND_WIDTH = 512
|
@@ -28,32 +33,48 @@ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
|
|
28 |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
model = SF3D.from_pretrained(
|
33 |
"stabilityai/stable-fast-3d",
|
34 |
config_name="config.yaml",
|
35 |
weight_name="model.safetensors",
|
36 |
)
|
37 |
-
model.eval()
|
|
|
38 |
|
39 |
example_files = [
|
40 |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
41 |
]
|
42 |
|
43 |
|
44 |
-
def run_model(input_image):
|
45 |
start = time.time()
|
46 |
with torch.no_grad():
|
47 |
-
with torch.autocast(
|
|
|
|
|
48 |
model_batch = create_batch(input_image)
|
49 |
-
model_batch = {k: v.
|
50 |
-
trimesh_mesh, _glob_dict = model.generate_mesh(
|
|
|
|
|
51 |
trimesh_mesh = trimesh_mesh[0]
|
52 |
|
53 |
# Create new tmp file
|
54 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
55 |
|
56 |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
|
|
57 |
|
58 |
print("Generation took:", time.time() - start, "s")
|
59 |
|
@@ -104,61 +125,6 @@ def remove_background(input_image: Image) -> Image:
|
|
104 |
return rembg.remove(input_image, session=rembg_session)
|
105 |
|
106 |
|
107 |
-
def resize_foreground(
|
108 |
-
image: Image,
|
109 |
-
ratio: float,
|
110 |
-
) -> Image:
|
111 |
-
image = np.array(image)
|
112 |
-
assert image.shape[-1] == 4
|
113 |
-
alpha = np.where(image[..., 3] > 0)
|
114 |
-
y1, y2, x1, x2 = (
|
115 |
-
alpha[0].min(),
|
116 |
-
alpha[0].max(),
|
117 |
-
alpha[1].min(),
|
118 |
-
alpha[1].max(),
|
119 |
-
)
|
120 |
-
# crop the foreground
|
121 |
-
fg = image[y1:y2, x1:x2]
|
122 |
-
# pad to square
|
123 |
-
size = max(fg.shape[0], fg.shape[1])
|
124 |
-
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
125 |
-
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
126 |
-
new_image = np.pad(
|
127 |
-
fg,
|
128 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
129 |
-
mode="constant",
|
130 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
131 |
-
)
|
132 |
-
|
133 |
-
# compute padding according to the ratio
|
134 |
-
new_size = int(new_image.shape[0] / ratio)
|
135 |
-
# pad to size, double side
|
136 |
-
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
137 |
-
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
138 |
-
new_image = np.pad(
|
139 |
-
new_image,
|
140 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
141 |
-
mode="constant",
|
142 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
143 |
-
)
|
144 |
-
new_image = Image.fromarray(new_image, mode="RGBA").resize(
|
145 |
-
(COND_WIDTH, COND_HEIGHT)
|
146 |
-
)
|
147 |
-
return new_image
|
148 |
-
|
149 |
-
|
150 |
-
def square_crop(input_image: Image) -> Image:
|
151 |
-
# Perform a center square crop
|
152 |
-
min_size = min(input_image.size)
|
153 |
-
left = (input_image.size[0] - min_size) // 2
|
154 |
-
top = (input_image.size[1] - min_size) // 2
|
155 |
-
right = (input_image.size[0] + min_size) // 2
|
156 |
-
bottom = (input_image.size[1] + min_size) // 2
|
157 |
-
return input_image.crop((left, top, right, bottom)).resize(
|
158 |
-
(COND_WIDTH, COND_HEIGHT)
|
159 |
-
)
|
160 |
-
|
161 |
-
|
162 |
def show_mask_img(input_image: Image) -> Image:
|
163 |
img_numpy = np.array(input_image)
|
164 |
alpha = img_numpy[:, :, 3] / 255.0
|
@@ -167,9 +133,27 @@ def show_mask_img(input_image: Image) -> Image:
|
|
167 |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
168 |
|
169 |
|
170 |
-
def run_button(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
if run_btn == "Run":
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
return (
|
175 |
gr.update(),
|
@@ -182,12 +166,13 @@ def run_button(run_btn, input_image, background_state, foreground_ratio):
|
|
182 |
elif run_btn == "Remove Background":
|
183 |
rem_removed = remove_background(input_image)
|
184 |
|
185 |
-
|
186 |
-
|
|
|
187 |
|
188 |
return (
|
189 |
gr.update(value="Run", visible=True),
|
190 |
-
|
191 |
fr_res,
|
192 |
gr.update(value=show_mask_img(fr_res), visible=True),
|
193 |
gr.update(value=None, visible=False),
|
@@ -210,11 +195,12 @@ def requires_bg_remove(image, fr):
|
|
210 |
|
211 |
if min_alpha == 0:
|
212 |
print("Already has alpha")
|
213 |
-
|
214 |
-
|
|
|
215 |
return (
|
216 |
gr.update(value="Run", visible=True),
|
217 |
-
|
218 |
fr_res,
|
219 |
gr.update(value=show_mask_img(fr_res), visible=True),
|
220 |
gr.update(visible=False),
|
@@ -231,7 +217,9 @@ def requires_bg_remove(image, fr):
|
|
231 |
|
232 |
|
233 |
def update_foreground_ratio(img_proc, fr):
|
234 |
-
foreground_res = resize_foreground(
|
|
|
|
|
235 |
return (
|
236 |
foreground_res,
|
237 |
gr.update(value=show_mask_img(foreground_res)),
|
@@ -250,7 +238,8 @@ with gr.Blocks() as demo:
|
|
250 |
**Tips**
|
251 |
1. If the image already has an alpha channel, you can skip the background removal step.
|
252 |
2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
|
253 |
-
3. You can
|
|
|
254 |
""")
|
255 |
with gr.Row(variant="panel"):
|
256 |
with gr.Column():
|
@@ -280,6 +269,30 @@ with gr.Blocks() as demo:
|
|
280 |
outputs=[background_remove_state, preview_removal],
|
281 |
)
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
run_btn = gr.Button("Run", variant="primary", visible=False)
|
284 |
|
285 |
with gr.Column():
|
@@ -341,6 +354,9 @@ with gr.Blocks() as demo:
|
|
341 |
input_img,
|
342 |
background_remove_state,
|
343 |
foreground_ratio,
|
|
|
|
|
|
|
344 |
],
|
345 |
outputs=[
|
346 |
run_btn,
|
@@ -352,4 +368,4 @@ with gr.Blocks() as demo:
|
|
352 |
],
|
353 |
)
|
354 |
|
355 |
-
demo.launch()
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
import time
|
4 |
+
from contextlib import nullcontext
|
5 |
from functools import lru_cache
|
6 |
from typing import Any
|
7 |
|
|
|
12 |
from gradio_litmodel3d import LitModel3D
|
13 |
from PIL import Image
|
14 |
|
15 |
+
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
|
16 |
+
|
17 |
import sf3d.utils as sf3d_utils
|
18 |
from sf3d.system import SF3D
|
19 |
|
20 |
+
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
|
21 |
+
|
22 |
rembg_session = rembg.new_session()
|
23 |
|
24 |
COND_WIDTH = 512
|
|
|
33 |
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
|
34 |
)
|
35 |
|
36 |
+
generated_files = []
|
37 |
+
|
38 |
+
# Delete previous gradio temp dir folder
|
39 |
+
if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
|
40 |
+
print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
|
41 |
+
import shutil
|
42 |
+
|
43 |
+
shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
|
44 |
+
|
45 |
+
device = sf3d_utils.get_device()
|
46 |
|
47 |
model = SF3D.from_pretrained(
|
48 |
"stabilityai/stable-fast-3d",
|
49 |
config_name="config.yaml",
|
50 |
weight_name="model.safetensors",
|
51 |
)
|
52 |
+
model.eval()
|
53 |
+
model = model.to(device)
|
54 |
|
55 |
example_files = [
|
56 |
os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
|
57 |
]
|
58 |
|
59 |
|
60 |
+
def run_model(input_image, remesh_option, vertex_count, texture_size):
|
61 |
start = time.time()
|
62 |
with torch.no_grad():
|
63 |
+
with torch.autocast(
|
64 |
+
device_type=device, dtype=torch.bfloat16
|
65 |
+
) if "cuda" in device else nullcontext():
|
66 |
model_batch = create_batch(input_image)
|
67 |
+
model_batch = {k: v.to(device) for k, v in model_batch.items()}
|
68 |
+
trimesh_mesh, _glob_dict = model.generate_mesh(
|
69 |
+
model_batch, texture_size, remesh_option, vertex_count
|
70 |
+
)
|
71 |
trimesh_mesh = trimesh_mesh[0]
|
72 |
|
73 |
# Create new tmp file
|
74 |
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
|
75 |
|
76 |
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
|
77 |
+
generated_files.append(tmp_file.name)
|
78 |
|
79 |
print("Generation took:", time.time() - start, "s")
|
80 |
|
|
|
125 |
return rembg.remove(input_image, session=rembg_session)
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def show_mask_img(input_image: Image) -> Image:
|
129 |
img_numpy = np.array(input_image)
|
130 |
alpha = img_numpy[:, :, 3] / 255.0
|
|
|
133 |
return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
|
134 |
|
135 |
|
136 |
+
def run_button(
|
137 |
+
run_btn,
|
138 |
+
input_image,
|
139 |
+
background_state,
|
140 |
+
foreground_ratio,
|
141 |
+
remesh_option,
|
142 |
+
vertex_count,
|
143 |
+
texture_size,
|
144 |
+
):
|
145 |
if run_btn == "Run":
|
146 |
+
if torch.cuda.is_available():
|
147 |
+
torch.cuda.reset_peak_memory_stats()
|
148 |
+
glb_file: str = run_model(
|
149 |
+
background_state, remesh_option.lower(), vertex_count, texture_size
|
150 |
+
)
|
151 |
+
if torch.cuda.is_available():
|
152 |
+
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
153 |
+
elif torch.backends.mps.is_available():
|
154 |
+
print(
|
155 |
+
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
|
156 |
+
)
|
157 |
|
158 |
return (
|
159 |
gr.update(),
|
|
|
166 |
elif run_btn == "Remove Background":
|
167 |
rem_removed = remove_background(input_image)
|
168 |
|
169 |
+
fr_res = sf3d_utils.resize_foreground(
|
170 |
+
rem_removed, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
171 |
+
)
|
172 |
|
173 |
return (
|
174 |
gr.update(value="Run", visible=True),
|
175 |
+
rem_removed,
|
176 |
fr_res,
|
177 |
gr.update(value=show_mask_img(fr_res), visible=True),
|
178 |
gr.update(value=None, visible=False),
|
|
|
195 |
|
196 |
if min_alpha == 0:
|
197 |
print("Already has alpha")
|
198 |
+
fr_res = sf3d_utils.resize_foreground(
|
199 |
+
image, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
200 |
+
)
|
201 |
return (
|
202 |
gr.update(value="Run", visible=True),
|
203 |
+
image,
|
204 |
fr_res,
|
205 |
gr.update(value=show_mask_img(fr_res), visible=True),
|
206 |
gr.update(visible=False),
|
|
|
217 |
|
218 |
|
219 |
def update_foreground_ratio(img_proc, fr):
|
220 |
+
foreground_res = sf3d_utils.resize_foreground(
|
221 |
+
img_proc, fr, out_size=(COND_WIDTH, COND_HEIGHT)
|
222 |
+
)
|
223 |
return (
|
224 |
foreground_res,
|
225 |
gr.update(value=show_mask_img(foreground_res)),
|
|
|
238 |
**Tips**
|
239 |
1. If the image already has an alpha channel, you can skip the background removal step.
|
240 |
2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
|
241 |
+
3. You can select the remeshing option to control the mesh topology. This can introduce artifacts in the mesh on thin surfaces and should be turned off in such cases.
|
242 |
+
4. You can upload your own HDR environment map to light the 3D model.
|
243 |
""")
|
244 |
with gr.Row(variant="panel"):
|
245 |
with gr.Column():
|
|
|
269 |
outputs=[background_remove_state, preview_removal],
|
270 |
)
|
271 |
|
272 |
+
remesh_option = gr.Radio(
|
273 |
+
choices=["None", "Triangle", "Quad"],
|
274 |
+
label="Remeshing",
|
275 |
+
value="None",
|
276 |
+
visible=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
vertex_count_slider = gr.Slider(
|
280 |
+
label="Target Vertex Count",
|
281 |
+
minimum=-1,
|
282 |
+
maximum=20000,
|
283 |
+
value=-1,
|
284 |
+
visible=True,
|
285 |
+
)
|
286 |
+
|
287 |
+
texture_size = gr.Slider(
|
288 |
+
label="Texture Size",
|
289 |
+
minimum=512,
|
290 |
+
maximum=2048,
|
291 |
+
value=1024,
|
292 |
+
step=256,
|
293 |
+
visible=True,
|
294 |
+
)
|
295 |
+
|
296 |
run_btn = gr.Button("Run", variant="primary", visible=False)
|
297 |
|
298 |
with gr.Column():
|
|
|
354 |
input_img,
|
355 |
background_remove_state,
|
356 |
foreground_ratio,
|
357 |
+
remesh_option,
|
358 |
+
vertex_count_slider,
|
359 |
+
texture_size,
|
360 |
],
|
361 |
outputs=[
|
362 |
run_btn,
|
|
|
368 |
],
|
369 |
)
|
370 |
|
371 |
+
demo.queue().launch(share=False)
|
requirements.txt
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
3 |
einops==0.7.0
|
4 |
jaxtyping==0.2.31
|
5 |
omegaconf==2.3.0
|
6 |
transformers==4.42.3
|
7 |
-
slangtorch==1.2.2
|
8 |
open_clip_torch==2.24.0
|
9 |
trimesh==4.4.1
|
10 |
numpy==1.26.4
|
11 |
huggingface-hub==0.23.4
|
12 |
-
rembg[gpu]==2.0.57
|
|
|
|
|
|
|
|
|
13 |
gradio-litmodel3d==0.0.1
|
|
|
|
|
|
|
|
1 |
+
wheel
|
2 |
+
setuptools==69.5.1
|
3 |
+
torch==2.5.1
|
4 |
+
torchvision==0.20.1
|
5 |
einops==0.7.0
|
6 |
jaxtyping==0.2.31
|
7 |
omegaconf==2.3.0
|
8 |
transformers==4.42.3
|
|
|
9 |
open_clip_torch==2.24.0
|
10 |
trimesh==4.4.1
|
11 |
numpy==1.26.4
|
12 |
huggingface-hub==0.23.4
|
13 |
+
rembg[gpu]==2.0.57; sys_platform != 'darwin'
|
14 |
+
rembg==2.0.57; sys_platform == 'darwin'
|
15 |
+
pynanoinstantmeshes==0.0.3
|
16 |
+
gpytoolbox==0.2.0
|
17 |
+
gradio==4.41.0
|
18 |
gradio-litmodel3d==0.0.1
|
19 |
+
# (HF hack) These are installed at runtime in gradio_app.py
|
20 |
+
# ./texture_baker/
|
21 |
+
# ./uv_unwrapper/
|
ruff.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[lint]
|
2 |
+
ignore = ["F722"]
|
3 |
+
extend-select = ["I"]
|
run.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from contextlib import nullcontext
|
4 |
+
|
5 |
+
import rembg
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from sf3d.system import SF3D
|
11 |
+
from sf3d.utils import get_device, remove_background, resize_foreground
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"image", type=str, nargs="+", help="Path to input image(s) or folder."
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--device",
|
20 |
+
default=get_device(),
|
21 |
+
type=str,
|
22 |
+
help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--pretrained-model",
|
26 |
+
default="stabilityai/stable-fast-3d",
|
27 |
+
type=str,
|
28 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-fast-3d'",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--foreground-ratio",
|
32 |
+
default=0.85,
|
33 |
+
type=float,
|
34 |
+
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--output-dir",
|
38 |
+
default="output/",
|
39 |
+
type=str,
|
40 |
+
help="Output directory to save the results. Default: 'output/'",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--texture-resolution",
|
44 |
+
default=1024,
|
45 |
+
type=int,
|
46 |
+
help="Texture atlas resolution. Default: 1024",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--remesh_option",
|
50 |
+
choices=["none", "triangle", "quad"],
|
51 |
+
default="none",
|
52 |
+
help="Remeshing option",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--target_vertex_count",
|
56 |
+
type=int,
|
57 |
+
help="Target vertex count. -1 does not perform a reduction.",
|
58 |
+
default=-1,
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
62 |
+
)
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
# Ensure args.device contains cuda
|
66 |
+
devices = ["cuda", "mps", "cpu"]
|
67 |
+
if not any(args.device in device for device in devices):
|
68 |
+
raise ValueError("Invalid device. Use cuda, mps or cpu")
|
69 |
+
|
70 |
+
output_dir = args.output_dir
|
71 |
+
os.makedirs(output_dir, exist_ok=True)
|
72 |
+
|
73 |
+
device = args.device
|
74 |
+
if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
|
75 |
+
device = "cpu"
|
76 |
+
|
77 |
+
print("Device used: ", device)
|
78 |
+
|
79 |
+
model = SF3D.from_pretrained(
|
80 |
+
args.pretrained_model,
|
81 |
+
config_name="config.yaml",
|
82 |
+
weight_name="model.safetensors",
|
83 |
+
)
|
84 |
+
model.to(device)
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
rembg_session = rembg.new_session()
|
88 |
+
images = []
|
89 |
+
idx = 0
|
90 |
+
for image_path in args.image:
|
91 |
+
|
92 |
+
def handle_image(image_path, idx):
|
93 |
+
image = remove_background(
|
94 |
+
Image.open(image_path).convert("RGBA"), rembg_session
|
95 |
+
)
|
96 |
+
image = resize_foreground(image, args.foreground_ratio)
|
97 |
+
os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
|
98 |
+
image.save(os.path.join(output_dir, str(idx), "input.png"))
|
99 |
+
images.append(image)
|
100 |
+
|
101 |
+
if os.path.isdir(image_path):
|
102 |
+
image_paths = [
|
103 |
+
os.path.join(image_path, f)
|
104 |
+
for f in os.listdir(image_path)
|
105 |
+
if f.endswith((".png", ".jpg", ".jpeg"))
|
106 |
+
]
|
107 |
+
for image_path in image_paths:
|
108 |
+
handle_image(image_path, idx)
|
109 |
+
idx += 1
|
110 |
+
else:
|
111 |
+
handle_image(image_path, idx)
|
112 |
+
idx += 1
|
113 |
+
|
114 |
+
for i in tqdm(range(0, len(images), args.batch_size)):
|
115 |
+
image = images[i : i + args.batch_size]
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
torch.cuda.reset_peak_memory_stats()
|
118 |
+
with torch.no_grad():
|
119 |
+
with torch.autocast(
|
120 |
+
device_type=device, dtype=torch.bfloat16
|
121 |
+
) if "cuda" in device else nullcontext():
|
122 |
+
mesh, glob_dict = model.run_image(
|
123 |
+
image,
|
124 |
+
bake_resolution=args.texture_resolution,
|
125 |
+
remesh=args.remesh_option,
|
126 |
+
vertex_count=args.target_vertex_count,
|
127 |
+
)
|
128 |
+
if torch.cuda.is_available():
|
129 |
+
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
130 |
+
elif torch.backends.mps.is_available():
|
131 |
+
print(
|
132 |
+
"Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
|
133 |
+
)
|
134 |
+
|
135 |
+
if len(image) == 1:
|
136 |
+
out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
|
137 |
+
mesh.export(out_mesh_path, include_normals=True)
|
138 |
+
else:
|
139 |
+
for j in range(len(mesh)):
|
140 |
+
out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
|
141 |
+
mesh[j].export(out_mesh_path, include_normals=True)
|
sf3d/models/image_estimator/clip_based_estimator.py
CHANGED
@@ -95,7 +95,7 @@ class ClipBasedHeadEstimator(BaseModule):
|
|
95 |
# Run the model
|
96 |
# Resize cond_image to 224
|
97 |
cond_image = nn.functional.interpolate(
|
98 |
-
cond_image.flatten(0, 1).permute(0, 3, 1, 2),
|
99 |
size=(224, 224),
|
100 |
mode="bilinear",
|
101 |
align_corners=False,
|
|
|
95 |
# Run the model
|
96 |
# Resize cond_image to 224
|
97 |
cond_image = nn.functional.interpolate(
|
98 |
+
cond_image.flatten(0, 1).permute(0, 3, 1, 2).contiguous(),
|
99 |
size=(224, 224),
|
100 |
mode="bilinear",
|
101 |
align_corners=False,
|
sf3d/models/mesh.py
CHANGED
@@ -1,15 +1,30 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
from typing import Any, Dict, Optional
|
4 |
|
|
|
|
|
|
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
|
|
7 |
from jaxtyping import Float, Integer
|
8 |
from torch import Tensor
|
9 |
|
10 |
-
from sf3d.box_uv_unwrap import box_projection_uv_unwrap
|
11 |
from sf3d.models.utils import dot
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class Mesh:
|
15 |
def __init__(
|
@@ -25,6 +40,8 @@ class Mesh:
|
|
25 |
for k, v in kwargs.items():
|
26 |
self.add_extra(k, v)
|
27 |
|
|
|
|
|
28 |
def add_extra(self, k, v) -> None:
|
29 |
self.extras[k] = v
|
30 |
|
@@ -131,12 +148,112 @@ class Mesh:
|
|
131 |
|
132 |
return tangents
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
@torch.no_grad()
|
135 |
def unwrap_uv(
|
136 |
self,
|
137 |
island_padding: float = 0.02,
|
138 |
) -> Mesh:
|
139 |
-
uv, indices =
|
140 |
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
141 |
)
|
142 |
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import math
|
4 |
from typing import Any, Dict, Optional
|
5 |
|
6 |
+
import gpytoolbox
|
7 |
+
import numpy as np
|
8 |
+
import pynanoinstantmeshes
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
+
import trimesh
|
12 |
from jaxtyping import Float, Integer
|
13 |
from torch import Tensor
|
14 |
|
|
|
15 |
from sf3d.models.utils import dot
|
16 |
|
17 |
+
try:
|
18 |
+
from uv_unwrapper import Unwrapper
|
19 |
+
except ImportError:
|
20 |
+
import logging
|
21 |
+
|
22 |
+
logging.warning(
|
23 |
+
"Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
|
24 |
+
)
|
25 |
+
# Exit early to avoid further errors
|
26 |
+
raise ImportError("uv_unwrapper not found")
|
27 |
+
|
28 |
|
29 |
class Mesh:
|
30 |
def __init__(
|
|
|
40 |
for k, v in kwargs.items():
|
41 |
self.add_extra(k, v)
|
42 |
|
43 |
+
self.unwrapper = Unwrapper()
|
44 |
+
|
45 |
def add_extra(self, k, v) -> None:
|
46 |
self.extras[k] = v
|
47 |
|
|
|
148 |
|
149 |
return tangents
|
150 |
|
151 |
+
def quad_remesh(
|
152 |
+
self,
|
153 |
+
quad_vertex_count: int = -1,
|
154 |
+
quad_rosy: int = 4,
|
155 |
+
quad_crease_angle: float = -1.0,
|
156 |
+
quad_smooth_iter: int = 2,
|
157 |
+
quad_align_to_boundaries: bool = False,
|
158 |
+
) -> Mesh:
|
159 |
+
if quad_vertex_count < 0:
|
160 |
+
quad_vertex_count = self.v_pos.shape[0]
|
161 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
162 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
|
163 |
+
|
164 |
+
new_vert, new_faces = pynanoinstantmeshes.remesh(
|
165 |
+
v_pos,
|
166 |
+
t_pos_idx,
|
167 |
+
quad_vertex_count // 4,
|
168 |
+
rosy=quad_rosy,
|
169 |
+
posy=4,
|
170 |
+
creaseAngle=quad_crease_angle,
|
171 |
+
align_to_boundaries=quad_align_to_boundaries,
|
172 |
+
smooth_iter=quad_smooth_iter,
|
173 |
+
deterministic=False,
|
174 |
+
)
|
175 |
+
|
176 |
+
# Briefly load in trimesh
|
177 |
+
mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
|
178 |
+
|
179 |
+
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
|
180 |
+
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
|
181 |
+
|
182 |
+
# Create new mesh
|
183 |
+
return Mesh(v_pos, t_pos_idx)
|
184 |
+
|
185 |
+
def triangle_remesh(
|
186 |
+
self,
|
187 |
+
triangle_average_edge_length_multiplier: Optional[float] = None,
|
188 |
+
triangle_remesh_steps: int = 10,
|
189 |
+
triangle_vertex_count=-1,
|
190 |
+
):
|
191 |
+
if triangle_vertex_count > 0:
|
192 |
+
reduction = triangle_vertex_count / self.v_pos.shape[0]
|
193 |
+
print("Triangle reduction:", reduction)
|
194 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
|
195 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
196 |
+
if reduction > 1.0:
|
197 |
+
subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
|
198 |
+
print("Subdivide iters:", subdivide_iters)
|
199 |
+
v_pos, t_pos_idx = gpytoolbox.subdivide(
|
200 |
+
v_pos,
|
201 |
+
t_pos_idx,
|
202 |
+
iters=subdivide_iters,
|
203 |
+
)
|
204 |
+
reduction = triangle_vertex_count / v_pos.shape[0]
|
205 |
+
|
206 |
+
# Simplify
|
207 |
+
points_out, faces_out, _, _ = gpytoolbox.decimate(
|
208 |
+
v_pos,
|
209 |
+
t_pos_idx,
|
210 |
+
face_ratio=reduction,
|
211 |
+
)
|
212 |
+
|
213 |
+
# Convert back to torch
|
214 |
+
self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
|
215 |
+
self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
|
216 |
+
self._edges = None
|
217 |
+
triangle_average_edge_length_multiplier = None
|
218 |
+
|
219 |
+
edges = self.edges
|
220 |
+
if triangle_average_edge_length_multiplier is None:
|
221 |
+
h = None
|
222 |
+
else:
|
223 |
+
h = float(
|
224 |
+
torch.linalg.norm(
|
225 |
+
self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
|
226 |
+
)
|
227 |
+
.mean()
|
228 |
+
.item()
|
229 |
+
* triangle_average_edge_length_multiplier
|
230 |
+
)
|
231 |
+
|
232 |
+
# Convert to numpy
|
233 |
+
v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
|
234 |
+
t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
|
235 |
+
|
236 |
+
# Remesh
|
237 |
+
v_remesh, f_remesh = gpytoolbox.remesh_botsch(
|
238 |
+
v_pos,
|
239 |
+
t_pos_idx,
|
240 |
+
triangle_remesh_steps,
|
241 |
+
h,
|
242 |
+
)
|
243 |
+
|
244 |
+
# Convert back to torch
|
245 |
+
v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
|
246 |
+
t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
|
247 |
+
|
248 |
+
# Create new mesh
|
249 |
+
return Mesh(v_pos, t_pos_idx)
|
250 |
+
|
251 |
@torch.no_grad()
|
252 |
def unwrap_uv(
|
253 |
self,
|
254 |
island_padding: float = 0.02,
|
255 |
) -> Mesh:
|
256 |
+
uv, indices = self.unwrapper(
|
257 |
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
258 |
)
|
259 |
|
sf3d/models/network.py
CHANGED
@@ -7,10 +7,23 @@ import torch.nn.functional as F
|
|
7 |
from einops import rearrange
|
8 |
from jaxtyping import Float
|
9 |
from torch import Tensor
|
|
|
10 |
from torch.autograd import Function
|
11 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
12 |
|
13 |
from sf3d.models.utils import BaseModule, normalize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class PixelShuffleUpsampleNetwork(BaseModule):
|
@@ -65,13 +78,18 @@ class _TruncExp(Function): # pylint: disable=abstract-method
|
|
65 |
# Implementation from torch-ngp:
|
66 |
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
67 |
@staticmethod
|
68 |
-
@
|
|
|
|
|
|
|
|
|
|
|
69 |
def forward(ctx, x): # pylint: disable=arguments-differ
|
70 |
ctx.save_for_backward(x)
|
71 |
return torch.exp(x)
|
72 |
|
73 |
@staticmethod
|
74 |
-
@custom_bwd
|
75 |
def backward(ctx, g): # pylint: disable=arguments-differ
|
76 |
x = ctx.saved_tensors[0]
|
77 |
return g * torch.exp(torch.clamp(x, max=15))
|
|
|
7 |
from einops import rearrange
|
8 |
from jaxtyping import Float
|
9 |
from torch import Tensor
|
10 |
+
from torch.amp import custom_bwd, custom_fwd
|
11 |
from torch.autograd import Function
|
|
|
12 |
|
13 |
from sf3d.models.utils import BaseModule, normalize
|
14 |
+
from sf3d.utils import get_device
|
15 |
+
|
16 |
+
|
17 |
+
def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
|
18 |
+
def wrapper(fn):
|
19 |
+
if condition:
|
20 |
+
if len(kwargs) == 0:
|
21 |
+
return decorator_with_args
|
22 |
+
return decorator_with_args(*args, **kwargs)(fn)
|
23 |
+
else:
|
24 |
+
return fn
|
25 |
+
|
26 |
+
return wrapper
|
27 |
|
28 |
|
29 |
class PixelShuffleUpsampleNetwork(BaseModule):
|
|
|
78 |
# Implementation from torch-ngp:
|
79 |
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
80 |
@staticmethod
|
81 |
+
@conditional_decorator(
|
82 |
+
custom_fwd,
|
83 |
+
"cuda" in get_device(),
|
84 |
+
cast_inputs=torch.float32,
|
85 |
+
device_type="cuda",
|
86 |
+
)
|
87 |
def forward(ctx, x): # pylint: disable=arguments-differ
|
88 |
ctx.save_for_backward(x)
|
89 |
return torch.exp(x)
|
90 |
|
91 |
@staticmethod
|
92 |
+
@conditional_decorator(custom_bwd, "cuda" in get_device())
|
93 |
def backward(ctx, g): # pylint: disable=arguments-differ
|
94 |
x = ctx.saved_tensors[0]
|
95 |
return g * torch.exp(torch.clamp(x, max=15))
|
sf3d/models/utils.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import dataclasses
|
2 |
import importlib
|
3 |
-
import math
|
4 |
from dataclasses import dataclass
|
5 |
from typing import Any, List, Optional, Tuple, Union
|
6 |
|
@@ -9,7 +8,7 @@ import PIL
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
-
from jaxtyping import
|
13 |
from omegaconf import DictConfig, OmegaConf
|
14 |
from torch import Tensor
|
15 |
|
@@ -77,61 +76,6 @@ def normalize(x, dim=-1, eps=None):
|
|
77 |
return F.normalize(x, dim=dim, p=2, eps=eps)
|
78 |
|
79 |
|
80 |
-
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
|
81 |
-
# One pad for determinant
|
82 |
-
tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
|
83 |
-
det_tri = torch.det(tri_sq)
|
84 |
-
tri_rev = torch.cat(
|
85 |
-
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
|
86 |
-
)
|
87 |
-
tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
|
88 |
-
return tri_sq
|
89 |
-
|
90 |
-
|
91 |
-
def triangle_intersection_2d(
|
92 |
-
t1: Float[Tensor, "*B 3 2"],
|
93 |
-
t2: Float[Tensor, "*B 3 2"],
|
94 |
-
eps=1e-12,
|
95 |
-
) -> Float[Tensor, "*B"]: # noqa: F821
|
96 |
-
"""Returns True if triangles collide, False otherwise"""
|
97 |
-
|
98 |
-
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
|
99 |
-
logdetx = torch.logdet(x.double())
|
100 |
-
if eps is None:
|
101 |
-
return ~torch.isfinite(logdetx)
|
102 |
-
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
|
103 |
-
|
104 |
-
t1s = tri_winding(t1)
|
105 |
-
t2s = tri_winding(t2)
|
106 |
-
|
107 |
-
# Assume the triangles do not collide in the begging
|
108 |
-
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
|
109 |
-
for i in range(3):
|
110 |
-
edge = torch.roll(t1s, i, dims=1)[:, :2, :]
|
111 |
-
# Check if all points of triangle 2 lay on the external side of edge E.
|
112 |
-
# If this is the case the triangle do not collide
|
113 |
-
upd = (
|
114 |
-
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
|
115 |
-
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
|
116 |
-
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
|
117 |
-
)
|
118 |
-
# Here no collision is still True due to inversion
|
119 |
-
ret = ret | upd
|
120 |
-
|
121 |
-
for i in range(3):
|
122 |
-
edge = torch.roll(t2s, i, dims=1)[:, :2, :]
|
123 |
-
|
124 |
-
upd = (
|
125 |
-
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
|
126 |
-
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
|
127 |
-
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
|
128 |
-
)
|
129 |
-
# Here no collision is still True due to inversion
|
130 |
-
ret = ret | upd
|
131 |
-
|
132 |
-
return ~ret # Do the inversion
|
133 |
-
|
134 |
-
|
135 |
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
136 |
|
137 |
|
|
|
1 |
import dataclasses
|
2 |
import importlib
|
|
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Any, List, Optional, Tuple, Union
|
5 |
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
+
from jaxtyping import Float, Int, Num
|
12 |
from omegaconf import DictConfig, OmegaConf
|
13 |
from torch import Tensor
|
14 |
|
|
|
76 |
return F.normalize(x, dim=dim, p=2, eps=eps)
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
80 |
|
81 |
|
sf3d/system.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
|
|
2 |
from dataclasses import dataclass, field
|
3 |
-
from typing import Any, List, Optional, Tuple
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
@@ -21,15 +22,23 @@ from sf3d.models.utils import (
|
|
21 |
ImageProcessor,
|
22 |
convert_data,
|
23 |
dilate_fill,
|
24 |
-
dot,
|
25 |
find_class,
|
26 |
float32_to_uint8_np,
|
27 |
normalize,
|
28 |
scale_tensor,
|
29 |
)
|
30 |
-
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
class SF3D(BaseModule):
|
@@ -206,6 +215,7 @@ class SF3D(BaseModule):
|
|
206 |
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
207 |
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
208 |
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
|
|
209 |
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
210 |
|
211 |
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
@@ -234,10 +244,54 @@ class SF3D(BaseModule):
|
|
234 |
|
235 |
def run_image(
|
236 |
self,
|
237 |
-
image: Image,
|
238 |
bake_resolution: int,
|
|
|
|
|
239 |
estimate_illumination: bool = False,
|
240 |
-
) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
if image.mode != "RGBA":
|
242 |
raise ValueError("Image must be in RGBA mode")
|
243 |
img_cond = (
|
@@ -258,30 +312,14 @@ class SF3D(BaseModule):
|
|
258 |
mask_cond,
|
259 |
)
|
260 |
|
261 |
-
|
262 |
-
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
263 |
-
self.cfg.default_fovy_deg,
|
264 |
-
self.cfg.cond_image_size,
|
265 |
-
self.cfg.cond_image_size,
|
266 |
-
)
|
267 |
-
|
268 |
-
batch = {
|
269 |
-
"rgb_cond": rgb_cond,
|
270 |
-
"mask_cond": mask_cond,
|
271 |
-
"c2w_cond": c2w_cond.unsqueeze(0),
|
272 |
-
"intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
|
273 |
-
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
|
274 |
-
}
|
275 |
-
|
276 |
-
meshes, global_dict = self.generate_mesh(
|
277 |
-
batch, bake_resolution, estimate_illumination
|
278 |
-
)
|
279 |
-
return meshes[0], global_dict
|
280 |
|
281 |
def generate_mesh(
|
282 |
self,
|
283 |
batch,
|
284 |
bake_resolution: int,
|
|
|
|
|
285 |
estimate_illumination: bool = False,
|
286 |
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
287 |
batch["rgb_cond"] = self.image_processor(
|
@@ -300,8 +338,11 @@ class SF3D(BaseModule):
|
|
300 |
if self.global_estimator is not None and estimate_illumination:
|
301 |
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
302 |
|
|
|
303 |
with torch.no_grad():
|
304 |
-
with torch.autocast(
|
|
|
|
|
305 |
meshes = self.triplane_to_meshes(scene_codes)
|
306 |
|
307 |
rets = []
|
@@ -311,6 +352,17 @@ class SF3D(BaseModule):
|
|
311 |
rets.append(trimesh.Trimesh())
|
312 |
continue
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
mesh.unwrap_uv()
|
315 |
|
316 |
# Build textures
|
@@ -323,7 +375,6 @@ class SF3D(BaseModule):
|
|
323 |
mesh.v_pos,
|
324 |
rast,
|
325 |
mesh.t_pos_idx,
|
326 |
-
mesh.v_tex,
|
327 |
)
|
328 |
gb_pos = pos_bake[bake_mask]
|
329 |
|
@@ -336,7 +387,6 @@ class SF3D(BaseModule):
|
|
336 |
mesh.v_nrm,
|
337 |
rast,
|
338 |
mesh.t_pos_idx,
|
339 |
-
mesh.v_tex,
|
340 |
)
|
341 |
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
342 |
decoded["normal"] = gb_nrm
|
@@ -377,29 +427,28 @@ class SF3D(BaseModule):
|
|
377 |
mesh.v_tng,
|
378 |
rast,
|
379 |
mesh.t_pos_idx,
|
380 |
-
mesh.v_tex,
|
381 |
)
|
382 |
gb_tng = tng[bake_mask]
|
383 |
gb_tng = F.normalize(gb_tng, dim=-1)
|
384 |
gb_btng = F.normalize(
|
385 |
-
torch.cross(
|
386 |
)
|
387 |
normal = F.normalize(mat_out["normal"], dim=-1)
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
399 |
)
|
400 |
-
bump = (bump * 0.5 + 0.5).clamp(0, 1)
|
401 |
|
402 |
-
f[bake_mask] =
|
403 |
mat_out["bump"] = f
|
404 |
else:
|
405 |
f[bake_mask] = v.view(-1, v.shape[-1])
|
@@ -410,12 +459,13 @@ class SF3D(BaseModule):
|
|
410 |
return arr
|
411 |
return (
|
412 |
dilate_fill(
|
413 |
-
arr.permute(2, 0, 1)[None, ...],
|
414 |
bake_mask.unsqueeze(0).unsqueeze(0),
|
415 |
iterations=bake_resolution // 150,
|
416 |
)
|
417 |
.squeeze(0)
|
418 |
.permute(1, 2, 0)
|
|
|
419 |
)
|
420 |
|
421 |
verts_np = convert_data(mesh.v_pos)
|
|
|
1 |
import os
|
2 |
+
from contextlib import nullcontext
|
3 |
from dataclasses import dataclass, field
|
4 |
+
from typing import Any, List, Literal, Optional, Tuple, Union
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
22 |
ImageProcessor,
|
23 |
convert_data,
|
24 |
dilate_fill,
|
|
|
25 |
find_class,
|
26 |
float32_to_uint8_np,
|
27 |
normalize,
|
28 |
scale_tensor,
|
29 |
)
|
30 |
+
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w, get_device
|
31 |
|
32 |
+
try:
|
33 |
+
from texture_baker import TextureBaker
|
34 |
+
except ImportError:
|
35 |
+
import logging
|
36 |
+
|
37 |
+
logging.warning(
|
38 |
+
"Could not import texture_baker. Please install it via `pip install texture-baker/`"
|
39 |
+
)
|
40 |
+
# Exit early to avoid further errors
|
41 |
+
raise ImportError("texture_baker not found")
|
42 |
|
43 |
|
44 |
class SF3D(BaseModule):
|
|
|
215 |
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
216 |
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
217 |
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
218 |
+
|
219 |
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
220 |
|
221 |
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
|
|
244 |
|
245 |
def run_image(
|
246 |
self,
|
247 |
+
image: Union[Image.Image, List[Image.Image]],
|
248 |
bake_resolution: int,
|
249 |
+
remesh: Literal["none", "triangle", "quad"] = "none",
|
250 |
+
vertex_count: int = -1,
|
251 |
estimate_illumination: bool = False,
|
252 |
+
) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
|
253 |
+
if isinstance(image, list):
|
254 |
+
rgb_cond = []
|
255 |
+
mask_cond = []
|
256 |
+
for img in image:
|
257 |
+
mask, rgb = self.prepare_image(img)
|
258 |
+
mask_cond.append(mask)
|
259 |
+
rgb_cond.append(rgb)
|
260 |
+
rgb_cond = torch.stack(rgb_cond, 0)
|
261 |
+
mask_cond = torch.stack(mask_cond, 0)
|
262 |
+
batch_size = rgb_cond.shape[0]
|
263 |
+
else:
|
264 |
+
mask_cond, rgb_cond = self.prepare_image(image)
|
265 |
+
batch_size = 1
|
266 |
+
|
267 |
+
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
|
268 |
+
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
269 |
+
self.cfg.default_fovy_deg,
|
270 |
+
self.cfg.cond_image_size,
|
271 |
+
self.cfg.cond_image_size,
|
272 |
+
)
|
273 |
+
|
274 |
+
batch = {
|
275 |
+
"rgb_cond": rgb_cond,
|
276 |
+
"mask_cond": mask_cond,
|
277 |
+
"c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
|
278 |
+
"intrinsic_cond": intrinsic.to(self.device)
|
279 |
+
.view(1, 1, 3, 3)
|
280 |
+
.repeat(batch_size, 1, 1, 1),
|
281 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
|
282 |
+
.view(1, 1, 3, 3)
|
283 |
+
.repeat(batch_size, 1, 1, 1),
|
284 |
+
}
|
285 |
+
|
286 |
+
meshes, global_dict = self.generate_mesh(
|
287 |
+
batch, bake_resolution, remesh, vertex_count, estimate_illumination
|
288 |
+
)
|
289 |
+
if batch_size == 1:
|
290 |
+
return meshes[0], global_dict
|
291 |
+
else:
|
292 |
+
return meshes, global_dict
|
293 |
+
|
294 |
+
def prepare_image(self, image):
|
295 |
if image.mode != "RGBA":
|
296 |
raise ValueError("Image must be in RGBA mode")
|
297 |
img_cond = (
|
|
|
312 |
mask_cond,
|
313 |
)
|
314 |
|
315 |
+
return mask_cond, rgb_cond
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
def generate_mesh(
|
318 |
self,
|
319 |
batch,
|
320 |
bake_resolution: int,
|
321 |
+
remesh: Literal["none", "triangle", "quad"] = "none",
|
322 |
+
vertex_count: int = -1,
|
323 |
estimate_illumination: bool = False,
|
324 |
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
325 |
batch["rgb_cond"] = self.image_processor(
|
|
|
338 |
if self.global_estimator is not None and estimate_illumination:
|
339 |
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
340 |
|
341 |
+
device = get_device()
|
342 |
with torch.no_grad():
|
343 |
+
with torch.autocast(
|
344 |
+
device_type=device, enabled=False
|
345 |
+
) if "cuda" in device else nullcontext():
|
346 |
meshes = self.triplane_to_meshes(scene_codes)
|
347 |
|
348 |
rets = []
|
|
|
352 |
rets.append(trimesh.Trimesh())
|
353 |
continue
|
354 |
|
355 |
+
if remesh == "triangle":
|
356 |
+
mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
|
357 |
+
elif remesh == "quad":
|
358 |
+
mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
|
359 |
+
else:
|
360 |
+
if vertex_count > 0:
|
361 |
+
print(
|
362 |
+
"Warning: vertex_count is ignored when remesh is none"
|
363 |
+
)
|
364 |
+
|
365 |
+
print("After Remesh", mesh.v_pos.shape[0], mesh.t_pos_idx.shape[0])
|
366 |
mesh.unwrap_uv()
|
367 |
|
368 |
# Build textures
|
|
|
375 |
mesh.v_pos,
|
376 |
rast,
|
377 |
mesh.t_pos_idx,
|
|
|
378 |
)
|
379 |
gb_pos = pos_bake[bake_mask]
|
380 |
|
|
|
387 |
mesh.v_nrm,
|
388 |
rast,
|
389 |
mesh.t_pos_idx,
|
|
|
390 |
)
|
391 |
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
392 |
decoded["normal"] = gb_nrm
|
|
|
427 |
mesh.v_tng,
|
428 |
rast,
|
429 |
mesh.t_pos_idx,
|
|
|
430 |
)
|
431 |
gb_tng = tng[bake_mask]
|
432 |
gb_tng = F.normalize(gb_tng, dim=-1)
|
433 |
gb_btng = F.normalize(
|
434 |
+
torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
|
435 |
)
|
436 |
normal = F.normalize(mat_out["normal"], dim=-1)
|
437 |
|
438 |
+
# Create tangent space matrix and transform normal
|
439 |
+
tangent_matrix = torch.stack(
|
440 |
+
[gb_tng, gb_btng, gb_nrm], dim=-1
|
441 |
+
)
|
442 |
+
normal_tangent = torch.bmm(
|
443 |
+
tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
|
444 |
+
).squeeze(-1)
|
445 |
+
|
446 |
+
# Convert from [-1,1] to [0,1] range for storage
|
447 |
+
normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
|
448 |
+
0, 1
|
449 |
)
|
|
|
450 |
|
451 |
+
f[bake_mask] = normal_tangent.view(-1, 3)
|
452 |
mat_out["bump"] = f
|
453 |
else:
|
454 |
f[bake_mask] = v.view(-1, v.shape[-1])
|
|
|
459 |
return arr
|
460 |
return (
|
461 |
dilate_fill(
|
462 |
+
arr.permute(2, 0, 1)[None, ...].contiguous(),
|
463 |
bake_mask.unsqueeze(0).unsqueeze(0),
|
464 |
iterations=bake_resolution // 150,
|
465 |
)
|
466 |
.squeeze(0)
|
467 |
.permute(1, 2, 0)
|
468 |
+
.contiguous()
|
469 |
)
|
470 |
|
471 |
verts_np = convert_data(mesh.v_pos)
|
sf3d/utils.py
CHANGED
@@ -1,13 +1,27 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import rembg
|
5 |
import torch
|
|
|
6 |
from PIL import Image
|
7 |
|
8 |
import sf3d.models.utils as sf3d_utils
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
12 |
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
13 |
np.deg2rad(fov_deg),
|
@@ -50,42 +64,42 @@ def remove_background(
|
|
50 |
return image
|
51 |
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def resize_foreground(
|
54 |
-
image: Image,
|
55 |
ratio: float,
|
|
|
56 |
) -> Image:
|
57 |
-
image
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
)
|
66 |
-
# crop the foreground
|
67 |
-
fg = image[y1:y2, x1:x2]
|
68 |
-
# pad to square
|
69 |
-
size = max(fg.shape[0], fg.shape[1])
|
70 |
-
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
71 |
-
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
72 |
-
new_image = np.pad(
|
73 |
-
fg,
|
74 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
75 |
-
mode="constant",
|
76 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
77 |
-
)
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
new_image,
|
86 |
-
((ph0, ph1), (pw0, pw1), (0, 0)),
|
87 |
-
mode="constant",
|
88 |
-
constant_values=((0, 0), (0, 0), (0, 0)),
|
89 |
)
|
90 |
-
|
|
|
|
|
91 |
return new_image
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Union
|
3 |
|
4 |
import numpy as np
|
5 |
import rembg
|
6 |
import torch
|
7 |
+
import torchvision.transforms.functional as torchvision_F
|
8 |
from PIL import Image
|
9 |
|
10 |
import sf3d.models.utils as sf3d_utils
|
11 |
|
12 |
|
13 |
+
def get_device():
|
14 |
+
if os.environ.get("SF3D_USE_CPU", "0") == "1":
|
15 |
+
return "cpu"
|
16 |
+
|
17 |
+
device = "cpu"
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
device = "cuda"
|
20 |
+
elif torch.backends.mps.is_available():
|
21 |
+
device = "mps"
|
22 |
+
return device
|
23 |
+
|
24 |
+
|
25 |
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
26 |
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
27 |
np.deg2rad(fov_deg),
|
|
|
64 |
return image
|
65 |
|
66 |
|
67 |
+
def get_1d_bounds(arr):
|
68 |
+
nz = np.flatnonzero(arr)
|
69 |
+
return nz[0], nz[-1]
|
70 |
+
|
71 |
+
|
72 |
+
def get_bbox_from_mask(mask, thr=0.5):
|
73 |
+
masks_for_box = (mask > thr).astype(np.float32)
|
74 |
+
assert masks_for_box.sum() > 0, "Empty mask!"
|
75 |
+
x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
|
76 |
+
y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
|
77 |
+
return x0, y0, x1, y1
|
78 |
+
|
79 |
+
|
80 |
def resize_foreground(
|
81 |
+
image: Union[Image.Image, np.ndarray],
|
82 |
ratio: float,
|
83 |
+
out_size=None,
|
84 |
) -> Image:
|
85 |
+
if isinstance(image, np.ndarray):
|
86 |
+
image = Image.fromarray(image, mode="RGBA")
|
87 |
+
assert image.mode == "RGBA"
|
88 |
+
# Get bounding box
|
89 |
+
mask_np = np.array(image)[:, :, -1]
|
90 |
+
x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
|
91 |
+
h, w = y2 - y1, x2 - x1
|
92 |
+
yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
|
93 |
+
scale = max(h, w) / ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
new_image = torchvision_F.crop(
|
96 |
+
image,
|
97 |
+
top=int(yc - scale / 2),
|
98 |
+
left=int(xc - scale / 2),
|
99 |
+
height=int(scale),
|
100 |
+
width=int(scale),
|
|
|
|
|
|
|
|
|
101 |
)
|
102 |
+
if out_size is not None:
|
103 |
+
new_image = new_image.resize(out_size)
|
104 |
+
|
105 |
return new_image
|
texture_baker/README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Texture baker
|
2 |
+
|
3 |
+
Small texture baker which rasterizes barycentric coordinates to a tensor.
|
4 |
+
It also implements an interpolation module which can be used to bake attributes to textures then.
|
5 |
+
|
6 |
+
## Usage
|
7 |
+
|
8 |
+
The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates.
|
9 |
+
It supports baking on the CPU and GPU.
|
10 |
+
|
11 |
+
```python
|
12 |
+
from texture_baker import TextureBaker
|
13 |
+
|
14 |
+
mesh = ...
|
15 |
+
uv = mesh.uv # num_vertex, 2
|
16 |
+
triangle_idx = mesh.faces # num_faces, 3
|
17 |
+
vertices = mesh.vertices # num_vertex, 3
|
18 |
+
|
19 |
+
tb = TextureBaker()
|
20 |
+
# First get the barycentric coordinates
|
21 |
+
rast = tb.rasterize(
|
22 |
+
uv=uv, face_indices=triangle_idx, bake_resolution=1024
|
23 |
+
)
|
24 |
+
# Then interpolate vertex attributes
|
25 |
+
position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx)
|
26 |
+
```
|
texture_baker/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
texture_baker/setup.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import platform
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from setuptools import find_packages, setup
|
7 |
+
from torch.utils.cpp_extension import (
|
8 |
+
CUDA_HOME,
|
9 |
+
BuildExtension,
|
10 |
+
CppExtension,
|
11 |
+
CUDAExtension,
|
12 |
+
)
|
13 |
+
|
14 |
+
library_name = "texture_baker"
|
15 |
+
|
16 |
+
|
17 |
+
def get_extensions():
|
18 |
+
debug_mode = os.getenv("DEBUG", "0") == "1"
|
19 |
+
use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1"
|
20 |
+
use_metal = (
|
21 |
+
os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1"
|
22 |
+
)
|
23 |
+
use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1"
|
24 |
+
if debug_mode:
|
25 |
+
print("Compiling in debug mode")
|
26 |
+
|
27 |
+
use_cuda = use_cuda and CUDA_HOME is not None
|
28 |
+
extension = CUDAExtension if use_cuda else CppExtension
|
29 |
+
|
30 |
+
extra_link_args = []
|
31 |
+
extra_compile_args = {
|
32 |
+
"cxx": [
|
33 |
+
"-O3" if not debug_mode else "-O0",
|
34 |
+
"-fdiagnostics-color=always",
|
35 |
+
"-fopenmp",
|
36 |
+
] + ["-march=native"] if use_native_arch else [],
|
37 |
+
"nvcc": [
|
38 |
+
"-O3" if not debug_mode else "-O0",
|
39 |
+
],
|
40 |
+
}
|
41 |
+
if debug_mode:
|
42 |
+
extra_compile_args["cxx"].append("-g")
|
43 |
+
if platform.system() == "Windows":
|
44 |
+
extra_compile_args["cxx"].append("/Z7")
|
45 |
+
extra_compile_args["cxx"].append("/Od")
|
46 |
+
extra_link_args.extend(["/DEBUG"])
|
47 |
+
extra_compile_args["cxx"].append("-UNDEBUG")
|
48 |
+
extra_compile_args["nvcc"].append("-UNDEBUG")
|
49 |
+
extra_compile_args["nvcc"].append("-g")
|
50 |
+
extra_link_args.extend(["-O0", "-g"])
|
51 |
+
|
52 |
+
define_macros = []
|
53 |
+
extensions = []
|
54 |
+
libraries = []
|
55 |
+
|
56 |
+
this_dir = os.path.dirname(os.path.curdir)
|
57 |
+
sources = glob.glob(
|
58 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
|
59 |
+
)
|
60 |
+
|
61 |
+
if len(sources) == 0:
|
62 |
+
print("No source files found for extension, skipping extension compilation")
|
63 |
+
return None
|
64 |
+
|
65 |
+
if use_cuda:
|
66 |
+
define_macros += [
|
67 |
+
("THRUST_IGNORE_CUB_VERSION_CHECK", None),
|
68 |
+
]
|
69 |
+
sources += glob.glob(
|
70 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True
|
71 |
+
)
|
72 |
+
libraries += ["cudart", "c10_cuda"]
|
73 |
+
|
74 |
+
if use_metal:
|
75 |
+
define_macros += [
|
76 |
+
("WITH_MPS", None),
|
77 |
+
]
|
78 |
+
sources += glob.glob(
|
79 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
|
80 |
+
)
|
81 |
+
extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64"]})
|
82 |
+
extra_link_args += ["-arch", "arm64"]
|
83 |
+
|
84 |
+
extensions.append(
|
85 |
+
extension(
|
86 |
+
name=f"{library_name}._C",
|
87 |
+
sources=sources,
|
88 |
+
define_macros=define_macros,
|
89 |
+
extra_compile_args=extra_compile_args,
|
90 |
+
extra_link_args=extra_link_args,
|
91 |
+
libraries=libraries
|
92 |
+
+ [
|
93 |
+
"c10",
|
94 |
+
"torch",
|
95 |
+
"torch_cpu",
|
96 |
+
"torch_python",
|
97 |
+
],
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
for ext in extensions:
|
102 |
+
ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries]
|
103 |
+
|
104 |
+
print(extensions)
|
105 |
+
|
106 |
+
return extensions
|
107 |
+
|
108 |
+
|
109 |
+
setup(
|
110 |
+
name=library_name,
|
111 |
+
version="0.0.1",
|
112 |
+
packages=find_packages(where="."),
|
113 |
+
package_dir={"": "."},
|
114 |
+
ext_modules=get_extensions(),
|
115 |
+
install_requires=[],
|
116 |
+
package_data={
|
117 |
+
library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")],
|
118 |
+
},
|
119 |
+
description="Small texture baker which rasterizes barycentric coordinates to a tensor.",
|
120 |
+
long_description=open("README.md").read(),
|
121 |
+
long_description_content_type="text/markdown",
|
122 |
+
url="https://github.com/Stability-AI/texture_baker",
|
123 |
+
cmdclass={"build_ext": BuildExtension},
|
124 |
+
)
|
texture_baker/texture_baker/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch # noqa: F401
|
2 |
+
|
3 |
+
from . import _C # noqa: F401
|
4 |
+
from .baker import TextureBaker # noqa: F401
|
texture_baker/texture_baker/baker.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
class TextureBaker(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
def rasterize(
|
11 |
+
self,
|
12 |
+
uv: Tensor,
|
13 |
+
face_indices: Tensor,
|
14 |
+
bake_resolution: int,
|
15 |
+
) -> Tensor:
|
16 |
+
"""
|
17 |
+
Rasterize the UV coordinates to a barycentric coordinates
|
18 |
+
& Triangle idxs texture map
|
19 |
+
|
20 |
+
Args:
|
21 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
22 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
23 |
+
bake_resolution (int): Resolution of the bake
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Tensor, bake_resolution bake_resolution 4, float: Rasterized map
|
27 |
+
"""
|
28 |
+
return torch.ops.texture_baker_cpp.rasterize(
|
29 |
+
uv, face_indices.to(torch.int32), bake_resolution
|
30 |
+
)
|
31 |
+
|
32 |
+
def get_mask(self, rast: Tensor) -> Tensor:
|
33 |
+
"""
|
34 |
+
Get the occupancy mask from the rasterized map
|
35 |
+
|
36 |
+
Args:
|
37 |
+
rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Tensor, bake_resolution bake_resolution, bool: Mask
|
41 |
+
"""
|
42 |
+
return rast[..., -1] >= 0
|
43 |
+
|
44 |
+
def interpolate(
|
45 |
+
self,
|
46 |
+
attr: Tensor,
|
47 |
+
rast: Tensor,
|
48 |
+
face_indices: Tensor,
|
49 |
+
) -> Tensor:
|
50 |
+
"""
|
51 |
+
Interpolate the attributes using the rasterized map
|
52 |
+
|
53 |
+
Args:
|
54 |
+
attr (Tensor, num_vertices 3, float): Attributes of the mesh
|
55 |
+
rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
|
56 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
57 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes
|
61 |
+
"""
|
62 |
+
return torch.ops.texture_baker_cpp.interpolate(
|
63 |
+
attr, face_indices.to(torch.int32), rast
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
attr: Tensor,
|
69 |
+
uv: Tensor,
|
70 |
+
face_indices: Tensor,
|
71 |
+
bake_resolution: int,
|
72 |
+
) -> Tensor:
|
73 |
+
"""
|
74 |
+
Bake the texture
|
75 |
+
|
76 |
+
Args:
|
77 |
+
attr (Tensor, num_vertices 3, float): Attributes of the mesh
|
78 |
+
uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
|
79 |
+
face_indices (Tensor, num_faces 3, int): Face indices of the mesh
|
80 |
+
bake_resolution (int): Resolution of the bake
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Tensor, bake_resolution bake_resolution 3, float: Baked texture
|
84 |
+
"""
|
85 |
+
rast = self.rasterize(uv, face_indices, bake_resolution)
|
86 |
+
return self.interpolate(attr, rast, face_indices, uv)
|
texture_baker/texture_baker/csrc/baker.cpp
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <ATen/Context.h>
|
3 |
+
#include <chrono>
|
4 |
+
#include <cmath>
|
5 |
+
#include <omp.h>
|
6 |
+
#include <torch/extension.h>
|
7 |
+
#ifndef __ARM_ARCH_ISA_A64
|
8 |
+
#include <immintrin.h>
|
9 |
+
#endif
|
10 |
+
|
11 |
+
#include "baker.h"
|
12 |
+
|
13 |
+
// #define TIMING
|
14 |
+
#define BINS 8
|
15 |
+
|
16 |
+
namespace texture_baker_cpp {
|
17 |
+
// Calculate the centroid of a triangle
|
18 |
+
tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1,
|
19 |
+
const tb_float2 &v2) {
|
20 |
+
return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f};
|
21 |
+
}
|
22 |
+
|
23 |
+
float BVH::find_best_split_plane(const BVHNode &node, int &best_axis,
|
24 |
+
int &best_pos, AABB ¢roidBounds) {
|
25 |
+
float best_cost = std::numeric_limits<float>::max();
|
26 |
+
|
27 |
+
for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
|
28 |
+
{
|
29 |
+
float boundsMin = centroidBounds.min[axis];
|
30 |
+
float boundsMax = centroidBounds.max[axis];
|
31 |
+
if (boundsMin == boundsMax) {
|
32 |
+
continue;
|
33 |
+
}
|
34 |
+
|
35 |
+
// Populate the bins
|
36 |
+
float scale = BINS / (boundsMax - boundsMin);
|
37 |
+
float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
|
38 |
+
int leftSum = 0, rightSum = 0;
|
39 |
+
|
40 |
+
#ifndef __ARM_ARCH_ISA_A64
|
41 |
+
#ifndef _MSC_VER
|
42 |
+
if (__builtin_cpu_supports("sse"))
|
43 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
44 |
+
// SSE supported on Windows
|
45 |
+
if constexpr (true)
|
46 |
+
#endif
|
47 |
+
{
|
48 |
+
__m128 min4[BINS], max4[BINS];
|
49 |
+
unsigned int count[BINS];
|
50 |
+
for (unsigned int i = 0; i < BINS; i++)
|
51 |
+
min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f),
|
52 |
+
count[i] = 0;
|
53 |
+
for (int i = node.start; i < node.end; i++) {
|
54 |
+
int tri_idx = triangle_indices[i];
|
55 |
+
const Triangle &triangle = triangles[tri_idx];
|
56 |
+
|
57 |
+
int binIdx = std::min(
|
58 |
+
BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
|
59 |
+
count[binIdx]++;
|
60 |
+
__m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f);
|
61 |
+
__m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f);
|
62 |
+
__m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f);
|
63 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
|
64 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
|
65 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
|
66 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
|
67 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
|
68 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
|
69 |
+
}
|
70 |
+
// gather data for the 7 planes between the 8 bins
|
71 |
+
__m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4;
|
72 |
+
__m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4;
|
73 |
+
for (int i = 0; i < BINS - 1; i++) {
|
74 |
+
leftSum += count[i];
|
75 |
+
rightSum += count[BINS - 1 - i];
|
76 |
+
leftMin4 = _mm_min_ps(leftMin4, min4[i]);
|
77 |
+
rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
|
78 |
+
leftMax4 = _mm_max_ps(leftMax4, max4[i]);
|
79 |
+
rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
|
80 |
+
float le[4], re[4];
|
81 |
+
_mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
|
82 |
+
_mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
|
83 |
+
// SSE order goes from back to front
|
84 |
+
leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
|
85 |
+
rightCountArea[BINS - 2 - i] =
|
86 |
+
rightSum * (re[2] * re[3]); // 2D area calculation
|
87 |
+
}
|
88 |
+
}
|
89 |
+
#else
|
90 |
+
if constexpr (false) {
|
91 |
+
}
|
92 |
+
#endif
|
93 |
+
else {
|
94 |
+
struct Bin {
|
95 |
+
AABB bounds;
|
96 |
+
int triCount = 0;
|
97 |
+
} bins[BINS];
|
98 |
+
|
99 |
+
for (int i = node.start; i < node.end; i++) {
|
100 |
+
int tri_idx = triangle_indices[i];
|
101 |
+
const Triangle &triangle = triangles[tri_idx];
|
102 |
+
|
103 |
+
int binIdx = std::min(
|
104 |
+
BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
|
105 |
+
bins[binIdx].triCount++;
|
106 |
+
bins[binIdx].bounds.grow(triangle.v0);
|
107 |
+
bins[binIdx].bounds.grow(triangle.v1);
|
108 |
+
bins[binIdx].bounds.grow(triangle.v2);
|
109 |
+
}
|
110 |
+
|
111 |
+
// Gather data for the planes between the bins
|
112 |
+
AABB leftBox, rightBox;
|
113 |
+
|
114 |
+
for (int i = 0; i < BINS - 1; i++) {
|
115 |
+
leftSum += bins[i].triCount;
|
116 |
+
leftBox.grow(bins[i].bounds);
|
117 |
+
leftCountArea[i] = leftSum * leftBox.area();
|
118 |
+
|
119 |
+
rightSum += bins[BINS - 1 - i].triCount;
|
120 |
+
rightBox.grow(bins[BINS - 1 - i].bounds);
|
121 |
+
rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
// Calculate SAH cost for the planes
|
126 |
+
scale = (boundsMax - boundsMin) / BINS;
|
127 |
+
for (int i = 0; i < BINS - 1; i++) {
|
128 |
+
float planeCost = leftCountArea[i] + rightCountArea[i];
|
129 |
+
if (planeCost < best_cost) {
|
130 |
+
best_axis = axis;
|
131 |
+
best_pos = i + 1;
|
132 |
+
best_cost = planeCost;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
return best_cost;
|
138 |
+
}
|
139 |
+
|
140 |
+
void BVH::update_node_bounds(BVHNode &node, AABB ¢roidBounds) {
|
141 |
+
#ifndef __ARM_ARCH_ISA_A64
|
142 |
+
#ifndef _MSC_VER
|
143 |
+
if (__builtin_cpu_supports("sse"))
|
144 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
145 |
+
// SSE supported on Windows
|
146 |
+
if constexpr (true)
|
147 |
+
#endif
|
148 |
+
{
|
149 |
+
__m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f);
|
150 |
+
__m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f);
|
151 |
+
|
152 |
+
for (int i = node.start; i < node.end; i += 2) {
|
153 |
+
int tri_idx1 = triangle_indices[i];
|
154 |
+
const Triangle &leafTri1 = triangles[tri_idx1];
|
155 |
+
// Check if the second actually exists in the node
|
156 |
+
__m128 v0, v1, v2, centroid;
|
157 |
+
if (i + 1 < node.end) {
|
158 |
+
int tri_idx2 = triangle_indices[i + 1];
|
159 |
+
const Triangle leafTri2 = triangles[tri_idx2];
|
160 |
+
|
161 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
|
162 |
+
leafTri2.v0.y);
|
163 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
|
164 |
+
leafTri2.v1.y);
|
165 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
|
166 |
+
leafTri2.v2.y);
|
167 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
168 |
+
leafTri2.centroid.x, leafTri2.centroid.y);
|
169 |
+
} else {
|
170 |
+
// Otherwise do some duplicated work
|
171 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
|
172 |
+
leafTri1.v0.y);
|
173 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
|
174 |
+
leafTri1.v1.y);
|
175 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
|
176 |
+
leafTri1.v2.y);
|
177 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
178 |
+
leafTri1.centroid.x, leafTri1.centroid.y);
|
179 |
+
}
|
180 |
+
|
181 |
+
min4 = _mm_min_ps(min4, v0);
|
182 |
+
max4 = _mm_max_ps(max4, v0);
|
183 |
+
min4 = _mm_min_ps(min4, v1);
|
184 |
+
max4 = _mm_max_ps(max4, v1);
|
185 |
+
min4 = _mm_min_ps(min4, v2);
|
186 |
+
max4 = _mm_max_ps(max4, v2);
|
187 |
+
cmin4 = _mm_min_ps(cmin4, centroid);
|
188 |
+
cmax4 = _mm_max_ps(cmax4, centroid);
|
189 |
+
}
|
190 |
+
|
191 |
+
float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
|
192 |
+
_mm_store_ps(min_values, min4);
|
193 |
+
_mm_store_ps(max_values, max4);
|
194 |
+
_mm_store_ps(cmin_values, cmin4);
|
195 |
+
_mm_store_ps(cmax_values, cmax4);
|
196 |
+
|
197 |
+
node.bbox.min.x = std::min(min_values[3], min_values[1]);
|
198 |
+
node.bbox.min.y = std::min(min_values[2], min_values[0]);
|
199 |
+
node.bbox.max.x = std::max(max_values[3], max_values[1]);
|
200 |
+
node.bbox.max.y = std::max(max_values[2], max_values[0]);
|
201 |
+
|
202 |
+
centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
|
203 |
+
centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
|
204 |
+
centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
|
205 |
+
centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
|
206 |
+
}
|
207 |
+
#else
|
208 |
+
if constexpr (false) {
|
209 |
+
}
|
210 |
+
#endif
|
211 |
+
{
|
212 |
+
node.bbox.invalidate();
|
213 |
+
centroidBounds.invalidate();
|
214 |
+
|
215 |
+
// Calculate the bounding box for the node
|
216 |
+
for (int i = node.start; i < node.end; ++i) {
|
217 |
+
int tri_idx = triangle_indices[i];
|
218 |
+
const Triangle &tri = triangles[tri_idx];
|
219 |
+
node.bbox.grow(tri.v0);
|
220 |
+
node.bbox.grow(tri.v1);
|
221 |
+
node.bbox.grow(tri.v2);
|
222 |
+
centroidBounds.grow(tri.centroid);
|
223 |
+
}
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
void BVH::build(const tb_float2 *vertices, const tb_int3 *indices,
|
228 |
+
const int64_t &num_indices) {
|
229 |
+
#ifdef TIMING
|
230 |
+
auto start = std::chrono::high_resolution_clock::now();
|
231 |
+
#endif
|
232 |
+
// Create triangles
|
233 |
+
for (size_t i = 0; i < num_indices; ++i) {
|
234 |
+
tb_int3 idx = indices[i];
|
235 |
+
triangles.push_back(
|
236 |
+
{vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast<int>(i),
|
237 |
+
triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])});
|
238 |
+
}
|
239 |
+
|
240 |
+
// Initialize triangle_indices
|
241 |
+
triangle_indices.resize(triangles.size());
|
242 |
+
std::iota(triangle_indices.begin(), triangle_indices.end(), 0);
|
243 |
+
|
244 |
+
// Build BVH nodes
|
245 |
+
// Reserve extra capacity to fix windows specific crashes
|
246 |
+
nodes.reserve(triangles.size() * 2 + 1);
|
247 |
+
nodes.push_back({}); // Create the root node
|
248 |
+
root = 0;
|
249 |
+
|
250 |
+
// Define a struct for queue entries
|
251 |
+
struct QueueEntry {
|
252 |
+
int node_idx;
|
253 |
+
int start;
|
254 |
+
int end;
|
255 |
+
};
|
256 |
+
|
257 |
+
// Queue for breadth-first traversal
|
258 |
+
std::queue<QueueEntry> node_queue;
|
259 |
+
node_queue.push({root, 0, (int)triangles.size()});
|
260 |
+
|
261 |
+
// Process each node in the queue
|
262 |
+
while (!node_queue.empty()) {
|
263 |
+
QueueEntry current = node_queue.front();
|
264 |
+
node_queue.pop();
|
265 |
+
|
266 |
+
int node_idx = current.node_idx;
|
267 |
+
int start = current.start;
|
268 |
+
int end = current.end;
|
269 |
+
|
270 |
+
BVHNode &node = nodes[node_idx];
|
271 |
+
node.start = start;
|
272 |
+
node.end = end;
|
273 |
+
|
274 |
+
// Calculate the bounding box for the node
|
275 |
+
AABB centroidBounds;
|
276 |
+
update_node_bounds(node, centroidBounds);
|
277 |
+
|
278 |
+
// Determine the best split using SAH
|
279 |
+
int best_axis, best_pos;
|
280 |
+
|
281 |
+
float splitCost =
|
282 |
+
find_best_split_plane(node, best_axis, best_pos, centroidBounds);
|
283 |
+
float nosplitCost = node.calculate_node_cost();
|
284 |
+
|
285 |
+
// Stop condition: if the best cost is greater than or equal to the parent's
|
286 |
+
// cost
|
287 |
+
if (splitCost >= nosplitCost) {
|
288 |
+
// Leaf node
|
289 |
+
node.left = node.right = -1;
|
290 |
+
continue;
|
291 |
+
}
|
292 |
+
|
293 |
+
float scale =
|
294 |
+
BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]);
|
295 |
+
int i = node.start;
|
296 |
+
int j = node.end - 1;
|
297 |
+
|
298 |
+
// Sort the triangle_indices in the range [start, end) based on the best
|
299 |
+
// axis
|
300 |
+
while (i <= j) {
|
301 |
+
// use the exact calculation we used for binning to prevent rare
|
302 |
+
// inaccuracies
|
303 |
+
int tri_idx = triangle_indices[i];
|
304 |
+
tb_float2 tcentr = triangles[tri_idx].centroid;
|
305 |
+
int binIdx = std::min(
|
306 |
+
BINS - 1,
|
307 |
+
(int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale));
|
308 |
+
if (binIdx < best_pos)
|
309 |
+
i++;
|
310 |
+
else
|
311 |
+
std::swap(triangle_indices[i], triangle_indices[j--]);
|
312 |
+
}
|
313 |
+
int leftCount = i - node.start;
|
314 |
+
if (leftCount == 0 || leftCount == node.num_triangles()) {
|
315 |
+
// Leaf node
|
316 |
+
node.left = node.right = -1;
|
317 |
+
continue;
|
318 |
+
}
|
319 |
+
|
320 |
+
int mid = i;
|
321 |
+
|
322 |
+
// Create and set left child
|
323 |
+
node.left = nodes.size();
|
324 |
+
nodes.push_back({});
|
325 |
+
node_queue.push({node.left, start, mid});
|
326 |
+
|
327 |
+
// Create and set right child
|
328 |
+
node = nodes[node_idx]; // Update the node - Potentially stale reference
|
329 |
+
node.right = nodes.size();
|
330 |
+
nodes.push_back({});
|
331 |
+
node_queue.push({node.right, mid, end});
|
332 |
+
}
|
333 |
+
#ifdef TIMING
|
334 |
+
auto end = std::chrono::high_resolution_clock::now();
|
335 |
+
std::chrono::duration<double> elapsed = end - start;
|
336 |
+
std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl;
|
337 |
+
#endif
|
338 |
+
}
|
339 |
+
|
340 |
+
// Utility function to clamp a value between a minimum and a maximum
|
341 |
+
float clamp(float val, float minVal, float maxVal) {
|
342 |
+
return std::min(std::max(val, minVal), maxVal);
|
343 |
+
}
|
344 |
+
|
345 |
+
// Function to check if a point (xy) is inside a triangle defined by vertices
|
346 |
+
// v1, v2, v3
|
347 |
+
bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2,
|
348 |
+
tb_float2 v3, float &u, float &v, float &w) {
|
349 |
+
// Vectors from v1 to v2, v3 and xy
|
350 |
+
tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y};
|
351 |
+
tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y};
|
352 |
+
tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y};
|
353 |
+
|
354 |
+
// Dot products of the vectors
|
355 |
+
float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
|
356 |
+
float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
|
357 |
+
float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
|
358 |
+
float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
|
359 |
+
float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
|
360 |
+
|
361 |
+
// Calculate the barycentric coordinates
|
362 |
+
float denom = d00 * d11 - d01 * d01;
|
363 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
364 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
365 |
+
u = 1.0f - v - w;
|
366 |
+
|
367 |
+
// Check if the point is inside the triangle
|
368 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
369 |
+
}
|
370 |
+
|
371 |
+
bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w,
|
372 |
+
int &index) const {
|
373 |
+
const int max_stack_size = 64;
|
374 |
+
int node_stack[max_stack_size];
|
375 |
+
int stack_size = 0;
|
376 |
+
|
377 |
+
node_stack[stack_size++] = root;
|
378 |
+
|
379 |
+
while (stack_size > 0) {
|
380 |
+
int node_idx = node_stack[--stack_size];
|
381 |
+
const BVHNode &node = nodes[node_idx];
|
382 |
+
|
383 |
+
if (node.is_leaf()) {
|
384 |
+
for (int i = node.start; i < node.end; ++i) {
|
385 |
+
const Triangle &tri = triangles[triangle_indices[i]];
|
386 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) {
|
387 |
+
index = tri.index;
|
388 |
+
return true;
|
389 |
+
}
|
390 |
+
}
|
391 |
+
} else {
|
392 |
+
if (nodes[node.right].bbox.overlaps(point)) {
|
393 |
+
if (stack_size < max_stack_size) {
|
394 |
+
node_stack[stack_size++] = node.right;
|
395 |
+
} else {
|
396 |
+
// Handle stack overflow
|
397 |
+
throw std::runtime_error("Node stack overflow");
|
398 |
+
}
|
399 |
+
}
|
400 |
+
if (nodes[node.left].bbox.overlaps(point)) {
|
401 |
+
if (stack_size < max_stack_size) {
|
402 |
+
node_stack[stack_size++] = node.left;
|
403 |
+
} else {
|
404 |
+
// Handle stack overflow
|
405 |
+
throw std::runtime_error("Node stack overflow");
|
406 |
+
}
|
407 |
+
}
|
408 |
+
}
|
409 |
+
}
|
410 |
+
|
411 |
+
return false;
|
412 |
+
}
|
413 |
+
|
414 |
+
torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices,
|
415 |
+
int64_t bake_resolution) {
|
416 |
+
int width = bake_resolution;
|
417 |
+
int height = bake_resolution;
|
418 |
+
int num_pixels = width * height;
|
419 |
+
torch::Tensor rast_result = torch::empty(
|
420 |
+
{bake_resolution, bake_resolution, 4},
|
421 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
|
422 |
+
|
423 |
+
float *rast_result_ptr = rast_result.contiguous().data_ptr<float>();
|
424 |
+
const tb_float2 *vertices = (tb_float2 *)uv.data_ptr<float>();
|
425 |
+
const tb_int3 *tris = (tb_int3 *)indices.data_ptr<int>();
|
426 |
+
|
427 |
+
BVH bvh;
|
428 |
+
bvh.build(vertices, tris, indices.size(0));
|
429 |
+
|
430 |
+
#ifdef TIMING
|
431 |
+
auto start = std::chrono::high_resolution_clock::now();
|
432 |
+
#endif
|
433 |
+
|
434 |
+
#pragma omp parallel for
|
435 |
+
for (int idx = 0; idx < num_pixels; ++idx) {
|
436 |
+
int x = idx / height;
|
437 |
+
int y = idx % height;
|
438 |
+
int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel
|
439 |
+
|
440 |
+
tb_float2 pixel_coord = {float(y) / height, float(x) / width};
|
441 |
+
pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f);
|
442 |
+
pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f);
|
443 |
+
|
444 |
+
float u, v, w;
|
445 |
+
int triangle_idx;
|
446 |
+
if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) {
|
447 |
+
rast_result_ptr[idx_ + 0] = u;
|
448 |
+
rast_result_ptr[idx_ + 1] = v;
|
449 |
+
rast_result_ptr[idx_ + 2] = w;
|
450 |
+
rast_result_ptr[idx_ + 3] = static_cast<float>(triangle_idx);
|
451 |
+
} else {
|
452 |
+
rast_result_ptr[idx_ + 0] = 0.0f;
|
453 |
+
rast_result_ptr[idx_ + 1] = 0.0f;
|
454 |
+
rast_result_ptr[idx_ + 2] = 0.0f;
|
455 |
+
rast_result_ptr[idx_ + 3] = -1.0f;
|
456 |
+
}
|
457 |
+
}
|
458 |
+
|
459 |
+
#ifdef TIMING
|
460 |
+
auto end = std::chrono::high_resolution_clock::now();
|
461 |
+
std::chrono::duration<double> elapsed = end - start;
|
462 |
+
std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl;
|
463 |
+
#endif
|
464 |
+
return rast_result;
|
465 |
+
}
|
466 |
+
|
467 |
+
torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices,
|
468 |
+
torch::Tensor rast) {
|
469 |
+
#ifdef TIMING
|
470 |
+
auto start = std::chrono::high_resolution_clock::now();
|
471 |
+
#endif
|
472 |
+
int height = rast.size(0);
|
473 |
+
int width = rast.size(1);
|
474 |
+
torch::Tensor pos_bake = torch::empty(
|
475 |
+
{height, width, 3},
|
476 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
|
477 |
+
|
478 |
+
const float *attr_ptr = attr.contiguous().data_ptr<float>();
|
479 |
+
const int *indices_ptr = indices.contiguous().data_ptr<int>();
|
480 |
+
const float *rast_ptr = rast.contiguous().data_ptr<float>();
|
481 |
+
float *output_ptr = pos_bake.contiguous().data_ptr<float>();
|
482 |
+
|
483 |
+
int num_pixels = width * height;
|
484 |
+
|
485 |
+
#pragma omp parallel for
|
486 |
+
for (int idx = 0; idx < num_pixels; ++idx) {
|
487 |
+
int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel)
|
488 |
+
tb_float3 barycentric = {
|
489 |
+
rast_ptr[idx_ + 0],
|
490 |
+
rast_ptr[idx_ + 1],
|
491 |
+
rast_ptr[idx_ + 2],
|
492 |
+
};
|
493 |
+
int triangle_idx = static_cast<int>(rast_ptr[idx_ + 3]);
|
494 |
+
|
495 |
+
if (triangle_idx < 0) {
|
496 |
+
output_ptr[idx * 3 + 0] = 0.0f;
|
497 |
+
output_ptr[idx * 3 + 1] = 0.0f;
|
498 |
+
output_ptr[idx * 3 + 2] = 0.0f;
|
499 |
+
continue;
|
500 |
+
}
|
501 |
+
|
502 |
+
tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0],
|
503 |
+
indices_ptr[3 * triangle_idx + 1],
|
504 |
+
indices_ptr[3 * triangle_idx + 2]};
|
505 |
+
tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1],
|
506 |
+
attr_ptr[3 * triangle.x + 2]};
|
507 |
+
tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1],
|
508 |
+
attr_ptr[3 * triangle.y + 2]};
|
509 |
+
tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1],
|
510 |
+
attr_ptr[3 * triangle.z + 2]};
|
511 |
+
|
512 |
+
tb_float3 interpolated;
|
513 |
+
interpolated.x =
|
514 |
+
v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z;
|
515 |
+
interpolated.y =
|
516 |
+
v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z;
|
517 |
+
interpolated.z =
|
518 |
+
v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z;
|
519 |
+
|
520 |
+
output_ptr[idx * 3 + 0] = interpolated.x;
|
521 |
+
output_ptr[idx * 3 + 1] = interpolated.y;
|
522 |
+
output_ptr[idx * 3 + 2] = interpolated.z;
|
523 |
+
}
|
524 |
+
|
525 |
+
#ifdef TIMING
|
526 |
+
auto end = std::chrono::high_resolution_clock::now();
|
527 |
+
std::chrono::duration<double> elapsed = end - start;
|
528 |
+
std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl;
|
529 |
+
#endif
|
530 |
+
return pos_bake;
|
531 |
+
}
|
532 |
+
|
533 |
+
// Registers _C as a Python extension module.
|
534 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
535 |
+
|
536 |
+
// Defines the operators
|
537 |
+
TORCH_LIBRARY(texture_baker_cpp, m) {
|
538 |
+
m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor");
|
539 |
+
m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor");
|
540 |
+
}
|
541 |
+
|
542 |
+
// Registers CPP implementations
|
543 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) {
|
544 |
+
m.impl("rasterize", &rasterize_cpu);
|
545 |
+
m.impl("interpolate", &interpolate_cpu);
|
546 |
+
}
|
547 |
+
|
548 |
+
} // namespace texture_baker_cpp
|
texture_baker/texture_baker/csrc/baker.h
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__)
|
4 |
+
#define CUDA_ENABLED
|
5 |
+
#ifndef __METAL__
|
6 |
+
#define CUDA_HOST_DEVICE __host__ __device__
|
7 |
+
#define CUDA_DEVICE __device__
|
8 |
+
#define METAL_CONSTANT_MEM
|
9 |
+
#define METAL_THREAD_MEM
|
10 |
+
#else
|
11 |
+
#define tb_float2 float2
|
12 |
+
#define CUDA_HOST_DEVICE
|
13 |
+
#define CUDA_DEVICE
|
14 |
+
#define METAL_CONSTANT_MEM constant
|
15 |
+
#define METAL_THREAD_MEM thread
|
16 |
+
#endif
|
17 |
+
#else
|
18 |
+
#define CUDA_HOST_DEVICE
|
19 |
+
#define CUDA_DEVICE
|
20 |
+
#define METAL_CONSTANT_MEM
|
21 |
+
#define METAL_THREAD_MEM
|
22 |
+
#include <cfloat>
|
23 |
+
#include <limits>
|
24 |
+
#include <vector>
|
25 |
+
#endif
|
26 |
+
|
27 |
+
namespace texture_baker_cpp {
|
28 |
+
// Structure to represent a 2D point or vector
|
29 |
+
#ifndef __METAL__
|
30 |
+
union alignas(8) tb_float2 {
|
31 |
+
struct {
|
32 |
+
float x, y;
|
33 |
+
};
|
34 |
+
|
35 |
+
float data[2];
|
36 |
+
|
37 |
+
float &operator[](size_t idx) {
|
38 |
+
if (idx > 1)
|
39 |
+
throw std::runtime_error("bad index");
|
40 |
+
return data[idx];
|
41 |
+
}
|
42 |
+
|
43 |
+
const float &operator[](size_t idx) const {
|
44 |
+
if (idx > 1)
|
45 |
+
throw std::runtime_error("bad index");
|
46 |
+
return data[idx];
|
47 |
+
}
|
48 |
+
|
49 |
+
bool operator==(const tb_float2 &rhs) const {
|
50 |
+
return x == rhs.x && y == rhs.y;
|
51 |
+
}
|
52 |
+
};
|
53 |
+
|
54 |
+
union alignas(4) tb_float3 {
|
55 |
+
struct {
|
56 |
+
float x, y, z;
|
57 |
+
};
|
58 |
+
|
59 |
+
float data[3];
|
60 |
+
|
61 |
+
float &operator[](size_t idx) {
|
62 |
+
if (idx > 2)
|
63 |
+
throw std::runtime_error("bad index");
|
64 |
+
return data[idx];
|
65 |
+
}
|
66 |
+
|
67 |
+
const float &operator[](size_t idx) const {
|
68 |
+
if (idx > 2)
|
69 |
+
throw std::runtime_error("bad index");
|
70 |
+
return data[idx];
|
71 |
+
}
|
72 |
+
};
|
73 |
+
|
74 |
+
union alignas(16) tb_float4 {
|
75 |
+
struct {
|
76 |
+
float x, y, z, w;
|
77 |
+
};
|
78 |
+
|
79 |
+
float data[4];
|
80 |
+
|
81 |
+
float &operator[](size_t idx) {
|
82 |
+
if (idx > 3)
|
83 |
+
throw std::runtime_error("bad index");
|
84 |
+
return data[idx];
|
85 |
+
}
|
86 |
+
|
87 |
+
const float &operator[](size_t idx) const {
|
88 |
+
if (idx > 3)
|
89 |
+
throw std::runtime_error("bad index");
|
90 |
+
return data[idx];
|
91 |
+
}
|
92 |
+
};
|
93 |
+
#endif
|
94 |
+
|
95 |
+
union alignas(4) tb_int3 {
|
96 |
+
struct {
|
97 |
+
int x, y, z;
|
98 |
+
};
|
99 |
+
|
100 |
+
int data[3];
|
101 |
+
#ifndef __METAL__
|
102 |
+
int &operator[](size_t idx) {
|
103 |
+
if (idx > 2)
|
104 |
+
throw std::runtime_error("bad index");
|
105 |
+
return data[idx];
|
106 |
+
}
|
107 |
+
#endif
|
108 |
+
};
|
109 |
+
|
110 |
+
// BVH structure to accelerate point-triangle intersection
|
111 |
+
struct alignas(16) AABB {
|
112 |
+
// Init bounding boxes with max/min
|
113 |
+
tb_float2 min = {FLT_MAX, FLT_MAX};
|
114 |
+
tb_float2 max = {FLT_MIN, FLT_MIN};
|
115 |
+
|
116 |
+
#ifndef CUDA_ENABLED
|
117 |
+
// grow the AABB to include a point
|
118 |
+
void grow(const tb_float2 &p) {
|
119 |
+
min.x = std::min(min.x, p.x);
|
120 |
+
min.y = std::min(min.y, p.y);
|
121 |
+
max.x = std::max(max.x, p.x);
|
122 |
+
max.y = std::max(max.y, p.y);
|
123 |
+
}
|
124 |
+
|
125 |
+
void grow(const AABB &b) {
|
126 |
+
if (b.min.x != FLT_MAX) {
|
127 |
+
grow(b.min);
|
128 |
+
grow(b.max);
|
129 |
+
}
|
130 |
+
}
|
131 |
+
#endif
|
132 |
+
|
133 |
+
// Check if two AABBs overlap
|
134 |
+
bool overlaps(const METAL_THREAD_MEM AABB &other) const {
|
135 |
+
return min.x <= other.max.x && max.x >= other.min.x &&
|
136 |
+
min.y <= other.max.y && max.y >= other.min.y;
|
137 |
+
}
|
138 |
+
|
139 |
+
bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const {
|
140 |
+
return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
|
141 |
+
point.y <= max.y;
|
142 |
+
}
|
143 |
+
|
144 |
+
#if defined(__NVCC__)
|
145 |
+
CUDA_DEVICE bool overlaps(const float2 &point) const {
|
146 |
+
return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
|
147 |
+
point.y <= max.y;
|
148 |
+
}
|
149 |
+
#endif
|
150 |
+
|
151 |
+
// Initialize AABB to an invalid state
|
152 |
+
void invalidate() {
|
153 |
+
min = {FLT_MAX, FLT_MAX};
|
154 |
+
max = {FLT_MIN, FLT_MIN};
|
155 |
+
}
|
156 |
+
|
157 |
+
// Calculate the area of the AABB
|
158 |
+
float area() const {
|
159 |
+
tb_float2 extent = {max.x - min.x, max.y - min.y};
|
160 |
+
return extent.x * extent.y;
|
161 |
+
}
|
162 |
+
};
|
163 |
+
|
164 |
+
struct BVHNode {
|
165 |
+
AABB bbox;
|
166 |
+
int start, end;
|
167 |
+
int left, right;
|
168 |
+
|
169 |
+
int num_triangles() const { return end - start; }
|
170 |
+
|
171 |
+
CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; }
|
172 |
+
|
173 |
+
float calculate_node_cost() {
|
174 |
+
float area = bbox.area();
|
175 |
+
return num_triangles() * area;
|
176 |
+
}
|
177 |
+
};
|
178 |
+
|
179 |
+
struct Triangle {
|
180 |
+
tb_float2 v0, v1, v2;
|
181 |
+
int index;
|
182 |
+
tb_float2 centroid;
|
183 |
+
};
|
184 |
+
|
185 |
+
#ifndef __METAL__
|
186 |
+
struct BVH {
|
187 |
+
std::vector<BVHNode> nodes;
|
188 |
+
std::vector<Triangle> triangles;
|
189 |
+
std::vector<int> triangle_indices;
|
190 |
+
int root;
|
191 |
+
|
192 |
+
void build(const tb_float2 *vertices, const tb_int3 *indices,
|
193 |
+
const int64_t &num_indices);
|
194 |
+
bool intersect(const tb_float2 &point, float &u, float &v, float &w,
|
195 |
+
int &index) const;
|
196 |
+
|
197 |
+
void update_node_bounds(BVHNode &node, AABB ¢roidBounds);
|
198 |
+
float find_best_split_plane(const BVHNode &node, int &best_axis,
|
199 |
+
int &best_pos, AABB ¢roidBounds);
|
200 |
+
};
|
201 |
+
#endif
|
202 |
+
|
203 |
+
} // namespace texture_baker_cpp
|
texture_baker/texture_baker/csrc/baker_kernel.cu
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
#include <ATen/Context.h>
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
#include "baker.h"
|
7 |
+
|
8 |
+
// #define TIMING
|
9 |
+
|
10 |
+
#define STRINGIFY(x) #x
|
11 |
+
#define STR(x) STRINGIFY(x)
|
12 |
+
#define FILE_LINE __FILE__ ":" STR(__LINE__)
|
13 |
+
#define CUDA_CHECK_THROW(x) \
|
14 |
+
do { \
|
15 |
+
cudaError_t _result = x; \
|
16 |
+
if (_result != cudaSuccess) \
|
17 |
+
throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \
|
18 |
+
} while(0)
|
19 |
+
|
20 |
+
namespace texture_baker_cpp
|
21 |
+
{
|
22 |
+
|
23 |
+
__device__ float3 operator+(const float3 &a, const float3 &b)
|
24 |
+
{
|
25 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
26 |
+
}
|
27 |
+
|
28 |
+
// xy: 2D test position
|
29 |
+
// v1: vertex position 1
|
30 |
+
// v2: vertex position 2
|
31 |
+
// v3: vertex position 3
|
32 |
+
//
|
33 |
+
__forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w)
|
34 |
+
{
|
35 |
+
// Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3.
|
36 |
+
// If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
|
37 |
+
float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y);
|
38 |
+
float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y);
|
39 |
+
float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y);
|
40 |
+
|
41 |
+
float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
|
42 |
+
float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
|
43 |
+
float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
|
44 |
+
float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
|
45 |
+
float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
|
46 |
+
|
47 |
+
float denom = d00 * d11 - d01 * d01;
|
48 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
49 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
50 |
+
u = 1.0f - v - w;
|
51 |
+
|
52 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
53 |
+
}
|
54 |
+
|
55 |
+
__global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height)
|
56 |
+
{
|
57 |
+
// Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
|
58 |
+
//int idx = x * width + y;
|
59 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
60 |
+
int x = idx / width;
|
61 |
+
int y = idx % width;
|
62 |
+
|
63 |
+
if (x >= width || y >= height)
|
64 |
+
return;
|
65 |
+
|
66 |
+
float4 barycentric = rast[idx];
|
67 |
+
int triangle_idx = int(barycentric.w);
|
68 |
+
|
69 |
+
if (triangle_idx < 0)
|
70 |
+
{
|
71 |
+
output[idx] = make_float3(0.0f, 0.0f, 0.0f);
|
72 |
+
return;
|
73 |
+
}
|
74 |
+
|
75 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
76 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
77 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
78 |
+
|
79 |
+
output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x)
|
80 |
+
+ make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y)
|
81 |
+
+ make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z);
|
82 |
+
}
|
83 |
+
|
84 |
+
__device__ bool bvh_intersect(
|
85 |
+
const BVHNode* __restrict__ nodes,
|
86 |
+
const Triangle* __restrict__ triangles,
|
87 |
+
const int* __restrict__ triangle_indices,
|
88 |
+
const int root,
|
89 |
+
const float2 &point,
|
90 |
+
float &u, float &v, float &w,
|
91 |
+
int &index)
|
92 |
+
{
|
93 |
+
constexpr int max_stack_size = 64;
|
94 |
+
int node_stack[max_stack_size];
|
95 |
+
int stack_size = 0;
|
96 |
+
|
97 |
+
node_stack[stack_size++] = root;
|
98 |
+
|
99 |
+
while (stack_size > 0)
|
100 |
+
{
|
101 |
+
int node_idx = node_stack[--stack_size];
|
102 |
+
const BVHNode &node = nodes[node_idx];
|
103 |
+
|
104 |
+
if (node.is_leaf())
|
105 |
+
{
|
106 |
+
for (int i = node.start; i < node.end; ++i)
|
107 |
+
{
|
108 |
+
const Triangle &tri = triangles[triangle_indices[i]];
|
109 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
|
110 |
+
{
|
111 |
+
index = tri.index;
|
112 |
+
return true;
|
113 |
+
}
|
114 |
+
}
|
115 |
+
}
|
116 |
+
else
|
117 |
+
{
|
118 |
+
if (nodes[node.right].bbox.overlaps(point))
|
119 |
+
{
|
120 |
+
if (stack_size < max_stack_size)
|
121 |
+
{
|
122 |
+
node_stack[stack_size++] = node.right;
|
123 |
+
}
|
124 |
+
else
|
125 |
+
{
|
126 |
+
// Handle stack overflow
|
127 |
+
// Make sure NDEBUG is not defined (see setup.py)
|
128 |
+
assert(0 && "Node stack overflow");
|
129 |
+
}
|
130 |
+
}
|
131 |
+
if (nodes[node.left].bbox.overlaps(point))
|
132 |
+
{
|
133 |
+
if (stack_size < max_stack_size)
|
134 |
+
{
|
135 |
+
node_stack[stack_size++] = node.left;
|
136 |
+
}
|
137 |
+
else
|
138 |
+
{
|
139 |
+
// Handle stack overflow
|
140 |
+
// Make sure NDEBUG is not defined (see setup.py)
|
141 |
+
assert(0 && "Node stack overflow");
|
142 |
+
}
|
143 |
+
}
|
144 |
+
}
|
145 |
+
}
|
146 |
+
|
147 |
+
return false;
|
148 |
+
}
|
149 |
+
|
150 |
+
__global__ void kernel_bake_uv(
|
151 |
+
float2* __restrict__ uv,
|
152 |
+
int3* __restrict__ indices,
|
153 |
+
float4* __restrict__ output,
|
154 |
+
const BVHNode* __restrict__ nodes,
|
155 |
+
const Triangle* __restrict__ triangles,
|
156 |
+
const int* __restrict__ triangle_indices,
|
157 |
+
const int root,
|
158 |
+
const int width,
|
159 |
+
const int height,
|
160 |
+
const int num_indices)
|
161 |
+
{
|
162 |
+
//int idx = x * width + y;
|
163 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
164 |
+
int x = idx / width;
|
165 |
+
int y = idx % width;
|
166 |
+
|
167 |
+
if (y >= width || x >= height)
|
168 |
+
return;
|
169 |
+
|
170 |
+
// We index x,y but the original coords are HW. So swap them
|
171 |
+
float2 pixel_coord = make_float2(float(y) / height, float(x) / width);
|
172 |
+
pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f);
|
173 |
+
pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f);
|
174 |
+
|
175 |
+
float u, v, w;
|
176 |
+
int triangle_idx;
|
177 |
+
bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
|
178 |
+
|
179 |
+
if (hit)
|
180 |
+
{
|
181 |
+
output[idx] = make_float4(u, v, w, float(triangle_idx));
|
182 |
+
return;
|
183 |
+
}
|
184 |
+
|
185 |
+
output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f);
|
186 |
+
}
|
187 |
+
|
188 |
+
torch::Tensor rasterize_gpu(
|
189 |
+
torch::Tensor uv,
|
190 |
+
torch::Tensor indices,
|
191 |
+
int64_t bake_resolution)
|
192 |
+
{
|
193 |
+
#ifdef TIMING
|
194 |
+
auto start = std::chrono::high_resolution_clock::now();
|
195 |
+
#endif
|
196 |
+
constexpr int block_size = 16 * 16;
|
197 |
+
int grid_size = bake_resolution * bake_resolution / block_size;
|
198 |
+
dim3 block_dims(block_size, 1, 1);
|
199 |
+
dim3 grid_dims(grid_size, 1, 1);
|
200 |
+
|
201 |
+
int num_indices = indices.size(0);
|
202 |
+
|
203 |
+
int width = bake_resolution;
|
204 |
+
int height = bake_resolution;
|
205 |
+
|
206 |
+
// Step 1: create an empty tensor to store the output.
|
207 |
+
torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
208 |
+
|
209 |
+
auto vertices_cpu = uv.contiguous().cpu();
|
210 |
+
auto indices_cpu = indices.contiguous().cpu();
|
211 |
+
|
212 |
+
const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
|
213 |
+
const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
|
214 |
+
|
215 |
+
BVH bvh;
|
216 |
+
bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
|
217 |
+
|
218 |
+
BVHNode *nodes_gpu = nullptr;
|
219 |
+
Triangle *triangles_gpu = nullptr;
|
220 |
+
int *triangle_indices_gpu = nullptr;
|
221 |
+
const int bvh_root = bvh.root;
|
222 |
+
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
223 |
+
|
224 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream));
|
225 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream));
|
226 |
+
CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream));
|
227 |
+
|
228 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream));
|
229 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream));
|
230 |
+
CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream));
|
231 |
+
|
232 |
+
kernel_bake_uv<<<grid_dims, block_dims, 0, cuda_stream>>>(
|
233 |
+
(float2 *)uv.contiguous().data_ptr<float>(),
|
234 |
+
(int3 *)indices.contiguous().data_ptr<int>(),
|
235 |
+
(float4 *)rast_result.contiguous().data_ptr<float>(),
|
236 |
+
nodes_gpu,
|
237 |
+
triangles_gpu,
|
238 |
+
triangle_indices_gpu,
|
239 |
+
bvh_root,
|
240 |
+
width,
|
241 |
+
height,
|
242 |
+
num_indices);
|
243 |
+
|
244 |
+
CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream));
|
245 |
+
CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream));
|
246 |
+
CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream));
|
247 |
+
|
248 |
+
#ifdef TIMING
|
249 |
+
CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
|
250 |
+
auto end = std::chrono::high_resolution_clock::now();
|
251 |
+
std::chrono::duration<double> elapsed = end - start;
|
252 |
+
std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl;
|
253 |
+
#endif
|
254 |
+
return rast_result;
|
255 |
+
}
|
256 |
+
|
257 |
+
torch::Tensor interpolate_gpu(
|
258 |
+
torch::Tensor attr,
|
259 |
+
torch::Tensor indices,
|
260 |
+
torch::Tensor rast)
|
261 |
+
{
|
262 |
+
#ifdef TIMING
|
263 |
+
auto start = std::chrono::high_resolution_clock::now();
|
264 |
+
#endif
|
265 |
+
constexpr int block_size = 16 * 16;
|
266 |
+
int grid_size = rast.size(0) * rast.size(0) / block_size;
|
267 |
+
dim3 block_dims(block_size, 1, 1);
|
268 |
+
dim3 grid_dims(grid_size, 1, 1);
|
269 |
+
|
270 |
+
// Step 1: create an empty tensor to store the output.
|
271 |
+
torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
272 |
+
|
273 |
+
int width = rast.size(0);
|
274 |
+
int height = rast.size(1);
|
275 |
+
|
276 |
+
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
277 |
+
|
278 |
+
kernel_interpolate<<<grid_dims, block_dims, 0, cuda_stream>>>(
|
279 |
+
(float3 *)attr.contiguous().data_ptr<float>(),
|
280 |
+
(int3 *)indices.contiguous().data_ptr<int>(),
|
281 |
+
(float4 *)rast.contiguous().data_ptr<float>(),
|
282 |
+
(float3 *)pos_bake.contiguous().data_ptr<float>(),
|
283 |
+
width,
|
284 |
+
height);
|
285 |
+
#ifdef TIMING
|
286 |
+
CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
|
287 |
+
auto end = std::chrono::high_resolution_clock::now();
|
288 |
+
std::chrono::duration<double> elapsed = end - start;
|
289 |
+
std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl;
|
290 |
+
#endif
|
291 |
+
return pos_bake;
|
292 |
+
}
|
293 |
+
|
294 |
+
// Registers CUDA implementations
|
295 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m)
|
296 |
+
{
|
297 |
+
m.impl("rasterize", &rasterize_gpu);
|
298 |
+
m.impl("interpolate", &interpolate_gpu);
|
299 |
+
}
|
300 |
+
|
301 |
+
}
|
texture_baker/texture_baker/csrc/baker_kernel.metal
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <metal_stdlib>
|
2 |
+
using namespace metal;
|
3 |
+
|
4 |
+
// This header is inlined manually
|
5 |
+
//#include "baker.h"
|
6 |
+
|
7 |
+
// Use the texture_baker_cpp so it can use the classes from baker.h
|
8 |
+
using namespace texture_baker_cpp;
|
9 |
+
|
10 |
+
// Utility function to compute barycentric coordinates
|
11 |
+
bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) {
|
12 |
+
float2 v1v2 = v2 - v1;
|
13 |
+
float2 v1v3 = v3 - v1;
|
14 |
+
float2 xyv1 = xy - v1;
|
15 |
+
|
16 |
+
float d00 = dot(v1v2, v1v2);
|
17 |
+
float d01 = dot(v1v2, v1v3);
|
18 |
+
float d11 = dot(v1v3, v1v3);
|
19 |
+
float d20 = dot(xyv1, v1v2);
|
20 |
+
float d21 = dot(xyv1, v1v3);
|
21 |
+
|
22 |
+
float denom = d00 * d11 - d01 * d01;
|
23 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
24 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
25 |
+
u = 1.0f - v - w;
|
26 |
+
|
27 |
+
return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
|
28 |
+
}
|
29 |
+
|
30 |
+
// Kernel function for interpolation
|
31 |
+
kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]],
|
32 |
+
constant packed_int3 *indices [[buffer(1)]],
|
33 |
+
constant packed_float4 *rast [[buffer(2)]],
|
34 |
+
device packed_float3 *output [[buffer(3)]],
|
35 |
+
constant int &width [[buffer(4)]],
|
36 |
+
constant int &height [[buffer(5)]],
|
37 |
+
uint3 blockIdx [[threadgroup_position_in_grid]],
|
38 |
+
uint3 threadIdx [[thread_position_in_threadgroup]],
|
39 |
+
uint3 blockDim [[threads_per_threadgroup]])
|
40 |
+
{
|
41 |
+
// Calculate global position using threadgroup and thread positions
|
42 |
+
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
43 |
+
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
44 |
+
|
45 |
+
if (x >= width || y >= height) return;
|
46 |
+
|
47 |
+
int idx = y * width + x;
|
48 |
+
float4 barycentric = rast[idx];
|
49 |
+
int triangle_idx = int(barycentric.w);
|
50 |
+
|
51 |
+
if (triangle_idx < 0) {
|
52 |
+
output[idx] = float3(0.0f, 0.0f, 0.0f);
|
53 |
+
return;
|
54 |
+
}
|
55 |
+
|
56 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
57 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
58 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
59 |
+
|
60 |
+
output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
|
61 |
+
}
|
62 |
+
|
63 |
+
bool bvh_intersect(
|
64 |
+
constant BVHNode* nodes,
|
65 |
+
constant Triangle* triangles,
|
66 |
+
constant int* triangle_indices,
|
67 |
+
const thread int root,
|
68 |
+
const thread float2 &point,
|
69 |
+
thread float &u, thread float &v, thread float &w,
|
70 |
+
thread int &index)
|
71 |
+
{
|
72 |
+
const int max_stack_size = 64;
|
73 |
+
thread int node_stack[max_stack_size];
|
74 |
+
int stack_size = 0;
|
75 |
+
|
76 |
+
node_stack[stack_size++] = root;
|
77 |
+
|
78 |
+
while (stack_size > 0)
|
79 |
+
{
|
80 |
+
int node_idx = node_stack[--stack_size];
|
81 |
+
BVHNode node = nodes[node_idx];
|
82 |
+
|
83 |
+
if (node.is_leaf())
|
84 |
+
{
|
85 |
+
for (int i = node.start; i < node.end; ++i)
|
86 |
+
{
|
87 |
+
constant Triangle &tri = triangles[triangle_indices[i]];
|
88 |
+
if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
|
89 |
+
{
|
90 |
+
index = tri.index;
|
91 |
+
return true;
|
92 |
+
}
|
93 |
+
}
|
94 |
+
}
|
95 |
+
else
|
96 |
+
{
|
97 |
+
BVHNode test_node = nodes[node.right];
|
98 |
+
if (test_node.bbox.overlaps(point))
|
99 |
+
{
|
100 |
+
if (stack_size < max_stack_size)
|
101 |
+
{
|
102 |
+
node_stack[stack_size++] = node.right;
|
103 |
+
}
|
104 |
+
else
|
105 |
+
{
|
106 |
+
// Handle stack overflow
|
107 |
+
// Sadly, metal doesn't support asserts (but you could try enabling metal validation layers)
|
108 |
+
return false;
|
109 |
+
}
|
110 |
+
}
|
111 |
+
test_node = nodes[node.left];
|
112 |
+
if (test_node.bbox.overlaps(point))
|
113 |
+
{
|
114 |
+
if (stack_size < max_stack_size)
|
115 |
+
{
|
116 |
+
node_stack[stack_size++] = node.left;
|
117 |
+
}
|
118 |
+
else
|
119 |
+
{
|
120 |
+
// Handle stack overflow
|
121 |
+
return false;
|
122 |
+
}
|
123 |
+
}
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
return false;
|
128 |
+
}
|
129 |
+
|
130 |
+
|
131 |
+
// Kernel function for baking UV
|
132 |
+
kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]],
|
133 |
+
constant packed_int3 *indices [[buffer(1)]],
|
134 |
+
device packed_float4 *output [[buffer(2)]],
|
135 |
+
constant BVHNode *nodes [[buffer(3)]],
|
136 |
+
constant Triangle *triangles [[buffer(4)]],
|
137 |
+
constant int *triangle_indices [[buffer(5)]],
|
138 |
+
constant int &root [[buffer(6)]],
|
139 |
+
constant int &width [[buffer(7)]],
|
140 |
+
constant int &height [[buffer(8)]],
|
141 |
+
constant int &num_indices [[buffer(9)]],
|
142 |
+
uint3 blockIdx [[threadgroup_position_in_grid]],
|
143 |
+
uint3 threadIdx [[thread_position_in_threadgroup]],
|
144 |
+
uint3 blockDim [[threads_per_threadgroup]])
|
145 |
+
{
|
146 |
+
// Calculate global position using threadgroup and thread positions
|
147 |
+
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
148 |
+
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
149 |
+
|
150 |
+
|
151 |
+
if (x >= width || y >= height) return;
|
152 |
+
|
153 |
+
int idx = x * width + y;
|
154 |
+
|
155 |
+
// Swap original coordinates
|
156 |
+
float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width));
|
157 |
+
pixel_coord = clamp(pixel_coord, 0.0f, 1.0f);
|
158 |
+
pixel_coord.y = 1.0f - pixel_coord.y;
|
159 |
+
|
160 |
+
float u, v, w;
|
161 |
+
int triangle_idx;
|
162 |
+
bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
|
163 |
+
|
164 |
+
if (hit) {
|
165 |
+
output[idx] = float4(u, v, w, float(triangle_idx));
|
166 |
+
return;
|
167 |
+
}
|
168 |
+
|
169 |
+
output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f);
|
170 |
+
}
|
texture_baker/texture_baker/csrc/baker_kernel.mm
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/Context.h>
|
4 |
+
#include "baker.h"
|
5 |
+
|
6 |
+
#import <Foundation/Foundation.h>
|
7 |
+
#import <Metal/Metal.h>
|
8 |
+
#include <filesystem>
|
9 |
+
|
10 |
+
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
|
11 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
|
12 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
13 |
+
}
|
14 |
+
|
15 |
+
// Helper function to create a compute pipeline state object (PSO).
|
16 |
+
static inline id<MTLComputePipelineState> createComputePipelineState(id<MTLDevice> device, NSString* fullSource, std::string kernel_name) {
|
17 |
+
NSError *error = nil;
|
18 |
+
|
19 |
+
// Load the custom kernel shader.
|
20 |
+
MTLCompileOptions *options = [[MTLCompileOptions alloc] init];
|
21 |
+
// Add the preprocessor macro "__METAL__"
|
22 |
+
options.preprocessorMacros = @{@"__METAL__": @""};
|
23 |
+
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error];
|
24 |
+
TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
25 |
+
|
26 |
+
id<MTLFunction> customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
|
27 |
+
TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str());
|
28 |
+
|
29 |
+
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error];
|
30 |
+
TORCH_CHECK(pso, error.localizedDescription.UTF8String);
|
31 |
+
|
32 |
+
return pso;
|
33 |
+
}
|
34 |
+
|
35 |
+
std::filesystem::path get_extension_path() {
|
36 |
+
// Ensure the GIL is held before calling any Python C API function
|
37 |
+
PyGILState_STATE gstate = PyGILState_Ensure();
|
38 |
+
|
39 |
+
const char* module_name = "texture_baker";
|
40 |
+
|
41 |
+
// Import the module by name
|
42 |
+
PyObject* module = PyImport_ImportModule(module_name);
|
43 |
+
if (!module) {
|
44 |
+
PyGILState_Release(gstate);
|
45 |
+
throw std::runtime_error("Could not import the module: " + std::string(module_name));
|
46 |
+
}
|
47 |
+
|
48 |
+
// Get the filename of the module
|
49 |
+
PyObject* filename_obj = PyModule_GetFilenameObject(module);
|
50 |
+
if (filename_obj) {
|
51 |
+
std::string path = PyUnicode_AsUTF8(filename_obj);
|
52 |
+
Py_DECREF(filename_obj);
|
53 |
+
PyGILState_Release(gstate);
|
54 |
+
|
55 |
+
// Get the directory part of the path (removing the __init__.py)
|
56 |
+
std::filesystem::path module_path = std::filesystem::path(path).parent_path();
|
57 |
+
|
58 |
+
// Append the 'csrc' directory to the path
|
59 |
+
module_path /= "csrc";
|
60 |
+
|
61 |
+
return module_path;
|
62 |
+
} else {
|
63 |
+
PyGILState_Release(gstate);
|
64 |
+
throw std::runtime_error("Could not retrieve the module filename.");
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
NSString *get_shader_sources_as_string()
|
69 |
+
{
|
70 |
+
const std::filesystem::path csrc_path = get_extension_path();
|
71 |
+
const std::string shader_path = (csrc_path / "baker_kernel.metal").string();
|
72 |
+
const std::string shader_header_path = (csrc_path / "baker.h").string();
|
73 |
+
// Load the Metal shader from the specified path
|
74 |
+
NSError *error = nil;
|
75 |
+
|
76 |
+
NSString* shaderHeaderSource = [
|
77 |
+
NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()]
|
78 |
+
encoding:NSUTF8StringEncoding
|
79 |
+
error:&error];
|
80 |
+
if (error) {
|
81 |
+
throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String));
|
82 |
+
}
|
83 |
+
|
84 |
+
NSString* shaderSource = [
|
85 |
+
NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()]
|
86 |
+
encoding:NSUTF8StringEncoding
|
87 |
+
error:&error];
|
88 |
+
if (error) {
|
89 |
+
throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String));
|
90 |
+
}
|
91 |
+
|
92 |
+
NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource];
|
93 |
+
|
94 |
+
return fullSource;
|
95 |
+
}
|
96 |
+
|
97 |
+
namespace texture_baker_cpp
|
98 |
+
{
|
99 |
+
torch::Tensor rasterize_gpu(
|
100 |
+
torch::Tensor uv,
|
101 |
+
torch::Tensor indices,
|
102 |
+
int64_t bake_resolution)
|
103 |
+
{
|
104 |
+
TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor");
|
105 |
+
TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous");
|
106 |
+
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
|
107 |
+
|
108 |
+
TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type());
|
109 |
+
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type());
|
110 |
+
|
111 |
+
torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
|
112 |
+
|
113 |
+
@autoreleasepool {
|
114 |
+
auto vertices_cpu = uv.contiguous().cpu();
|
115 |
+
auto indices_cpu = indices.contiguous().cpu();
|
116 |
+
|
117 |
+
const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
|
118 |
+
const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
|
119 |
+
|
120 |
+
BVH bvh;
|
121 |
+
bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
|
122 |
+
|
123 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
124 |
+
|
125 |
+
NSString *fullSource = get_shader_sources_as_string();
|
126 |
+
|
127 |
+
// Create a compute pipeline state object using the helper function
|
128 |
+
id<MTLComputePipelineState> bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv");
|
129 |
+
|
130 |
+
// Get a reference to the command buffer for the MPS stream.
|
131 |
+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
132 |
+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
133 |
+
|
134 |
+
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
|
135 |
+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
136 |
+
|
137 |
+
dispatch_sync(serialQueue, ^(){
|
138 |
+
// Start a compute pass.
|
139 |
+
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
140 |
+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
141 |
+
|
142 |
+
// Get Metal buffers directly from PyTorch tensors
|
143 |
+
auto uv_buf = getMTLBufferStorage(uv.contiguous());
|
144 |
+
auto indices_buf = getMTLBufferStorage(indices.contiguous());
|
145 |
+
auto rast_result_buf = getMTLBufferStorage(rast_result);
|
146 |
+
|
147 |
+
const int width = bake_resolution;
|
148 |
+
const int height = bake_resolution;
|
149 |
+
const int num_indices = indices.size(0);
|
150 |
+
const int bvh_root = bvh.root;
|
151 |
+
|
152 |
+
// Wrap the existing CPU memory in Metal buffers with shared memory
|
153 |
+
id<MTLBuffer> nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil];
|
154 |
+
id<MTLBuffer> trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil];
|
155 |
+
id<MTLBuffer> triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil];
|
156 |
+
|
157 |
+
[computeEncoder setComputePipelineState:bake_uv_PSO];
|
158 |
+
[computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0];
|
159 |
+
[computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
|
160 |
+
[computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2];
|
161 |
+
[computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3];
|
162 |
+
[computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4];
|
163 |
+
[computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5];
|
164 |
+
[computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6];
|
165 |
+
[computeEncoder setBytes:&width length:sizeof(int) atIndex:7];
|
166 |
+
[computeEncoder setBytes:&height length:sizeof(int) atIndex:8];
|
167 |
+
[computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9];
|
168 |
+
|
169 |
+
// Calculate a thread group size.
|
170 |
+
int block_size = 16;
|
171 |
+
MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
|
172 |
+
MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1);
|
173 |
+
|
174 |
+
// Encode the compute command.
|
175 |
+
[computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
|
176 |
+
[computeEncoder endEncoding];
|
177 |
+
|
178 |
+
// Commit the work.
|
179 |
+
torch::mps::commit();
|
180 |
+
});
|
181 |
+
}
|
182 |
+
|
183 |
+
return rast_result;
|
184 |
+
}
|
185 |
+
|
186 |
+
torch::Tensor interpolate_gpu(
|
187 |
+
torch::Tensor attr,
|
188 |
+
torch::Tensor indices,
|
189 |
+
torch::Tensor rast)
|
190 |
+
{
|
191 |
+
TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous");
|
192 |
+
TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
|
193 |
+
TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous");
|
194 |
+
|
195 |
+
torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
|
196 |
+
std::filesystem::path csrc_path = get_extension_path();
|
197 |
+
|
198 |
+
@autoreleasepool {
|
199 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
200 |
+
|
201 |
+
NSString *fullSource = get_shader_sources_as_string();
|
202 |
+
// Create a compute pipeline state object using the helper function
|
203 |
+
id<MTLComputePipelineState> interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate");
|
204 |
+
|
205 |
+
// Get a reference to the command buffer for the MPS stream.
|
206 |
+
id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
|
207 |
+
TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
|
208 |
+
|
209 |
+
// Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
|
210 |
+
dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
|
211 |
+
|
212 |
+
dispatch_sync(serialQueue, ^(){
|
213 |
+
// Start a compute pass.
|
214 |
+
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
|
215 |
+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
|
216 |
+
|
217 |
+
// Get Metal buffers directly from PyTorch tensors
|
218 |
+
auto attr_buf = getMTLBufferStorage(attr.contiguous());
|
219 |
+
auto indices_buf = getMTLBufferStorage(indices.contiguous());
|
220 |
+
auto rast_buf = getMTLBufferStorage(rast.contiguous());
|
221 |
+
auto pos_bake_buf = getMTLBufferStorage(pos_bake);
|
222 |
+
|
223 |
+
int width = rast.size(0);
|
224 |
+
int height = rast.size(1);
|
225 |
+
|
226 |
+
[computeEncoder setComputePipelineState:interpolate_PSO];
|
227 |
+
[computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0];
|
228 |
+
[computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
|
229 |
+
[computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2];
|
230 |
+
[computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3];
|
231 |
+
[computeEncoder setBytes:&width length:sizeof(int) atIndex:4];
|
232 |
+
[computeEncoder setBytes:&height length:sizeof(int) atIndex:5];
|
233 |
+
|
234 |
+
// Calculate a thread group size.
|
235 |
+
|
236 |
+
int block_size = 16;
|
237 |
+
MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
|
238 |
+
MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1);
|
239 |
+
|
240 |
+
// Encode the compute command.
|
241 |
+
[computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
|
242 |
+
|
243 |
+
[computeEncoder endEncoding];
|
244 |
+
|
245 |
+
// Commit the work.
|
246 |
+
torch::mps::commit();
|
247 |
+
});
|
248 |
+
}
|
249 |
+
|
250 |
+
return pos_bake;
|
251 |
+
}
|
252 |
+
|
253 |
+
// Registers MPS implementations
|
254 |
+
TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m)
|
255 |
+
{
|
256 |
+
m.impl("rasterize", &rasterize_gpu);
|
257 |
+
m.impl("interpolate", &interpolate_gpu);
|
258 |
+
}
|
259 |
+
|
260 |
+
}
|
uv_unwrapper/README.md
ADDED
File without changes
|
uv_unwrapper/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
uv_unwrapper/setup.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
from setuptools import find_packages, setup
|
6 |
+
from torch.utils.cpp_extension import (
|
7 |
+
BuildExtension,
|
8 |
+
CppExtension,
|
9 |
+
)
|
10 |
+
|
11 |
+
library_name = "uv_unwrapper"
|
12 |
+
|
13 |
+
|
14 |
+
def get_extensions():
|
15 |
+
debug_mode = os.getenv("DEBUG", "0") == "1"
|
16 |
+
if debug_mode:
|
17 |
+
print("Compiling in debug mode")
|
18 |
+
|
19 |
+
is_mac = True if torch.backends.mps.is_available() else False
|
20 |
+
use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1"
|
21 |
+
extension = CppExtension
|
22 |
+
|
23 |
+
extra_link_args = []
|
24 |
+
extra_compile_args = {
|
25 |
+
"cxx": [
|
26 |
+
"-O3" if not debug_mode else "-O0",
|
27 |
+
"-fdiagnostics-color=always",
|
28 |
+
("-Xclang " if is_mac else "") + "-fopenmp",
|
29 |
+
] + ["-march=native"] if use_native_arch else [],
|
30 |
+
}
|
31 |
+
if debug_mode:
|
32 |
+
extra_compile_args["cxx"].append("-g")
|
33 |
+
extra_compile_args["cxx"].append("-UNDEBUG")
|
34 |
+
extra_link_args.extend(["-O0", "-g"])
|
35 |
+
|
36 |
+
define_macros = []
|
37 |
+
extensions = []
|
38 |
+
|
39 |
+
this_dir = os.path.dirname(os.path.curdir)
|
40 |
+
sources = glob.glob(
|
41 |
+
os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
|
42 |
+
)
|
43 |
+
|
44 |
+
if len(sources) == 0:
|
45 |
+
print("No source files found for extension, skipping extension compilation")
|
46 |
+
return None
|
47 |
+
|
48 |
+
extensions.append(
|
49 |
+
extension(
|
50 |
+
name=f"{library_name}._C",
|
51 |
+
sources=sources,
|
52 |
+
define_macros=define_macros,
|
53 |
+
extra_compile_args=extra_compile_args,
|
54 |
+
extra_link_args=extra_link_args,
|
55 |
+
libraries=[
|
56 |
+
"c10",
|
57 |
+
"torch",
|
58 |
+
"torch_cpu",
|
59 |
+
"torch_python"
|
60 |
+
] + ["omp"] if is_mac else [],
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
print(extensions)
|
65 |
+
|
66 |
+
return extensions
|
67 |
+
|
68 |
+
|
69 |
+
setup(
|
70 |
+
name=library_name,
|
71 |
+
version="0.0.1",
|
72 |
+
packages=find_packages(),
|
73 |
+
ext_modules=get_extensions(),
|
74 |
+
install_requires=[],
|
75 |
+
description="Box projection based UV unwrapper",
|
76 |
+
long_description=open("README.md").read(),
|
77 |
+
long_description_content_type="text/markdown",
|
78 |
+
cmdclass={"build_ext": BuildExtension},
|
79 |
+
)
|
uv_unwrapper/uv_unwrapper/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch # noqa: F401
|
2 |
+
|
3 |
+
from . import _C # noqa: F401
|
4 |
+
from .unwrap import Unwrapper
|
5 |
+
|
6 |
+
__all__ = ["Unwrapper"]
|
uv_unwrapper/uv_unwrapper/csrc/bvh.cpp
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
#include "bvh.h"
|
4 |
+
#include "common.h"
|
5 |
+
#include <cstring>
|
6 |
+
#include <iostream>
|
7 |
+
#include <queue>
|
8 |
+
#include <tuple>
|
9 |
+
|
10 |
+
namespace UVUnwrapper {
|
11 |
+
BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) {
|
12 |
+
// Copty tri to triangle
|
13 |
+
triangle = new Triangle[num_indices];
|
14 |
+
memcpy(triangle, tri, num_indices * sizeof(Triangle));
|
15 |
+
|
16 |
+
// Copy actual_idx to actualIdx
|
17 |
+
actualIdx = new int[num_indices];
|
18 |
+
memcpy(actualIdx, actual_idx, num_indices * sizeof(int));
|
19 |
+
|
20 |
+
triIdx = new int[num_indices];
|
21 |
+
triCount = num_indices;
|
22 |
+
|
23 |
+
bvhNode = new BVHNode[triCount * 2 + 64];
|
24 |
+
nodesUsed = 2;
|
25 |
+
memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode));
|
26 |
+
|
27 |
+
// populate triangle index array
|
28 |
+
for (int i = 0; i < triCount; i++)
|
29 |
+
triIdx[i] = i;
|
30 |
+
|
31 |
+
BVHNode &root = bvhNode[0];
|
32 |
+
|
33 |
+
root.start = 0, root.end = triCount;
|
34 |
+
AABB centroidBounds;
|
35 |
+
UpdateNodeBounds(0, centroidBounds);
|
36 |
+
|
37 |
+
// subdivide recursively
|
38 |
+
Subdivide(0, nodesUsed, centroidBounds);
|
39 |
+
}
|
40 |
+
|
41 |
+
BVH::BVH(const BVH &other)
|
42 |
+
: BVH(other.triangle, other.triIdx, other.triCount) {}
|
43 |
+
|
44 |
+
BVH::BVH(BVH &&other) noexcept // move constructor
|
45 |
+
: triIdx(std::exchange(other.triIdx, nullptr)),
|
46 |
+
actualIdx(std::exchange(other.actualIdx, nullptr)),
|
47 |
+
triangle(std::exchange(other.triangle, nullptr)),
|
48 |
+
bvhNode(std::exchange(other.bvhNode, nullptr)) {}
|
49 |
+
|
50 |
+
BVH &BVH::operator=(const BVH &other) // copy assignment
|
51 |
+
{
|
52 |
+
return *this = BVH(other);
|
53 |
+
}
|
54 |
+
|
55 |
+
BVH &BVH::operator=(BVH &&other) noexcept // move assignment
|
56 |
+
{
|
57 |
+
std::swap(triIdx, other.triIdx);
|
58 |
+
std::swap(actualIdx, other.actualIdx);
|
59 |
+
std::swap(triangle, other.triangle);
|
60 |
+
std::swap(bvhNode, other.bvhNode);
|
61 |
+
std::swap(triCount, other.triCount);
|
62 |
+
std::swap(nodesUsed, other.nodesUsed);
|
63 |
+
return *this;
|
64 |
+
}
|
65 |
+
|
66 |
+
BVH::~BVH() {
|
67 |
+
if (triIdx)
|
68 |
+
delete[] triIdx;
|
69 |
+
if (triangle)
|
70 |
+
delete[] triangle;
|
71 |
+
if (actualIdx)
|
72 |
+
delete[] actualIdx;
|
73 |
+
if (bvhNode)
|
74 |
+
delete[] bvhNode;
|
75 |
+
}
|
76 |
+
|
77 |
+
void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds) {
|
78 |
+
BVHNode &node = bvhNode[nodeIdx];
|
79 |
+
#ifndef __ARM_ARCH_ISA_A64
|
80 |
+
#ifndef _MSC_VER
|
81 |
+
if (__builtin_cpu_supports("sse"))
|
82 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
83 |
+
// SSE supported on Windows
|
84 |
+
if constexpr (true)
|
85 |
+
#endif
|
86 |
+
{
|
87 |
+
__m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN);
|
88 |
+
__m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN);
|
89 |
+
for (int i = node.start; i < node.end; i += 2) {
|
90 |
+
Triangle &leafTri1 = triangle[triIdx[i]];
|
91 |
+
__m128 v0, v1, v2, centroid;
|
92 |
+
if (i + 1 < node.end) {
|
93 |
+
const Triangle leafTri2 = triangle[triIdx[i + 1]];
|
94 |
+
|
95 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
|
96 |
+
leafTri2.v0.y);
|
97 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
|
98 |
+
leafTri2.v1.y);
|
99 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
|
100 |
+
leafTri2.v2.y);
|
101 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
102 |
+
leafTri2.centroid.x, leafTri2.centroid.y);
|
103 |
+
} else {
|
104 |
+
// Otherwise do some duplicated work
|
105 |
+
v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
|
106 |
+
leafTri1.v0.y);
|
107 |
+
v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
|
108 |
+
leafTri1.v1.y);
|
109 |
+
v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
|
110 |
+
leafTri1.v2.y);
|
111 |
+
centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
|
112 |
+
leafTri1.centroid.x, leafTri1.centroid.y);
|
113 |
+
}
|
114 |
+
|
115 |
+
min4 = _mm_min_ps(min4, v0);
|
116 |
+
max4 = _mm_max_ps(max4, v0);
|
117 |
+
min4 = _mm_min_ps(min4, v1);
|
118 |
+
max4 = _mm_max_ps(max4, v1);
|
119 |
+
min4 = _mm_min_ps(min4, v2);
|
120 |
+
max4 = _mm_max_ps(max4, v2);
|
121 |
+
cmin4 = _mm_min_ps(cmin4, centroid);
|
122 |
+
cmax4 = _mm_max_ps(cmax4, centroid);
|
123 |
+
}
|
124 |
+
float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
|
125 |
+
_mm_store_ps(min_values, min4);
|
126 |
+
_mm_store_ps(max_values, max4);
|
127 |
+
_mm_store_ps(cmin_values, cmin4);
|
128 |
+
_mm_store_ps(cmax_values, cmax4);
|
129 |
+
|
130 |
+
node.bbox.min.x = std::min(min_values[3], min_values[1]);
|
131 |
+
node.bbox.min.y = std::min(min_values[2], min_values[0]);
|
132 |
+
node.bbox.max.x = std::max(max_values[3], max_values[1]);
|
133 |
+
node.bbox.max.y = std::max(max_values[2], max_values[0]);
|
134 |
+
|
135 |
+
centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
|
136 |
+
centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
|
137 |
+
centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
|
138 |
+
centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
|
139 |
+
}
|
140 |
+
#else
|
141 |
+
if constexpr (false) {
|
142 |
+
}
|
143 |
+
#endif
|
144 |
+
else {
|
145 |
+
node.bbox.invalidate();
|
146 |
+
centroidBounds.invalidate();
|
147 |
+
|
148 |
+
// Calculate the bounding box for the node
|
149 |
+
for (int i = node.start; i < node.end; ++i) {
|
150 |
+
const Triangle &tri = triangle[triIdx[i]];
|
151 |
+
node.bbox.grow(tri.v0);
|
152 |
+
node.bbox.grow(tri.v1);
|
153 |
+
node.bbox.grow(tri.v2);
|
154 |
+
centroidBounds.grow(tri.centroid);
|
155 |
+
}
|
156 |
+
}
|
157 |
+
}
|
158 |
+
|
159 |
+
void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr,
|
160 |
+
AABB &rootCentroidBounds) {
|
161 |
+
// Create a queue for the nodes to be subdivided
|
162 |
+
std::queue<std::tuple<unsigned int, AABB>> nodeQueue;
|
163 |
+
nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds));
|
164 |
+
|
165 |
+
while (!nodeQueue.empty()) {
|
166 |
+
// Get the next node to process from the queue
|
167 |
+
auto [node_idx, centroidBounds] = nodeQueue.front();
|
168 |
+
nodeQueue.pop();
|
169 |
+
BVHNode &node = bvhNode[node_idx];
|
170 |
+
|
171 |
+
// Check if left is -1 and right not or vice versa
|
172 |
+
|
173 |
+
int axis, splitPos;
|
174 |
+
float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds);
|
175 |
+
|
176 |
+
if (cost >= node.calculate_node_cost()) {
|
177 |
+
node.left = node.right = -1;
|
178 |
+
continue; // Move on to the next node in the queue
|
179 |
+
}
|
180 |
+
|
181 |
+
int i = node.start;
|
182 |
+
int j = node.end - 1;
|
183 |
+
float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]);
|
184 |
+
while (i <= j) {
|
185 |
+
int binIdx =
|
186 |
+
std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] -
|
187 |
+
centroidBounds.min[axis]) *
|
188 |
+
scale));
|
189 |
+
if (binIdx < splitPos)
|
190 |
+
i++;
|
191 |
+
else
|
192 |
+
std::swap(triIdx[i], triIdx[j--]);
|
193 |
+
}
|
194 |
+
|
195 |
+
int leftCount = i - node.start;
|
196 |
+
if (leftCount == 0 || leftCount == (int)node.num_triangles()) {
|
197 |
+
node.left = node.right = -1;
|
198 |
+
continue; // Move on to the next node in the queue
|
199 |
+
}
|
200 |
+
|
201 |
+
int mid = i;
|
202 |
+
|
203 |
+
// Create child nodes
|
204 |
+
int leftChildIdx = nodePtr++;
|
205 |
+
int rightChildIdx = nodePtr++;
|
206 |
+
bvhNode[leftChildIdx].start = node.start;
|
207 |
+
bvhNode[leftChildIdx].end = mid;
|
208 |
+
bvhNode[rightChildIdx].start = mid;
|
209 |
+
bvhNode[rightChildIdx].end = node.end;
|
210 |
+
node.left = leftChildIdx;
|
211 |
+
node.right = rightChildIdx;
|
212 |
+
|
213 |
+
// Update the bounds for the child nodes and push them onto the queue
|
214 |
+
UpdateNodeBounds(leftChildIdx, centroidBounds);
|
215 |
+
nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds));
|
216 |
+
|
217 |
+
UpdateNodeBounds(rightChildIdx, centroidBounds);
|
218 |
+
nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds));
|
219 |
+
}
|
220 |
+
}
|
221 |
+
|
222 |
+
float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos,
|
223 |
+
AABB ¢roidBounds) {
|
224 |
+
float best_cost = FLT_MAX;
|
225 |
+
|
226 |
+
for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
|
227 |
+
{
|
228 |
+
float boundsMin = centroidBounds.min[axis];
|
229 |
+
float boundsMax = centroidBounds.max[axis];
|
230 |
+
// Or floating point precision
|
231 |
+
if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) {
|
232 |
+
continue;
|
233 |
+
}
|
234 |
+
|
235 |
+
// populate the bins
|
236 |
+
float scale = BINS / (boundsMax - boundsMin);
|
237 |
+
float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
|
238 |
+
int leftSum = 0, rightSum = 0;
|
239 |
+
#ifndef __ARM_ARCH_ISA_A64
|
240 |
+
#ifndef _MSC_VER
|
241 |
+
if (__builtin_cpu_supports("sse"))
|
242 |
+
#elif (defined(_M_AMD64) || defined(_M_X64))
|
243 |
+
// SSE supported on Windows
|
244 |
+
if constexpr (true)
|
245 |
+
#endif
|
246 |
+
{
|
247 |
+
__m128 min4[BINS], max4[BINS];
|
248 |
+
unsigned int count[BINS];
|
249 |
+
for (unsigned int i = 0; i < BINS; i++)
|
250 |
+
min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN),
|
251 |
+
count[i] = 0;
|
252 |
+
for (int i = node.start; i < node.end; i++) {
|
253 |
+
Triangle &tri = triangle[triIdx[i]];
|
254 |
+
int binIdx =
|
255 |
+
std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
|
256 |
+
count[binIdx]++;
|
257 |
+
|
258 |
+
__m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f);
|
259 |
+
__m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f);
|
260 |
+
__m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f);
|
261 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
|
262 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
|
263 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
|
264 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
|
265 |
+
min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
|
266 |
+
max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
|
267 |
+
}
|
268 |
+
// gather data for the 7 planes between the 8 bins
|
269 |
+
__m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4;
|
270 |
+
__m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4;
|
271 |
+
for (int i = 0; i < BINS - 1; i++) {
|
272 |
+
leftSum += count[i];
|
273 |
+
rightSum += count[BINS - 1 - i];
|
274 |
+
leftMin4 = _mm_min_ps(leftMin4, min4[i]);
|
275 |
+
rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
|
276 |
+
leftMax4 = _mm_max_ps(leftMax4, max4[i]);
|
277 |
+
rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
|
278 |
+
float le[4], re[4];
|
279 |
+
_mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
|
280 |
+
_mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
|
281 |
+
// SSE order goes from back to front
|
282 |
+
leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
|
283 |
+
rightCountArea[BINS - 2 - i] =
|
284 |
+
rightSum * (re[2] * re[3]); // 2D area calculation
|
285 |
+
}
|
286 |
+
}
|
287 |
+
#else
|
288 |
+
if constexpr (false) {
|
289 |
+
}
|
290 |
+
#endif
|
291 |
+
else {
|
292 |
+
struct Bin {
|
293 |
+
AABB bounds;
|
294 |
+
int triCount = 0;
|
295 |
+
} bin[BINS];
|
296 |
+
for (int i = node.start; i < node.end; i++) {
|
297 |
+
Triangle &tri = triangle[triIdx[i]];
|
298 |
+
int binIdx =
|
299 |
+
std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
|
300 |
+
bin[binIdx].triCount++;
|
301 |
+
bin[binIdx].bounds.grow(tri.v0);
|
302 |
+
bin[binIdx].bounds.grow(tri.v1);
|
303 |
+
bin[binIdx].bounds.grow(tri.v2);
|
304 |
+
}
|
305 |
+
// gather data for the 7 planes between the 8 bins
|
306 |
+
AABB leftBox, rightBox;
|
307 |
+
for (int i = 0; i < BINS - 1; i++) {
|
308 |
+
leftSum += bin[i].triCount;
|
309 |
+
leftBox.grow(bin[i].bounds);
|
310 |
+
leftCountArea[i] = leftSum * leftBox.area();
|
311 |
+
rightSum += bin[BINS - 1 - i].triCount;
|
312 |
+
rightBox.grow(bin[BINS - 1 - i].bounds);
|
313 |
+
rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
|
314 |
+
}
|
315 |
+
}
|
316 |
+
|
317 |
+
// calculate SAH cost for the 7 planes
|
318 |
+
scale = (boundsMax - boundsMin) / BINS;
|
319 |
+
for (int i = 0; i < BINS - 1; i++) {
|
320 |
+
const float planeCost = leftCountArea[i] + rightCountArea[i];
|
321 |
+
if (planeCost < best_cost)
|
322 |
+
best_axis = axis, best_pos = i + 1, best_cost = planeCost;
|
323 |
+
}
|
324 |
+
}
|
325 |
+
return best_cost;
|
326 |
+
}
|
327 |
+
|
328 |
+
std::vector<int> BVH::Intersect(Triangle &tri_intersect) {
|
329 |
+
/**
|
330 |
+
* @brief Intersect a triangle with the BVH
|
331 |
+
*
|
332 |
+
* @param triangle the triangle to intersect
|
333 |
+
*
|
334 |
+
* @return -1 for no intersection, the index of the intersected triangle
|
335 |
+
* otherwise
|
336 |
+
*/
|
337 |
+
|
338 |
+
const int max_stack_size = 64;
|
339 |
+
int node_stack[max_stack_size];
|
340 |
+
int stack_size = 0;
|
341 |
+
std::vector<int> intersected_triangles;
|
342 |
+
|
343 |
+
node_stack[stack_size++] = 0; // Start with the root node (index 0)
|
344 |
+
while (stack_size > 0) {
|
345 |
+
int node_idx = node_stack[--stack_size];
|
346 |
+
const BVHNode &node = bvhNode[node_idx];
|
347 |
+
if (node.is_leaf()) {
|
348 |
+
for (int i = node.start; i < node.end; ++i) {
|
349 |
+
const Triangle &tri = triangle[triIdx[i]];
|
350 |
+
// Check that the triangle is not the same as the intersected triangle
|
351 |
+
if (tri == tri_intersect)
|
352 |
+
continue;
|
353 |
+
if (tri_intersect.overlaps(tri)) {
|
354 |
+
intersected_triangles.push_back(actualIdx[triIdx[i]]);
|
355 |
+
}
|
356 |
+
}
|
357 |
+
} else {
|
358 |
+
// Check right child first
|
359 |
+
if (bvhNode[node.right].bbox.overlaps(tri_intersect)) {
|
360 |
+
if (stack_size < max_stack_size) {
|
361 |
+
node_stack[stack_size++] = node.right;
|
362 |
+
} else {
|
363 |
+
throw std::runtime_error("Node stack overflow");
|
364 |
+
}
|
365 |
+
}
|
366 |
+
|
367 |
+
// Check left child
|
368 |
+
if (bvhNode[node.left].bbox.overlaps(tri_intersect)) {
|
369 |
+
if (stack_size < max_stack_size) {
|
370 |
+
node_stack[stack_size++] = node.left;
|
371 |
+
} else {
|
372 |
+
throw std::runtime_error("Node stack overflow");
|
373 |
+
}
|
374 |
+
}
|
375 |
+
}
|
376 |
+
}
|
377 |
+
return intersected_triangles; // Return all intersected triangle indices
|
378 |
+
}
|
379 |
+
|
380 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/csrc/bvh.h
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <cfloat>
|
4 |
+
#include <cmath>
|
5 |
+
#ifndef __ARM_ARCH_ISA_A64
|
6 |
+
#include <immintrin.h>
|
7 |
+
#endif
|
8 |
+
#include <limits>
|
9 |
+
#include <vector>
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
#include "intersect.h"
|
13 |
+
/**
|
14 |
+
* Based on https://github.com/jbikker/bvh_article released under the unlicense.
|
15 |
+
*/
|
16 |
+
|
17 |
+
// bin count for binned BVH building
|
18 |
+
#define BINS 8
|
19 |
+
|
20 |
+
namespace UVUnwrapper {
|
21 |
+
// minimalist triangle struct
|
22 |
+
struct alignas(32) Triangle {
|
23 |
+
uv_float2 v0;
|
24 |
+
uv_float2 v1;
|
25 |
+
uv_float2 v2;
|
26 |
+
uv_float2 centroid;
|
27 |
+
|
28 |
+
bool overlaps(const Triangle &other) {
|
29 |
+
// return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2);
|
30 |
+
return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1,
|
31 |
+
other.v2);
|
32 |
+
}
|
33 |
+
|
34 |
+
bool operator==(const Triangle &rhs) const {
|
35 |
+
return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2;
|
36 |
+
}
|
37 |
+
};
|
38 |
+
|
39 |
+
// minimalist AABB struct with grow functionality
|
40 |
+
struct alignas(16) AABB {
|
41 |
+
// Init bounding boxes with max/min
|
42 |
+
uv_float2 min = {FLT_MAX, FLT_MAX};
|
43 |
+
uv_float2 max = {FLT_MIN, FLT_MIN};
|
44 |
+
|
45 |
+
void grow(const uv_float2 &p) {
|
46 |
+
min.x = std::min(min.x, p.x);
|
47 |
+
min.y = std::min(min.y, p.y);
|
48 |
+
max.x = std::max(max.x, p.x);
|
49 |
+
max.y = std::max(max.y, p.y);
|
50 |
+
}
|
51 |
+
|
52 |
+
void grow(const AABB &b) {
|
53 |
+
if (b.min.x != FLT_MAX) {
|
54 |
+
grow(b.min);
|
55 |
+
grow(b.max);
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
bool overlaps(const Triangle &tri) {
|
60 |
+
return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2);
|
61 |
+
}
|
62 |
+
|
63 |
+
float area() const {
|
64 |
+
uv_float2 extent = {max.x - min.x, max.y - min.y};
|
65 |
+
return extent.x * extent.y;
|
66 |
+
}
|
67 |
+
|
68 |
+
void invalidate() {
|
69 |
+
min = {FLT_MAX, FLT_MAX};
|
70 |
+
max = {FLT_MIN, FLT_MIN};
|
71 |
+
}
|
72 |
+
};
|
73 |
+
|
74 |
+
// 32-byte BVH node struct
|
75 |
+
struct alignas(32) BVHNode {
|
76 |
+
AABB bbox; // 16
|
77 |
+
int start = 0, end = 0; // 8
|
78 |
+
int left, right;
|
79 |
+
|
80 |
+
int num_triangles() const { return end - start; }
|
81 |
+
|
82 |
+
bool is_leaf() const { return left == -1 && right == -1; }
|
83 |
+
|
84 |
+
float calculate_node_cost() {
|
85 |
+
float area = bbox.area();
|
86 |
+
return num_triangles() * area;
|
87 |
+
}
|
88 |
+
};
|
89 |
+
|
90 |
+
class BVH {
|
91 |
+
public:
|
92 |
+
BVH() = default;
|
93 |
+
BVH(BVH &&other) noexcept;
|
94 |
+
BVH(const BVH &other);
|
95 |
+
BVH &operator=(const BVH &other);
|
96 |
+
BVH &operator=(BVH &&other) noexcept;
|
97 |
+
BVH(Triangle *tri, int *actual_idx, const size_t &num_indices);
|
98 |
+
~BVH();
|
99 |
+
|
100 |
+
std::vector<int> Intersect(Triangle &triangle);
|
101 |
+
|
102 |
+
private:
|
103 |
+
void Subdivide(unsigned int node_idx, unsigned int &nodePtr,
|
104 |
+
AABB ¢roidBounds);
|
105 |
+
void UpdateNodeBounds(unsigned int nodeIdx, AABB ¢roidBounds);
|
106 |
+
float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos,
|
107 |
+
AABB ¢roidBounds);
|
108 |
+
|
109 |
+
public:
|
110 |
+
int *triIdx = nullptr;
|
111 |
+
int *actualIdx = nullptr;
|
112 |
+
unsigned int triCount;
|
113 |
+
unsigned int nodesUsed;
|
114 |
+
BVHNode *bvhNode = nullptr;
|
115 |
+
Triangle *triangle = nullptr;
|
116 |
+
};
|
117 |
+
|
118 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/csrc/common.h
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <array>
|
4 |
+
#include <cmath>
|
5 |
+
#include <iostream>
|
6 |
+
#include <stdexcept>
|
7 |
+
|
8 |
+
const float EPSILON = 1e-7f;
|
9 |
+
|
10 |
+
// Structure to represent a 2D point or vector
|
11 |
+
union alignas(8) uv_float2 {
|
12 |
+
struct {
|
13 |
+
float x, y;
|
14 |
+
};
|
15 |
+
|
16 |
+
float data[2];
|
17 |
+
|
18 |
+
float &operator[](size_t idx) {
|
19 |
+
if (idx > 1)
|
20 |
+
throw std::runtime_error("bad index");
|
21 |
+
return data[idx];
|
22 |
+
}
|
23 |
+
|
24 |
+
const float &operator[](size_t idx) const {
|
25 |
+
if (idx > 1)
|
26 |
+
throw std::runtime_error("bad index");
|
27 |
+
return data[idx];
|
28 |
+
}
|
29 |
+
|
30 |
+
bool operator==(const uv_float2 &rhs) const {
|
31 |
+
return x == rhs.x && y == rhs.y;
|
32 |
+
}
|
33 |
+
};
|
34 |
+
|
35 |
+
// Do not align as this is specifically tweaked for BVHNode
|
36 |
+
union uv_float3 {
|
37 |
+
struct {
|
38 |
+
float x, y, z;
|
39 |
+
};
|
40 |
+
|
41 |
+
float data[3];
|
42 |
+
|
43 |
+
float &operator[](size_t idx) {
|
44 |
+
if (idx > 3)
|
45 |
+
throw std::runtime_error("bad index");
|
46 |
+
return data[idx];
|
47 |
+
}
|
48 |
+
|
49 |
+
const float &operator[](size_t idx) const {
|
50 |
+
if (idx > 3)
|
51 |
+
throw std::runtime_error("bad index");
|
52 |
+
return data[idx];
|
53 |
+
}
|
54 |
+
|
55 |
+
bool operator==(const uv_float3 &rhs) const {
|
56 |
+
return x == rhs.x && y == rhs.y && z == rhs.z;
|
57 |
+
}
|
58 |
+
};
|
59 |
+
|
60 |
+
union alignas(16) uv_float4 {
|
61 |
+
struct {
|
62 |
+
float x, y, z, w;
|
63 |
+
};
|
64 |
+
|
65 |
+
float data[4];
|
66 |
+
|
67 |
+
float &operator[](size_t idx) {
|
68 |
+
if (idx > 3)
|
69 |
+
throw std::runtime_error("bad index");
|
70 |
+
return data[idx];
|
71 |
+
}
|
72 |
+
|
73 |
+
const float &operator[](size_t idx) const {
|
74 |
+
if (idx > 3)
|
75 |
+
throw std::runtime_error("bad index");
|
76 |
+
return data[idx];
|
77 |
+
}
|
78 |
+
|
79 |
+
bool operator==(const uv_float4 &rhs) const {
|
80 |
+
return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
|
81 |
+
}
|
82 |
+
};
|
83 |
+
|
84 |
+
union alignas(8) uv_int2 {
|
85 |
+
struct {
|
86 |
+
int x, y;
|
87 |
+
};
|
88 |
+
|
89 |
+
int data[2];
|
90 |
+
|
91 |
+
int &operator[](size_t idx) {
|
92 |
+
if (idx > 1)
|
93 |
+
throw std::runtime_error("bad index");
|
94 |
+
return data[idx];
|
95 |
+
}
|
96 |
+
|
97 |
+
const int &operator[](size_t idx) const {
|
98 |
+
if (idx > 1)
|
99 |
+
throw std::runtime_error("bad index");
|
100 |
+
return data[idx];
|
101 |
+
}
|
102 |
+
|
103 |
+
bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; }
|
104 |
+
};
|
105 |
+
|
106 |
+
union alignas(4) uv_int3 {
|
107 |
+
struct {
|
108 |
+
int x, y, z;
|
109 |
+
};
|
110 |
+
|
111 |
+
int data[3];
|
112 |
+
|
113 |
+
int &operator[](size_t idx) {
|
114 |
+
if (idx > 2)
|
115 |
+
throw std::runtime_error("bad index");
|
116 |
+
return data[idx];
|
117 |
+
}
|
118 |
+
|
119 |
+
const int &operator[](size_t idx) const {
|
120 |
+
if (idx > 2)
|
121 |
+
throw std::runtime_error("bad index");
|
122 |
+
return data[idx];
|
123 |
+
}
|
124 |
+
|
125 |
+
bool operator==(const uv_int3 &rhs) const {
|
126 |
+
return x == rhs.x && y == rhs.y && z == rhs.z;
|
127 |
+
}
|
128 |
+
};
|
129 |
+
|
130 |
+
union alignas(16) uv_int4 {
|
131 |
+
struct {
|
132 |
+
int x, y, z, w;
|
133 |
+
};
|
134 |
+
|
135 |
+
int data[4];
|
136 |
+
|
137 |
+
int &operator[](size_t idx) {
|
138 |
+
if (idx > 3)
|
139 |
+
throw std::runtime_error("bad index");
|
140 |
+
return data[idx];
|
141 |
+
}
|
142 |
+
|
143 |
+
const int &operator[](size_t idx) const {
|
144 |
+
if (idx > 3)
|
145 |
+
throw std::runtime_error("bad index");
|
146 |
+
return data[idx];
|
147 |
+
}
|
148 |
+
|
149 |
+
bool operator==(const uv_int4 &rhs) const {
|
150 |
+
return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
|
151 |
+
}
|
152 |
+
};
|
153 |
+
|
154 |
+
inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; }
|
155 |
+
|
156 |
+
// Create a triangle centroid
|
157 |
+
inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1,
|
158 |
+
const uv_float2 &v2) {
|
159 |
+
return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)};
|
160 |
+
}
|
161 |
+
|
162 |
+
inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1,
|
163 |
+
const uv_float3 &v2) {
|
164 |
+
return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y),
|
165 |
+
calc_mean(v0.z, v1.z, v2.z)};
|
166 |
+
}
|
167 |
+
|
168 |
+
// Helper functions for vector math
|
169 |
+
inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) {
|
170 |
+
return {a.x - b.x, a.y - b.y};
|
171 |
+
}
|
172 |
+
|
173 |
+
inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) {
|
174 |
+
return {a.x - b.x, a.y - b.y, a.z - b.z};
|
175 |
+
}
|
176 |
+
|
177 |
+
inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) {
|
178 |
+
return {a.x + b.x, a.y + b.y};
|
179 |
+
}
|
180 |
+
|
181 |
+
inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) {
|
182 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
183 |
+
}
|
184 |
+
|
185 |
+
inline uv_float2 operator*(const uv_float2 &a, float scalar) {
|
186 |
+
return {a.x * scalar, a.y * scalar};
|
187 |
+
}
|
188 |
+
|
189 |
+
inline uv_float3 operator*(const uv_float3 &a, float scalar) {
|
190 |
+
return {a.x * scalar, a.y * scalar, a.z * scalar};
|
191 |
+
}
|
192 |
+
|
193 |
+
inline float dot(const uv_float2 &a, const uv_float2 &b) {
|
194 |
+
return a.x * b.x + a.y * b.y;
|
195 |
+
}
|
196 |
+
|
197 |
+
inline float dot(const uv_float3 &a, const uv_float3 &b) {
|
198 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
199 |
+
}
|
200 |
+
|
201 |
+
inline float cross(const uv_float2 &a, const uv_float2 &b) {
|
202 |
+
return a.x * b.y - a.y * b.x;
|
203 |
+
}
|
204 |
+
|
205 |
+
inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) {
|
206 |
+
return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
|
207 |
+
}
|
208 |
+
|
209 |
+
inline uv_float2 abs_vec(const uv_float2 &v) {
|
210 |
+
return {std::abs(v.x), std::abs(v.y)};
|
211 |
+
}
|
212 |
+
|
213 |
+
inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) {
|
214 |
+
return {std::min(a.x, b.x), std::min(a.y, b.y)};
|
215 |
+
}
|
216 |
+
|
217 |
+
inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) {
|
218 |
+
return {std::max(a.x, b.x), std::max(a.y, b.y)};
|
219 |
+
}
|
220 |
+
|
221 |
+
inline float distance_to(const uv_float2 &a, const uv_float2 &b) {
|
222 |
+
return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2));
|
223 |
+
}
|
224 |
+
|
225 |
+
inline float distance_to(const uv_float3 &a, const uv_float3 &b) {
|
226 |
+
return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) +
|
227 |
+
std::pow(a.z - b.z, 2));
|
228 |
+
}
|
229 |
+
|
230 |
+
inline uv_float2 normalize(const uv_float2 &v) {
|
231 |
+
float len = std::sqrt(v.x * v.x + v.y * v.y);
|
232 |
+
return {v.x / len, v.y / len};
|
233 |
+
}
|
234 |
+
|
235 |
+
inline uv_float3 normalize(const uv_float3 &v) {
|
236 |
+
float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
|
237 |
+
return {v.x / len, v.y / len, v.z / len};
|
238 |
+
}
|
239 |
+
|
240 |
+
inline float magnitude(const uv_float3 &v) {
|
241 |
+
return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
|
242 |
+
}
|
243 |
+
|
244 |
+
struct Matrix4 {
|
245 |
+
std::array<std::array<float, 4>, 4> m;
|
246 |
+
|
247 |
+
Matrix4() {
|
248 |
+
for (auto &row : m) {
|
249 |
+
row.fill(0.0f);
|
250 |
+
}
|
251 |
+
m[3][3] = 1.0f; // Identity matrix for 4th row and column
|
252 |
+
}
|
253 |
+
|
254 |
+
void set(float m00, float m01, float m02, float m03, float m10, float m11,
|
255 |
+
float m12, float m13, float m20, float m21, float m22, float m23,
|
256 |
+
float m30, float m31, float m32, float m33) {
|
257 |
+
m[0][0] = m00;
|
258 |
+
m[0][1] = m01;
|
259 |
+
m[0][2] = m02;
|
260 |
+
m[0][3] = m03;
|
261 |
+
m[1][0] = m10;
|
262 |
+
m[1][1] = m11;
|
263 |
+
m[1][2] = m12;
|
264 |
+
m[1][3] = m13;
|
265 |
+
m[2][0] = m20;
|
266 |
+
m[2][1] = m21;
|
267 |
+
m[2][2] = m22;
|
268 |
+
m[2][3] = m23;
|
269 |
+
m[3][0] = m30;
|
270 |
+
m[3][1] = m31;
|
271 |
+
m[3][2] = m32;
|
272 |
+
m[3][3] = m33;
|
273 |
+
}
|
274 |
+
|
275 |
+
float determinant() const {
|
276 |
+
return m[0][3] * m[1][2] * m[2][1] * m[3][0] -
|
277 |
+
m[0][2] * m[1][3] * m[2][1] * m[3][0] -
|
278 |
+
m[0][3] * m[1][1] * m[2][2] * m[3][0] +
|
279 |
+
m[0][1] * m[1][3] * m[2][2] * m[3][0] +
|
280 |
+
m[0][2] * m[1][1] * m[2][3] * m[3][0] -
|
281 |
+
m[0][1] * m[1][2] * m[2][3] * m[3][0] -
|
282 |
+
m[0][3] * m[1][2] * m[2][0] * m[3][1] +
|
283 |
+
m[0][2] * m[1][3] * m[2][0] * m[3][1] +
|
284 |
+
m[0][3] * m[1][0] * m[2][2] * m[3][1] -
|
285 |
+
m[0][0] * m[1][3] * m[2][2] * m[3][1] -
|
286 |
+
m[0][2] * m[1][0] * m[2][3] * m[3][1] +
|
287 |
+
m[0][0] * m[1][2] * m[2][3] * m[3][1] +
|
288 |
+
m[0][3] * m[1][1] * m[2][0] * m[3][2] -
|
289 |
+
m[0][1] * m[1][3] * m[2][0] * m[3][2] -
|
290 |
+
m[0][3] * m[1][0] * m[2][1] * m[3][2] +
|
291 |
+
m[0][0] * m[1][3] * m[2][1] * m[3][2] +
|
292 |
+
m[0][1] * m[1][0] * m[2][3] * m[3][2] -
|
293 |
+
m[0][0] * m[1][1] * m[2][3] * m[3][2] -
|
294 |
+
m[0][2] * m[1][1] * m[2][0] * m[3][3] +
|
295 |
+
m[0][1] * m[1][2] * m[2][0] * m[3][3] +
|
296 |
+
m[0][2] * m[1][0] * m[2][1] * m[3][3] -
|
297 |
+
m[0][0] * m[1][2] * m[2][1] * m[3][3] -
|
298 |
+
m[0][1] * m[1][0] * m[2][2] * m[3][3] +
|
299 |
+
m[0][0] * m[1][1] * m[2][2] * m[3][3];
|
300 |
+
}
|
301 |
+
|
302 |
+
Matrix4 operator*(const Matrix4 &other) const {
|
303 |
+
Matrix4 result;
|
304 |
+
for (int row = 0; row < 4; ++row) {
|
305 |
+
for (int col = 0; col < 4; ++col) {
|
306 |
+
result.m[row][col] =
|
307 |
+
m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] +
|
308 |
+
m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col];
|
309 |
+
}
|
310 |
+
}
|
311 |
+
return result;
|
312 |
+
}
|
313 |
+
|
314 |
+
Matrix4 operator*(float scalar) const {
|
315 |
+
Matrix4 result = *this;
|
316 |
+
for (auto &row : result.m) {
|
317 |
+
for (auto &element : row) {
|
318 |
+
element *= scalar;
|
319 |
+
}
|
320 |
+
}
|
321 |
+
return result;
|
322 |
+
}
|
323 |
+
|
324 |
+
Matrix4 operator+(const Matrix4 &other) const {
|
325 |
+
Matrix4 result;
|
326 |
+
for (int i = 0; i < 4; ++i) {
|
327 |
+
for (int j = 0; j < 4; ++j) {
|
328 |
+
result.m[i][j] = m[i][j] + other.m[i][j];
|
329 |
+
}
|
330 |
+
}
|
331 |
+
return result;
|
332 |
+
}
|
333 |
+
|
334 |
+
Matrix4 operator-(const Matrix4 &other) const {
|
335 |
+
Matrix4 result;
|
336 |
+
for (int i = 0; i < 4; ++i) {
|
337 |
+
for (int j = 0; j < 4; ++j) {
|
338 |
+
result.m[i][j] = m[i][j] - other.m[i][j];
|
339 |
+
}
|
340 |
+
}
|
341 |
+
return result;
|
342 |
+
}
|
343 |
+
|
344 |
+
float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; }
|
345 |
+
|
346 |
+
Matrix4 identity() const {
|
347 |
+
Matrix4 identity;
|
348 |
+
identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);
|
349 |
+
return identity;
|
350 |
+
}
|
351 |
+
|
352 |
+
Matrix4 power(int exp) const {
|
353 |
+
if (exp == 0)
|
354 |
+
return identity();
|
355 |
+
if (exp == 1)
|
356 |
+
return *this;
|
357 |
+
|
358 |
+
Matrix4 result = *this;
|
359 |
+
for (int i = 1; i < exp; ++i) {
|
360 |
+
result = result * (*this);
|
361 |
+
}
|
362 |
+
return result;
|
363 |
+
}
|
364 |
+
|
365 |
+
void print() {
|
366 |
+
// Print all entries in 4 rows with 4 columns
|
367 |
+
for (int i = 0; i < 4; ++i) {
|
368 |
+
for (int j = 0; j < 4; ++j) {
|
369 |
+
std::cout << m[i][j] << " ";
|
370 |
+
}
|
371 |
+
std::cout << std::endl;
|
372 |
+
}
|
373 |
+
}
|
374 |
+
|
375 |
+
bool invert() {
|
376 |
+
double inv[16], det;
|
377 |
+
double mArr[16];
|
378 |
+
|
379 |
+
// Convert the matrix to a 1D array for easier manipulation
|
380 |
+
for (int i = 0; i < 4; ++i) {
|
381 |
+
for (int j = 0; j < 4; ++j) {
|
382 |
+
mArr[i * 4 + j] = static_cast<double>(m[i][j]);
|
383 |
+
}
|
384 |
+
}
|
385 |
+
|
386 |
+
inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] -
|
387 |
+
mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] +
|
388 |
+
mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10];
|
389 |
+
|
390 |
+
inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] +
|
391 |
+
mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] -
|
392 |
+
mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10];
|
393 |
+
|
394 |
+
inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] -
|
395 |
+
mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] +
|
396 |
+
mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9];
|
397 |
+
|
398 |
+
inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] +
|
399 |
+
mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] -
|
400 |
+
mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9];
|
401 |
+
|
402 |
+
inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] +
|
403 |
+
mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] -
|
404 |
+
mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10];
|
405 |
+
|
406 |
+
inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] -
|
407 |
+
mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] +
|
408 |
+
mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10];
|
409 |
+
|
410 |
+
inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] +
|
411 |
+
mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] -
|
412 |
+
mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9];
|
413 |
+
|
414 |
+
inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] -
|
415 |
+
mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] +
|
416 |
+
mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9];
|
417 |
+
|
418 |
+
inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] -
|
419 |
+
mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] +
|
420 |
+
mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6];
|
421 |
+
|
422 |
+
inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] +
|
423 |
+
mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] -
|
424 |
+
mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6];
|
425 |
+
|
426 |
+
inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] -
|
427 |
+
mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] +
|
428 |
+
mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5];
|
429 |
+
|
430 |
+
inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] +
|
431 |
+
mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] -
|
432 |
+
mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5];
|
433 |
+
|
434 |
+
inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] +
|
435 |
+
mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] -
|
436 |
+
mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6];
|
437 |
+
|
438 |
+
inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] -
|
439 |
+
mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] +
|
440 |
+
mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6];
|
441 |
+
|
442 |
+
inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] +
|
443 |
+
mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] -
|
444 |
+
mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5];
|
445 |
+
|
446 |
+
inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] -
|
447 |
+
mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] +
|
448 |
+
mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5];
|
449 |
+
|
450 |
+
det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] +
|
451 |
+
mArr[3] * inv[12];
|
452 |
+
|
453 |
+
if (fabs(det) < 1e-6) {
|
454 |
+
return false;
|
455 |
+
}
|
456 |
+
|
457 |
+
det = 1.0 / det;
|
458 |
+
|
459 |
+
for (int i = 0; i < 16; i++) {
|
460 |
+
inv[i] *= det;
|
461 |
+
}
|
462 |
+
|
463 |
+
// Convert the 1D array back to the 4x4 matrix
|
464 |
+
for (int i = 0; i < 4; ++i) {
|
465 |
+
for (int j = 0; j < 4; ++j) {
|
466 |
+
m[i][j] = static_cast<float>(inv[i * 4 + j]);
|
467 |
+
}
|
468 |
+
}
|
469 |
+
|
470 |
+
return true;
|
471 |
+
}
|
472 |
+
};
|
473 |
+
|
474 |
+
inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) {
|
475 |
+
float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] +
|
476 |
+
v.z * matrix.m[0][2] + matrix.m[0][3];
|
477 |
+
float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] +
|
478 |
+
v.z * matrix.m[1][2] + matrix.m[1][3];
|
479 |
+
float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] +
|
480 |
+
v.z * matrix.m[2][2] + matrix.m[2][3];
|
481 |
+
float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] +
|
482 |
+
matrix.m[3][3];
|
483 |
+
|
484 |
+
if (std::fabs(w) > EPSILON) {
|
485 |
+
newX /= w;
|
486 |
+
newY /= w;
|
487 |
+
newZ /= w;
|
488 |
+
}
|
489 |
+
|
490 |
+
v.x = newX;
|
491 |
+
v.y = newY;
|
492 |
+
v.z = newZ;
|
493 |
+
}
|
uv_unwrapper/uv_unwrapper/csrc/intersect.cpp
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "intersect.h"
|
2 |
+
#include "bvh.h"
|
3 |
+
#include <algorithm>
|
4 |
+
#include <cmath>
|
5 |
+
#include <iostream>
|
6 |
+
#include <stdexcept>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
bool triangle_aabb_intersection(const uv_float2 &aabbMin,
|
10 |
+
const uv_float2 &aabbMax, const uv_float2 &v0,
|
11 |
+
const uv_float2 &v1, const uv_float2 &v2) {
|
12 |
+
// Convert the min and max aabb defintion to left, right, top, bottom
|
13 |
+
float l = aabbMin.x;
|
14 |
+
float r = aabbMax.x;
|
15 |
+
float t = aabbMin.y;
|
16 |
+
float b = aabbMax.y;
|
17 |
+
|
18 |
+
int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) |
|
19 |
+
((v0.y > b) ? 8 : 0);
|
20 |
+
if (b0 == 3)
|
21 |
+
return true;
|
22 |
+
|
23 |
+
int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) |
|
24 |
+
((v1.y > b) ? 8 : 0);
|
25 |
+
if (b1 == 3)
|
26 |
+
return true;
|
27 |
+
|
28 |
+
int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) |
|
29 |
+
((v2.y > b) ? 8 : 0);
|
30 |
+
if (b2 == 3)
|
31 |
+
return true;
|
32 |
+
|
33 |
+
float m, c, s;
|
34 |
+
|
35 |
+
int i0 = b0 ^ b1;
|
36 |
+
if (i0 != 0) {
|
37 |
+
if (v1.x != v0.x) {
|
38 |
+
m = (v1.y - v0.y) / (v1.x - v0.x);
|
39 |
+
c = v0.y - (m * v0.x);
|
40 |
+
if (i0 & 1) {
|
41 |
+
s = m * l + c;
|
42 |
+
if (s >= t && s <= b)
|
43 |
+
return true;
|
44 |
+
}
|
45 |
+
if (i0 & 2) {
|
46 |
+
s = (t - c) / m;
|
47 |
+
if (s >= l && s <= r)
|
48 |
+
return true;
|
49 |
+
}
|
50 |
+
if (i0 & 4) {
|
51 |
+
s = m * r + c;
|
52 |
+
if (s >= t && s <= b)
|
53 |
+
return true;
|
54 |
+
}
|
55 |
+
if (i0 & 8) {
|
56 |
+
s = (b - c) / m;
|
57 |
+
if (s >= l && s <= r)
|
58 |
+
return true;
|
59 |
+
}
|
60 |
+
} else {
|
61 |
+
if (l == v0.x || r == v0.x)
|
62 |
+
return true;
|
63 |
+
if (v0.x > l && v0.x < r)
|
64 |
+
return true;
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
int i1 = b1 ^ b2;
|
69 |
+
if (i1 != 0) {
|
70 |
+
if (v2.x != v1.x) {
|
71 |
+
m = (v2.y - v1.y) / (v2.x - v1.x);
|
72 |
+
c = v1.y - (m * v1.x);
|
73 |
+
if (i1 & 1) {
|
74 |
+
s = m * l + c;
|
75 |
+
if (s >= t && s <= b)
|
76 |
+
return true;
|
77 |
+
}
|
78 |
+
if (i1 & 2) {
|
79 |
+
s = (t - c) / m;
|
80 |
+
if (s >= l && s <= r)
|
81 |
+
return true;
|
82 |
+
}
|
83 |
+
if (i1 & 4) {
|
84 |
+
s = m * r + c;
|
85 |
+
if (s >= t && s <= b)
|
86 |
+
return true;
|
87 |
+
}
|
88 |
+
if (i1 & 8) {
|
89 |
+
s = (b - c) / m;
|
90 |
+
if (s >= l && s <= r)
|
91 |
+
return true;
|
92 |
+
}
|
93 |
+
} else {
|
94 |
+
if (l == v1.x || r == v1.x)
|
95 |
+
return true;
|
96 |
+
if (v1.x > l && v1.x < r)
|
97 |
+
return true;
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
int i2 = b0 ^ b2;
|
102 |
+
if (i2 != 0) {
|
103 |
+
if (v2.x != v0.x) {
|
104 |
+
m = (v2.y - v0.y) / (v2.x - v0.x);
|
105 |
+
c = v0.y - (m * v0.x);
|
106 |
+
if (i2 & 1) {
|
107 |
+
s = m * l + c;
|
108 |
+
if (s >= t && s <= b)
|
109 |
+
return true;
|
110 |
+
}
|
111 |
+
if (i2 & 2) {
|
112 |
+
s = (t - c) / m;
|
113 |
+
if (s >= l && s <= r)
|
114 |
+
return true;
|
115 |
+
}
|
116 |
+
if (i2 & 4) {
|
117 |
+
s = m * r + c;
|
118 |
+
if (s >= t && s <= b)
|
119 |
+
return true;
|
120 |
+
}
|
121 |
+
if (i2 & 8) {
|
122 |
+
s = (b - c) / m;
|
123 |
+
if (s >= l && s <= r)
|
124 |
+
return true;
|
125 |
+
}
|
126 |
+
} else {
|
127 |
+
if (l == v0.x || r == v0.x)
|
128 |
+
return true;
|
129 |
+
if (v0.x > l && v0.x < r)
|
130 |
+
return true;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
// Bounding box check
|
135 |
+
float tbb_l = std::min(v0.x, std::min(v1.x, v2.x));
|
136 |
+
float tbb_t = std::min(v0.y, std::min(v1.y, v2.y));
|
137 |
+
float tbb_r = std::max(v0.x, std::max(v1.x, v2.x));
|
138 |
+
float tbb_b = std::max(v0.y, std::max(v1.y, v2.y));
|
139 |
+
|
140 |
+
if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) {
|
141 |
+
float v0x = v2.x - v0.x;
|
142 |
+
float v0y = v2.y - v0.y;
|
143 |
+
float v1x = v1.x - v0.x;
|
144 |
+
float v1y = v1.y - v0.y;
|
145 |
+
float v2x, v2y;
|
146 |
+
|
147 |
+
float dot00, dot01, dot02, dot11, dot12, invDenom, u, v;
|
148 |
+
|
149 |
+
// Top-left corner
|
150 |
+
v2x = l - v0.x;
|
151 |
+
v2y = t - v0.y;
|
152 |
+
|
153 |
+
dot00 = v0x * v0x + v0y * v0y;
|
154 |
+
dot01 = v0x * v1x + v0y * v1y;
|
155 |
+
dot02 = v0x * v2x + v0y * v2y;
|
156 |
+
dot11 = v1x * v1x + v1y * v1y;
|
157 |
+
dot12 = v1x * v2x + v1y * v2y;
|
158 |
+
|
159 |
+
invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01);
|
160 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
161 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
162 |
+
|
163 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
164 |
+
return true;
|
165 |
+
|
166 |
+
// Bottom-left corner
|
167 |
+
v2x = l - v0.x;
|
168 |
+
v2y = b - v0.y;
|
169 |
+
|
170 |
+
dot02 = v0x * v2x + v0y * v2y;
|
171 |
+
dot12 = v1x * v2x + v1y * v2y;
|
172 |
+
|
173 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
174 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
175 |
+
|
176 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
177 |
+
return true;
|
178 |
+
|
179 |
+
// Bottom-right corner
|
180 |
+
v2x = r - v0.x;
|
181 |
+
v2y = b - v0.y;
|
182 |
+
|
183 |
+
dot02 = v0x * v2x + v0y * v2y;
|
184 |
+
dot12 = v1x * v2x + v1y * v2y;
|
185 |
+
|
186 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
187 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
188 |
+
|
189 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
190 |
+
return true;
|
191 |
+
|
192 |
+
// Top-right corner
|
193 |
+
v2x = r - v0.x;
|
194 |
+
v2y = t - v0.y;
|
195 |
+
|
196 |
+
dot02 = v0x * v2x + v0y * v2y;
|
197 |
+
dot12 = v1x * v2x + v1y * v2y;
|
198 |
+
|
199 |
+
u = (dot11 * dot02 - dot01 * dot12) * invDenom;
|
200 |
+
v = (dot00 * dot12 - dot01 * dot02) * invDenom;
|
201 |
+
|
202 |
+
if (u >= 0 && v >= 0 && (u + v) <= 1)
|
203 |
+
return true;
|
204 |
+
}
|
205 |
+
|
206 |
+
return false;
|
207 |
+
}
|
208 |
+
|
209 |
+
void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) {
|
210 |
+
float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
|
211 |
+
|
212 |
+
// If the determinant is negative, the triangle is oriented clockwise
|
213 |
+
if (det < 0) {
|
214 |
+
// Swap vertices b and c to ensure counter-clockwise winding
|
215 |
+
std::swap(b, c);
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
struct Triangle {
|
220 |
+
uv_float3 a, b, c;
|
221 |
+
|
222 |
+
Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1)
|
223 |
+
: a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {}
|
224 |
+
|
225 |
+
Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1)
|
226 |
+
: a(p1), b(q1), c(r1) {}
|
227 |
+
|
228 |
+
void getNormal(uv_float3 &normal) const {
|
229 |
+
uv_float3 u = b - a;
|
230 |
+
uv_float3 v = c - a;
|
231 |
+
normal = normalize(cross(u, v));
|
232 |
+
}
|
233 |
+
};
|
234 |
+
|
235 |
+
bool isTriDegenerated(const Triangle &tri) {
|
236 |
+
uv_float3 u = tri.a - tri.b;
|
237 |
+
uv_float3 v = tri.a - tri.c;
|
238 |
+
uv_float3 cr = cross(u, v);
|
239 |
+
return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON;
|
240 |
+
}
|
241 |
+
|
242 |
+
int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c,
|
243 |
+
const uv_float3 &d) {
|
244 |
+
Matrix4 _matrix4;
|
245 |
+
_matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y,
|
246 |
+
d.z, 1);
|
247 |
+
float det = _matrix4.determinant();
|
248 |
+
|
249 |
+
if (det < -EPSILON)
|
250 |
+
return -1;
|
251 |
+
else if (det > EPSILON)
|
252 |
+
return 1;
|
253 |
+
else
|
254 |
+
return 0;
|
255 |
+
}
|
256 |
+
|
257 |
+
int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) {
|
258 |
+
float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
|
259 |
+
|
260 |
+
if (det < -EPSILON)
|
261 |
+
return -1;
|
262 |
+
else if (det > EPSILON)
|
263 |
+
return 1;
|
264 |
+
else
|
265 |
+
return 0;
|
266 |
+
}
|
267 |
+
|
268 |
+
int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) {
|
269 |
+
uv_float2 a_2d = {a.x, a.y};
|
270 |
+
uv_float2 b_2d = {b.x, b.y};
|
271 |
+
uv_float2 c_2d = {c.x, c.y};
|
272 |
+
return orient2D(a_2d, b_2d, c_2d);
|
273 |
+
}
|
274 |
+
|
275 |
+
void permuteTriLeft(Triangle &tri) {
|
276 |
+
uv_float3 tmp = tri.a;
|
277 |
+
tri.a = tri.b;
|
278 |
+
tri.b = tri.c;
|
279 |
+
tri.c = tmp;
|
280 |
+
}
|
281 |
+
|
282 |
+
void permuteTriRight(Triangle &tri) {
|
283 |
+
uv_float3 tmp = tri.c;
|
284 |
+
tri.c = tri.b;
|
285 |
+
tri.b = tri.a;
|
286 |
+
tri.a = tmp;
|
287 |
+
}
|
288 |
+
|
289 |
+
void makeTriCounterClockwise(Triangle &tri) {
|
290 |
+
if (orient2D(tri.a, tri.b, tri.c) < 0) {
|
291 |
+
uv_float3 tmp = tri.c;
|
292 |
+
tri.c = tri.b;
|
293 |
+
tri.b = tmp;
|
294 |
+
}
|
295 |
+
}
|
296 |
+
|
297 |
+
void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p,
|
298 |
+
const uv_float3 &n, uv_float3 &target) {
|
299 |
+
uv_float3 u = b - a;
|
300 |
+
uv_float3 v = a - p;
|
301 |
+
float dot1 = dot(n, u);
|
302 |
+
float dot2 = dot(n, v);
|
303 |
+
u = u * (-dot2 / dot1);
|
304 |
+
target = a + u;
|
305 |
+
}
|
306 |
+
|
307 |
+
void computeLineIntersection(const Triangle &t1, const Triangle &t2,
|
308 |
+
std::vector<uv_float3> &target) {
|
309 |
+
uv_float3 n1, n2;
|
310 |
+
t1.getNormal(n1);
|
311 |
+
t2.getNormal(n2);
|
312 |
+
|
313 |
+
int o1 = orient3D(t1.a, t1.c, t2.b, t2.a);
|
314 |
+
int o2 = orient3D(t1.a, t1.b, t2.c, t2.a);
|
315 |
+
|
316 |
+
uv_float3 i1, i2;
|
317 |
+
|
318 |
+
if (o1 > 0) {
|
319 |
+
if (o2 > 0) {
|
320 |
+
intersectPlane(t1.a, t1.c, t2.a, n2, i1);
|
321 |
+
intersectPlane(t2.a, t2.c, t1.a, n1, i2);
|
322 |
+
} else {
|
323 |
+
intersectPlane(t1.a, t1.c, t2.a, n2, i1);
|
324 |
+
intersectPlane(t1.a, t1.b, t2.a, n2, i2);
|
325 |
+
}
|
326 |
+
} else {
|
327 |
+
if (o2 > 0) {
|
328 |
+
intersectPlane(t2.a, t2.b, t1.a, n1, i1);
|
329 |
+
intersectPlane(t2.a, t2.c, t1.a, n1, i2);
|
330 |
+
} else {
|
331 |
+
intersectPlane(t2.a, t2.b, t1.a, n1, i1);
|
332 |
+
intersectPlane(t1.a, t1.b, t2.a, n2, i2);
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
target.push_back(i1);
|
337 |
+
if (distance_to(i1, i2) >= EPSILON) {
|
338 |
+
target.push_back(i2);
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) {
|
343 |
+
// Permute a, b, c so that a is alone on its side
|
344 |
+
if (oa == ob) {
|
345 |
+
// c is alone, permute right so c becomes a
|
346 |
+
permuteTriRight(tri);
|
347 |
+
} else if (oa == oc) {
|
348 |
+
// b is alone, permute so b becomes a
|
349 |
+
permuteTriLeft(tri);
|
350 |
+
} else if (ob != oc) {
|
351 |
+
// In case a, b, c have different orientation, put a on positive side
|
352 |
+
if (ob > 0) {
|
353 |
+
permuteTriLeft(tri);
|
354 |
+
} else if (oc > 0) {
|
355 |
+
permuteTriRight(tri);
|
356 |
+
}
|
357 |
+
}
|
358 |
+
}
|
359 |
+
|
360 |
+
void makeTriAVertexPositive(Triangle &tri, const Triangle &other) {
|
361 |
+
int o = orient3D(other.a, other.b, other.c, tri.a);
|
362 |
+
if (o < 0) {
|
363 |
+
std::swap(tri.b, tri.c);
|
364 |
+
}
|
365 |
+
}
|
366 |
+
|
367 |
+
bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c,
|
368 |
+
std::vector<uv_float3> *target = nullptr) {
|
369 |
+
int o2a = orient3D(t1.a, t1.b, t1.c, t2.a);
|
370 |
+
int o2b = orient3D(t1.a, t1.b, t1.c, t2.b);
|
371 |
+
int o2c = orient3D(t1.a, t1.b, t1.c, t2.c);
|
372 |
+
|
373 |
+
if (o2a == o2b && o2a == o2c) {
|
374 |
+
return false;
|
375 |
+
}
|
376 |
+
|
377 |
+
// Make a vertex alone on its side for both triangles
|
378 |
+
makeTriAVertexAlone(t1, o1a, o1b, o1c);
|
379 |
+
makeTriAVertexAlone(t2, o2a, o2b, o2c);
|
380 |
+
|
381 |
+
// Ensure the vertex on the positive side
|
382 |
+
makeTriAVertexPositive(t2, t1);
|
383 |
+
makeTriAVertexPositive(t1, t2);
|
384 |
+
|
385 |
+
int o1 = orient3D(t1.a, t1.b, t2.a, t2.b);
|
386 |
+
int o2 = orient3D(t1.a, t1.c, t2.c, t2.a);
|
387 |
+
|
388 |
+
if (o1 <= 0 && o2 <= 0) {
|
389 |
+
if (target) {
|
390 |
+
computeLineIntersection(t1, t2, *target);
|
391 |
+
}
|
392 |
+
return true;
|
393 |
+
}
|
394 |
+
|
395 |
+
return false;
|
396 |
+
}
|
397 |
+
|
398 |
+
void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1,
|
399 |
+
const uv_float3 &a2, const uv_float3 &b2,
|
400 |
+
uv_float3 &target) {
|
401 |
+
float dx1 = a1.x - b1.x;
|
402 |
+
float dx2 = a2.x - b2.x;
|
403 |
+
float dy1 = a1.y - b1.y;
|
404 |
+
float dy2 = a2.y - b2.y;
|
405 |
+
|
406 |
+
float D = dx1 * dy2 - dx2 * dy1;
|
407 |
+
|
408 |
+
float n1 = a1.x * b1.y - a1.y * b1.x;
|
409 |
+
float n2 = a2.x * b2.y - a2.y * b2.x;
|
410 |
+
|
411 |
+
target.x = (n1 * dx2 - n2 * dx1) / D;
|
412 |
+
target.y = (n1 * dy2 - n2 * dy1) / D;
|
413 |
+
target.z = 0;
|
414 |
+
}
|
415 |
+
|
416 |
+
void clipTriangle(const Triangle &t1, const Triangle &t2,
|
417 |
+
std::vector<uv_float3> &target) {
|
418 |
+
std::vector<uv_float3> clip = {t1.a, t1.b, t1.c};
|
419 |
+
std::vector<uv_float3> output = {t2.a, t2.b, t2.c};
|
420 |
+
std::vector<int> orients(output.size() * 3, 0);
|
421 |
+
uv_float3 inter;
|
422 |
+
|
423 |
+
for (int i = 0; i < 3; ++i) {
|
424 |
+
const int i_prev = (i + 2) % 3;
|
425 |
+
std::vector<uv_float3> input;
|
426 |
+
std::copy(output.begin(), output.end(), std::back_inserter(input));
|
427 |
+
output.clear();
|
428 |
+
|
429 |
+
for (size_t j = 0; j < input.size(); ++j) {
|
430 |
+
orients[j] = orient2D(clip[i_prev], clip[i], input[j]);
|
431 |
+
}
|
432 |
+
|
433 |
+
for (size_t j = 0; j < input.size(); ++j) {
|
434 |
+
const int j_prev = (j - 1 + input.size()) % input.size();
|
435 |
+
|
436 |
+
if (orients[j] >= 0) {
|
437 |
+
if (orients[j_prev] < 0) {
|
438 |
+
linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j],
|
439 |
+
inter);
|
440 |
+
output.push_back({inter.x, inter.y, inter.z});
|
441 |
+
}
|
442 |
+
output.push_back({input[j].x, input[j].y, input[j].z});
|
443 |
+
} else if (orients[j_prev] >= 0) {
|
444 |
+
linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter);
|
445 |
+
output.push_back({inter.x, inter.y, inter.z});
|
446 |
+
}
|
447 |
+
}
|
448 |
+
}
|
449 |
+
|
450 |
+
// Clear duplicated points
|
451 |
+
for (const auto &point : output) {
|
452 |
+
int j = 0;
|
453 |
+
bool sameFound = false;
|
454 |
+
while (!sameFound && j < target.size()) {
|
455 |
+
sameFound = distance_to(point, target[j]) <= 1e-6;
|
456 |
+
j++;
|
457 |
+
}
|
458 |
+
|
459 |
+
if (!sameFound) {
|
460 |
+
target.push_back(point);
|
461 |
+
}
|
462 |
+
}
|
463 |
+
}
|
464 |
+
|
465 |
+
bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) {
|
466 |
+
const uv_float3 &p1 = t1.a;
|
467 |
+
const uv_float3 &q1 = t1.b;
|
468 |
+
const uv_float3 &r1 = t1.c;
|
469 |
+
const uv_float3 &p2 = t2.a;
|
470 |
+
const uv_float3 &r2 = t2.c;
|
471 |
+
|
472 |
+
if (orient2D(r2, p2, q1) >= 0) { // I
|
473 |
+
if (orient2D(r2, p1, q1) >= 0) { // II.a
|
474 |
+
if (orient2D(p1, p2, q1) >= 0) { // III.a
|
475 |
+
return true;
|
476 |
+
} else {
|
477 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.a
|
478 |
+
if (orient2D(q1, r1, p2) >= 0) { // V
|
479 |
+
return true;
|
480 |
+
}
|
481 |
+
}
|
482 |
+
}
|
483 |
+
}
|
484 |
+
} else {
|
485 |
+
if (orient2D(r2, p2, r1) >= 0) { // II.b
|
486 |
+
if (orient2D(q1, r1, r2) >= 0) { // III.b
|
487 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper)
|
488 |
+
return true;
|
489 |
+
}
|
490 |
+
}
|
491 |
+
}
|
492 |
+
}
|
493 |
+
|
494 |
+
return false;
|
495 |
+
}
|
496 |
+
|
497 |
+
bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) {
|
498 |
+
const uv_float3 &p1 = t1.a;
|
499 |
+
const uv_float3 &q1 = t1.b;
|
500 |
+
const uv_float3 &r1 = t1.c;
|
501 |
+
const uv_float3 &p2 = t2.a;
|
502 |
+
const uv_float3 &q2 = t2.b;
|
503 |
+
const uv_float3 &r2 = t2.c;
|
504 |
+
|
505 |
+
if (orient2D(r2, p2, q1) >= 0) { // I
|
506 |
+
if (orient2D(q2, r2, q1) >= 0) { // II.a
|
507 |
+
if (orient2D(p1, p2, q1) >= 0) { // III.a
|
508 |
+
if (orient2D(p1, q2, q1) <= 0) { // IV.a
|
509 |
+
return true;
|
510 |
+
}
|
511 |
+
} else {
|
512 |
+
if (orient2D(p1, p2, r1) >= 0) { // IV.b
|
513 |
+
if (orient2D(r2, p2, r1) <= 0) { // V.a
|
514 |
+
return true;
|
515 |
+
}
|
516 |
+
}
|
517 |
+
}
|
518 |
+
} else {
|
519 |
+
if (orient2D(p1, q2, q1) <= 0) { // III.b
|
520 |
+
if (orient2D(q2, r2, r1) >= 0) { // IV.c
|
521 |
+
if (orient2D(q1, r1, q2) >= 0) { // V.b
|
522 |
+
return true;
|
523 |
+
}
|
524 |
+
}
|
525 |
+
}
|
526 |
+
}
|
527 |
+
} else {
|
528 |
+
if (orient2D(r2, p2, r1) >= 0) { // II.b
|
529 |
+
if (orient2D(q1, r1, r2) >= 0) { // III.c
|
530 |
+
if (orient2D(r1, p1, p2) >= 0) { // IV.d
|
531 |
+
return true;
|
532 |
+
}
|
533 |
+
} else {
|
534 |
+
if (orient2D(q1, r1, q2) >= 0) { // IV.e
|
535 |
+
if (orient2D(q2, r2, r1) >= 0) { // V.c
|
536 |
+
return true;
|
537 |
+
}
|
538 |
+
}
|
539 |
+
}
|
540 |
+
}
|
541 |
+
}
|
542 |
+
|
543 |
+
return false;
|
544 |
+
}
|
545 |
+
|
546 |
+
bool coplanarIntersect(Triangle &t1, Triangle &t2,
|
547 |
+
std::vector<uv_float3> *target = nullptr) {
|
548 |
+
uv_float3 normal, u, v;
|
549 |
+
t1.getNormal(normal);
|
550 |
+
normal = normalize(normal);
|
551 |
+
u = normalize(t1.a - t1.b);
|
552 |
+
v = cross(normal, u);
|
553 |
+
|
554 |
+
// Move basis to t1.a
|
555 |
+
u = u + t1.a;
|
556 |
+
v = v + t1.a;
|
557 |
+
normal = normal + t1.a;
|
558 |
+
|
559 |
+
Matrix4 _matrix;
|
560 |
+
_matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z,
|
561 |
+
u.z, v.z, normal.z, 1, 1, 1, 1);
|
562 |
+
|
563 |
+
Matrix4 _affineMatrix;
|
564 |
+
_affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1);
|
565 |
+
|
566 |
+
_matrix.invert(); // Invert the _matrix
|
567 |
+
_matrix = _affineMatrix * _matrix;
|
568 |
+
|
569 |
+
// Apply transformation
|
570 |
+
apply_matrix4(t1.a, _matrix);
|
571 |
+
apply_matrix4(t1.b, _matrix);
|
572 |
+
apply_matrix4(t1.c, _matrix);
|
573 |
+
apply_matrix4(t2.a, _matrix);
|
574 |
+
apply_matrix4(t2.b, _matrix);
|
575 |
+
apply_matrix4(t2.c, _matrix);
|
576 |
+
|
577 |
+
makeTriCounterClockwise(t1);
|
578 |
+
makeTriCounterClockwise(t2);
|
579 |
+
|
580 |
+
const uv_float3 &p1 = t1.a;
|
581 |
+
const uv_float3 &p2 = t2.a;
|
582 |
+
const uv_float3 &q2 = t2.b;
|
583 |
+
const uv_float3 &r2 = t2.c;
|
584 |
+
|
585 |
+
int o_p2q2 = orient2D(p2, q2, p1);
|
586 |
+
int o_q2r2 = orient2D(q2, r2, p1);
|
587 |
+
int o_r2p2 = orient2D(r2, p2, p1);
|
588 |
+
|
589 |
+
bool intersecting = false;
|
590 |
+
if (o_p2q2 >= 0) {
|
591 |
+
if (o_q2r2 >= 0) {
|
592 |
+
if (o_r2p2 >= 0) {
|
593 |
+
// + + +
|
594 |
+
intersecting = true;
|
595 |
+
} else {
|
596 |
+
// + + -
|
597 |
+
intersecting = intersectionTypeR1(t1, t2);
|
598 |
+
}
|
599 |
+
} else {
|
600 |
+
if (o_r2p2 >= 0) {
|
601 |
+
// + - +
|
602 |
+
permuteTriRight(t2);
|
603 |
+
intersecting = intersectionTypeR1(t1, t2);
|
604 |
+
} else {
|
605 |
+
// + - -
|
606 |
+
intersecting = intersectionTypeR2(t1, t2);
|
607 |
+
}
|
608 |
+
}
|
609 |
+
} else {
|
610 |
+
if (o_q2r2 >= 0) {
|
611 |
+
if (o_r2p2 >= 0) {
|
612 |
+
// - + +
|
613 |
+
permuteTriLeft(t2);
|
614 |
+
intersecting = intersectionTypeR1(t1, t2);
|
615 |
+
} else {
|
616 |
+
// - + -
|
617 |
+
permuteTriLeft(t2);
|
618 |
+
intersecting = intersectionTypeR2(t1, t2);
|
619 |
+
}
|
620 |
+
} else {
|
621 |
+
if (o_r2p2 >= 0) {
|
622 |
+
// - - +
|
623 |
+
permuteTriRight(t2);
|
624 |
+
intersecting = intersectionTypeR2(t1, t2);
|
625 |
+
} else {
|
626 |
+
// - - -
|
627 |
+
std::cerr << "Triangles should not be flat." << std::endl;
|
628 |
+
return false;
|
629 |
+
}
|
630 |
+
}
|
631 |
+
}
|
632 |
+
|
633 |
+
if (intersecting && target) {
|
634 |
+
clipTriangle(t1, t2, *target);
|
635 |
+
|
636 |
+
_matrix.invert();
|
637 |
+
// Apply the transform to each target point
|
638 |
+
for (int i = 0; i < target->size(); ++i) {
|
639 |
+
apply_matrix4(target->at(i), _matrix);
|
640 |
+
}
|
641 |
+
}
|
642 |
+
|
643 |
+
return intersecting;
|
644 |
+
}
|
645 |
+
|
646 |
+
// Helper function to calculate the area of a polygon
|
647 |
+
float polygon_area(const std::vector<uv_float3> &polygon) {
|
648 |
+
if (polygon.size() < 3)
|
649 |
+
return 0.0f; // Not a polygon
|
650 |
+
|
651 |
+
uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector
|
652 |
+
|
653 |
+
// Calculate the cross product of edges around the polygon
|
654 |
+
for (size_t i = 0; i < polygon.size(); ++i) {
|
655 |
+
uv_float3 p1 = polygon[i];
|
656 |
+
uv_float3 p2 = polygon[(i + 1) % polygon.size()];
|
657 |
+
|
658 |
+
normal = normal + cross(p1, p2); // Accumulate the normal vector
|
659 |
+
}
|
660 |
+
|
661 |
+
float area =
|
662 |
+
magnitude(normal) / 2.0f; // Area is half the magnitude of the normal
|
663 |
+
return area;
|
664 |
+
}
|
665 |
+
|
666 |
+
bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
|
667 |
+
uv_float2 p2, uv_float2 q2, uv_float2 r2) {
|
668 |
+
Triangle t1(p1, q1, r1);
|
669 |
+
Triangle t2(p2, q2, r2);
|
670 |
+
|
671 |
+
if (isTriDegenerated(t1) || isTriDegenerated(t2)) {
|
672 |
+
// std::cerr << "Degenerated triangles provided, skipping." << std::endl;
|
673 |
+
return false;
|
674 |
+
}
|
675 |
+
|
676 |
+
int o1a = orient3D(t2.a, t2.b, t2.c, t1.a);
|
677 |
+
int o1b = orient3D(t2.a, t2.b, t2.c, t1.b);
|
678 |
+
int o1c = orient3D(t2.a, t2.b, t2.c, t1.c);
|
679 |
+
|
680 |
+
std::vector<uv_float3> intersections;
|
681 |
+
bool intersects;
|
682 |
+
|
683 |
+
if (o1a == o1b && o1a == o1c) // [[likely]]
|
684 |
+
{
|
685 |
+
intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections);
|
686 |
+
} else // [[unlikely]]
|
687 |
+
{
|
688 |
+
intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections);
|
689 |
+
}
|
690 |
+
|
691 |
+
if (intersects) {
|
692 |
+
float area = polygon_area(intersections);
|
693 |
+
|
694 |
+
// std::cout << "Intersection area: " << area << std::endl;
|
695 |
+
if (area < 1e-10f || std::isfinite(area) == false) {
|
696 |
+
// std::cout<<"Invalid area: " << area << std::endl;
|
697 |
+
return false; // Ignore intersection if the area is too small
|
698 |
+
}
|
699 |
+
}
|
700 |
+
|
701 |
+
return intersects;
|
702 |
+
}
|
uv_unwrapper/uv_unwrapper/csrc/intersect.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "common.h"
|
4 |
+
#include <vector>
|
5 |
+
|
6 |
+
bool triangle_aabb_intersection(const uv_float2 &aabb_min,
|
7 |
+
const uv_float2 &aabb_max, const uv_float2 &v0,
|
8 |
+
const uv_float2 &v1, const uv_float2 &v2);
|
9 |
+
bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
|
10 |
+
uv_float2 p2, uv_float2 q2, uv_float2 r2);
|
uv_unwrapper/uv_unwrapper/csrc/unwrapper.cpp
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "bvh.h"
|
2 |
+
#include <ATen/ATen.h>
|
3 |
+
#include <ATen/Context.h>
|
4 |
+
#include <chrono>
|
5 |
+
#include <cmath>
|
6 |
+
#include <cstring>
|
7 |
+
#include <omp.h>
|
8 |
+
#include <set>
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <vector>
|
11 |
+
|
12 |
+
// #define TIMING
|
13 |
+
|
14 |
+
#if defined(_MSC_VER)
|
15 |
+
#include <BaseTsd.h>
|
16 |
+
typedef SSIZE_T ssize_t;
|
17 |
+
#endif
|
18 |
+
|
19 |
+
namespace UVUnwrapper {
|
20 |
+
void create_bvhs(BVH *bvhs, Triangle *triangles,
|
21 |
+
std::vector<std::set<int>> &triangle_per_face, int num_faces,
|
22 |
+
int start, int end) {
|
23 |
+
#pragma omp parallel for
|
24 |
+
for (int i = start; i < end; i++) {
|
25 |
+
int num_triangles = triangle_per_face[i].size();
|
26 |
+
Triangle *triangles_per_face = new Triangle[num_triangles];
|
27 |
+
int *indices = new int[num_triangles];
|
28 |
+
int j = 0;
|
29 |
+
for (int idx : triangle_per_face[i]) {
|
30 |
+
triangles_per_face[j] = triangles[idx];
|
31 |
+
indices[j++] = idx;
|
32 |
+
}
|
33 |
+
// Each thread writes to it's own memory space
|
34 |
+
// First check if the number of triangles is 0
|
35 |
+
if (num_triangles == 0) {
|
36 |
+
bvhs[i - start] = std::move(BVH()); // Default constructor
|
37 |
+
} else {
|
38 |
+
bvhs[i - start] = std::move(
|
39 |
+
BVH(triangles_per_face, indices,
|
40 |
+
num_triangles)); // BVH now handles memory of triangles_per_face
|
41 |
+
}
|
42 |
+
delete[] triangles_per_face;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
void perform_intersection_check(BVH *bvhs, int num_bvhs, Triangle *triangles,
|
47 |
+
uv_float3 *vertex_tri_centroids,
|
48 |
+
int64_t *assign_indices_ptr,
|
49 |
+
ssize_t num_indices, int offset,
|
50 |
+
std::vector<std::set<int>> &triangle_per_face) {
|
51 |
+
std::vector<std::pair<int, int>>
|
52 |
+
unique_intersections; // Store unique intersections as pairs of triangle
|
53 |
+
// indices
|
54 |
+
|
55 |
+
// Step 1: Detect intersections in parallel
|
56 |
+
#pragma omp parallel for
|
57 |
+
for (int i = 0; i < num_indices; i++) {
|
58 |
+
if (assign_indices_ptr[i] < offset) {
|
59 |
+
continue;
|
60 |
+
}
|
61 |
+
|
62 |
+
Triangle cur_tri = triangles[i];
|
63 |
+
auto &cur_bvh = bvhs[assign_indices_ptr[i] - offset];
|
64 |
+
|
65 |
+
if (cur_bvh.bvhNode == nullptr) {
|
66 |
+
continue;
|
67 |
+
}
|
68 |
+
|
69 |
+
std::vector<int> intersections = cur_bvh.Intersect(cur_tri);
|
70 |
+
|
71 |
+
if (!intersections.empty()) {
|
72 |
+
|
73 |
+
#pragma omp critical
|
74 |
+
{
|
75 |
+
for (int intersect : intersections) {
|
76 |
+
if (i != intersect) {
|
77 |
+
// Ensure we only store unique pairs (A, B) where A < B to avoid
|
78 |
+
// duplication
|
79 |
+
if (i < intersect) {
|
80 |
+
unique_intersections.push_back(std::make_pair(i, intersect));
|
81 |
+
} else {
|
82 |
+
unique_intersections.push_back(std::make_pair(intersect, i));
|
83 |
+
}
|
84 |
+
}
|
85 |
+
}
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
// Step 2: Process unique intersections
|
91 |
+
for (int idx = 0; idx < unique_intersections.size(); idx++) {
|
92 |
+
int first = unique_intersections[idx].first;
|
93 |
+
int second = unique_intersections[idx].second;
|
94 |
+
|
95 |
+
int i_idx = assign_indices_ptr[first];
|
96 |
+
|
97 |
+
int norm_idx = i_idx % 6;
|
98 |
+
int axis = (norm_idx < 2) ? 0 : (norm_idx < 4) ? 1 : 2;
|
99 |
+
bool use_max = (i_idx % 2) == 1;
|
100 |
+
|
101 |
+
float pos_a = vertex_tri_centroids[first][axis];
|
102 |
+
float pos_b = vertex_tri_centroids[second][axis];
|
103 |
+
// Sort the intersections based on vertex_tri_centroids along the specified
|
104 |
+
// axis
|
105 |
+
if (use_max) {
|
106 |
+
if (pos_a < pos_b) {
|
107 |
+
std::swap(first, second);
|
108 |
+
}
|
109 |
+
} else {
|
110 |
+
if (pos_a > pos_b) {
|
111 |
+
std::swap(first, second);
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
// Update the unique intersections
|
116 |
+
unique_intersections[idx].first = first;
|
117 |
+
unique_intersections[idx].second = second;
|
118 |
+
}
|
119 |
+
|
120 |
+
// Now only get the second intersections from the pair and put them in a set
|
121 |
+
// The second intersection should always be the occluded triangle
|
122 |
+
std::set<int> second_intersections;
|
123 |
+
for (int idx = 0; idx < (int)unique_intersections.size(); idx++) {
|
124 |
+
int second = unique_intersections[idx].second;
|
125 |
+
second_intersections.insert(second);
|
126 |
+
}
|
127 |
+
|
128 |
+
for (int int_idx : second_intersections) {
|
129 |
+
// Move the second (occluded) triangle by 6
|
130 |
+
int intersect_idx = assign_indices_ptr[int_idx];
|
131 |
+
int new_index = intersect_idx + 6;
|
132 |
+
new_index = std::clamp(new_index, 0, 12);
|
133 |
+
|
134 |
+
assign_indices_ptr[int_idx] = new_index;
|
135 |
+
triangle_per_face[intersect_idx].erase(int_idx);
|
136 |
+
triangle_per_face[new_index].insert(int_idx);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
torch::Tensor assign_faces_uv_to_atlas_index(torch::Tensor vertices,
|
141 |
+
torch::Tensor indices,
|
142 |
+
torch::Tensor face_uv,
|
143 |
+
torch::Tensor face_index) {
|
144 |
+
// Get the number of faces
|
145 |
+
int num_faces = indices.size(0);
|
146 |
+
torch::Tensor assign_indices =
|
147 |
+
torch::empty(
|
148 |
+
{
|
149 |
+
num_faces,
|
150 |
+
},
|
151 |
+
torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU))
|
152 |
+
.contiguous();
|
153 |
+
|
154 |
+
auto vert_accessor = vertices.accessor<float, 2>();
|
155 |
+
auto indices_accessor = indices.accessor<int64_t, 2>();
|
156 |
+
auto face_uv_accessor = face_uv.accessor<float, 2>();
|
157 |
+
|
158 |
+
const int64_t *face_index_ptr = face_index.contiguous().data_ptr<int64_t>();
|
159 |
+
int64_t *assign_indices_ptr = assign_indices.data_ptr<int64_t>();
|
160 |
+
// copy face_index to assign_indices
|
161 |
+
memcpy(assign_indices_ptr, face_index_ptr, num_faces * sizeof(int64_t));
|
162 |
+
|
163 |
+
#ifdef TIMING
|
164 |
+
auto start = std::chrono::high_resolution_clock::now();
|
165 |
+
#endif
|
166 |
+
uv_float3 *vertex_tri_centroids = new uv_float3[num_faces];
|
167 |
+
Triangle *triangles = new Triangle[num_faces];
|
168 |
+
|
169 |
+
// Use std::set to store triangles for each face
|
170 |
+
std::vector<std::set<int>> triangle_per_face;
|
171 |
+
triangle_per_face.resize(13);
|
172 |
+
|
173 |
+
#pragma omp parallel for
|
174 |
+
for (int i = 0; i < num_faces; i++) {
|
175 |
+
int face_idx = i * 3;
|
176 |
+
triangles[i].v0 = {face_uv_accessor[face_idx + 0][0],
|
177 |
+
face_uv_accessor[face_idx + 0][1]};
|
178 |
+
triangles[i].v1 = {face_uv_accessor[face_idx + 1][0],
|
179 |
+
face_uv_accessor[face_idx + 1][1]};
|
180 |
+
triangles[i].v2 = {face_uv_accessor[face_idx + 2][0],
|
181 |
+
face_uv_accessor[face_idx + 2][1]};
|
182 |
+
triangles[i].centroid =
|
183 |
+
triangle_centroid(triangles[i].v0, triangles[i].v1, triangles[i].v2);
|
184 |
+
|
185 |
+
uv_float3 v0 = {vert_accessor[indices_accessor[i][0]][0],
|
186 |
+
vert_accessor[indices_accessor[i][0]][1],
|
187 |
+
vert_accessor[indices_accessor[i][0]][2]};
|
188 |
+
uv_float3 v1 = {vert_accessor[indices_accessor[i][1]][0],
|
189 |
+
vert_accessor[indices_accessor[i][1]][1],
|
190 |
+
vert_accessor[indices_accessor[i][1]][2]};
|
191 |
+
uv_float3 v2 = {vert_accessor[indices_accessor[i][2]][0],
|
192 |
+
vert_accessor[indices_accessor[i][2]][1],
|
193 |
+
vert_accessor[indices_accessor[i][2]][2]};
|
194 |
+
vertex_tri_centroids[i] = triangle_centroid(v0, v1, v2);
|
195 |
+
|
196 |
+
// Assign the triangle to the face index
|
197 |
+
#pragma omp critical
|
198 |
+
{ triangle_per_face[face_index_ptr[i]].insert(i); }
|
199 |
+
}
|
200 |
+
|
201 |
+
#ifdef TIMING
|
202 |
+
auto start_bvh = std::chrono::high_resolution_clock::now();
|
203 |
+
#endif
|
204 |
+
|
205 |
+
BVH *bvhs = new BVH[6];
|
206 |
+
create_bvhs(bvhs, triangles, triangle_per_face, num_faces, 0, 6);
|
207 |
+
|
208 |
+
#ifdef TIMING
|
209 |
+
auto end_bvh = std::chrono::high_resolution_clock::now();
|
210 |
+
std::chrono::duration<double> elapsed_seconds = end_bvh - start_bvh;
|
211 |
+
std::cout << "BVH build time: " << elapsed_seconds.count() << "s\n";
|
212 |
+
|
213 |
+
auto start_intersection_1 = std::chrono::high_resolution_clock::now();
|
214 |
+
#endif
|
215 |
+
|
216 |
+
perform_intersection_check(bvhs, 6, triangles, vertex_tri_centroids,
|
217 |
+
assign_indices_ptr, num_faces, 0,
|
218 |
+
triangle_per_face);
|
219 |
+
|
220 |
+
#ifdef TIMING
|
221 |
+
auto end_intersection_1 = std::chrono::high_resolution_clock::now();
|
222 |
+
elapsed_seconds = end_intersection_1 - start_intersection_1;
|
223 |
+
std::cout << "Intersection 1 time: " << elapsed_seconds.count() << "s\n";
|
224 |
+
#endif
|
225 |
+
// Create 6 new bvhs and delete the old ones
|
226 |
+
BVH *new_bvhs = new BVH[6];
|
227 |
+
create_bvhs(new_bvhs, triangles, triangle_per_face, num_faces, 6, 12);
|
228 |
+
|
229 |
+
#ifdef TIMING
|
230 |
+
auto end_bvh2 = std::chrono::high_resolution_clock::now();
|
231 |
+
elapsed_seconds = end_bvh2 - end_intersection_1;
|
232 |
+
std::cout << "BVH 2 build time: " << elapsed_seconds.count() << "s\n";
|
233 |
+
auto start_intersection_2 = std::chrono::high_resolution_clock::now();
|
234 |
+
#endif
|
235 |
+
|
236 |
+
perform_intersection_check(new_bvhs, 6, triangles, vertex_tri_centroids,
|
237 |
+
assign_indices_ptr, num_faces, 6,
|
238 |
+
triangle_per_face);
|
239 |
+
|
240 |
+
#ifdef TIMING
|
241 |
+
auto end_intersection_2 = std::chrono::high_resolution_clock::now();
|
242 |
+
elapsed_seconds = end_intersection_2 - start_intersection_2;
|
243 |
+
std::cout << "Intersection 2 time: " << elapsed_seconds.count() << "s\n";
|
244 |
+
elapsed_seconds = end_intersection_2 - start;
|
245 |
+
std::cout << "Total time: " << elapsed_seconds.count() << "s\n";
|
246 |
+
#endif
|
247 |
+
|
248 |
+
// Cleanup
|
249 |
+
delete[] vertex_tri_centroids;
|
250 |
+
delete[] triangles;
|
251 |
+
delete[] bvhs;
|
252 |
+
delete[] new_bvhs;
|
253 |
+
|
254 |
+
return assign_indices;
|
255 |
+
}
|
256 |
+
|
257 |
+
// Registers _C as a Python extension module.
|
258 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
259 |
+
|
260 |
+
// Defines the operators
|
261 |
+
TORCH_LIBRARY(UVUnwrapper, m) {
|
262 |
+
m.def("assign_faces_uv_to_atlas_index(Tensor vertices, Tensor indices, "
|
263 |
+
"Tensor face_uv, Tensor face_index) -> Tensor");
|
264 |
+
}
|
265 |
+
|
266 |
+
// Registers CPP implementations
|
267 |
+
TORCH_LIBRARY_IMPL(UVUnwrapper, CPU, m) {
|
268 |
+
m.impl("assign_faces_uv_to_atlas_index", &assign_faces_uv_to_atlas_index);
|
269 |
+
}
|
270 |
+
|
271 |
+
} // namespace UVUnwrapper
|
uv_unwrapper/uv_unwrapper/unwrap.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
class Unwrapper(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def _box_assign_vertex_to_cube_face(
|
15 |
+
self,
|
16 |
+
vertex_positions: Tensor,
|
17 |
+
vertex_normals: Tensor,
|
18 |
+
triangle_idxs: Tensor,
|
19 |
+
bbox: Tensor,
|
20 |
+
) -> Tuple[Tensor, Tensor]:
|
21 |
+
"""
|
22 |
+
Assigns each vertex to a cube face based on the face normal
|
23 |
+
|
24 |
+
Args:
|
25 |
+
vertex_positions (Tensor, Nv 3, float): Vertex positions
|
26 |
+
vertex_normals (Tensor, Nv 3, float): Vertex normals
|
27 |
+
triangle_idxs (Tensor, Nf 3, int): Triangle indices
|
28 |
+
bbox (Tensor, 2 3, float): Bounding box of the mesh
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Tensor, Nf 3 2, float: UV coordinates
|
32 |
+
Tensor, Nf, int: Cube face indices
|
33 |
+
"""
|
34 |
+
|
35 |
+
# Test to not have a scaled model to fit the space better
|
36 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
37 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
38 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
39 |
+
|
40 |
+
# Create a [0, 1] normalized vertex position
|
41 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
42 |
+
# And to [-1, 1]
|
43 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
44 |
+
|
45 |
+
# Get all vertex positions for each triangle
|
46 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
47 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
48 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
49 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
50 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
51 |
+
|
52 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
53 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
54 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
55 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
56 |
+
|
57 |
+
# Just average the normals per face
|
58 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
59 |
+
|
60 |
+
# Now decide based on the face normal in which box map we project
|
61 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
62 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
63 |
+
|
64 |
+
axis = torch.tensor(
|
65 |
+
[
|
66 |
+
[1, 0, 0], # 0
|
67 |
+
[-1, 0, 0], # 1
|
68 |
+
[0, 1, 0], # 2
|
69 |
+
[0, -1, 0], # 3
|
70 |
+
[0, 0, 1], # 4
|
71 |
+
[0, 0, -1], # 5
|
72 |
+
],
|
73 |
+
device=face_normal.device,
|
74 |
+
dtype=face_normal.dtype,
|
75 |
+
)
|
76 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
77 |
+
index = face_normal_axis.argmax(-1)
|
78 |
+
|
79 |
+
max_axis, uc, vc = (
|
80 |
+
torch.ones_like(abs_x),
|
81 |
+
torch.zeros_like(tri_stack[..., :1]),
|
82 |
+
torch.zeros_like(tri_stack[..., :1]),
|
83 |
+
)
|
84 |
+
mask_pos_x = index == 0
|
85 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
86 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
87 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
88 |
+
|
89 |
+
mask_neg_x = index == 1
|
90 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
91 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
92 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
93 |
+
|
94 |
+
mask_pos_y = index == 2
|
95 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
96 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
97 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
98 |
+
|
99 |
+
mask_neg_y = index == 3
|
100 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
101 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
102 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
103 |
+
|
104 |
+
mask_pos_z = index == 4
|
105 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
106 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
107 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
108 |
+
|
109 |
+
mask_neg_z = index == 5
|
110 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
111 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
112 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
113 |
+
|
114 |
+
# UC from [-1, 1] to [0, 1]
|
115 |
+
max_dim_div = max_axis.max(dim=0, keepdim=True).values
|
116 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
117 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
118 |
+
|
119 |
+
uv = torch.stack([uc, vc], dim=-1)
|
120 |
+
|
121 |
+
return uv, index
|
122 |
+
|
123 |
+
def _assign_faces_uv_to_atlas_index(
|
124 |
+
self,
|
125 |
+
vertex_positions: Tensor,
|
126 |
+
triangle_idxs: Tensor,
|
127 |
+
face_uv: Tensor,
|
128 |
+
face_index: Tensor,
|
129 |
+
) -> Tensor: # noqa: F821
|
130 |
+
"""
|
131 |
+
Assigns the face UV to the atlas index
|
132 |
+
|
133 |
+
Args:
|
134 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
135 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
136 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
137 |
+
face_index (Integer[Tensor, "Nf"]): Face indices
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Integer[Tensor, "Nf"]: Atlas index
|
141 |
+
"""
|
142 |
+
return torch.ops.UVUnwrapper.assign_faces_uv_to_atlas_index(
|
143 |
+
vertex_positions.cpu(),
|
144 |
+
triangle_idxs.cpu(),
|
145 |
+
face_uv.view(-1, 2).cpu(),
|
146 |
+
face_index.cpu(),
|
147 |
+
).to(vertex_positions.device)
|
148 |
+
|
149 |
+
def _find_slice_offset_and_scale(
|
150 |
+
self, index: Tensor
|
151 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # noqa: F821
|
152 |
+
"""
|
153 |
+
Find the slice offset and scale
|
154 |
+
|
155 |
+
Args:
|
156 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
Float[Tensor, "Nf"]: Offset x
|
160 |
+
Float[Tensor, "Nf"]: Offset y
|
161 |
+
Float[Tensor, "Nf"]: Division x
|
162 |
+
Float[Tensor, "Nf"]: Division y
|
163 |
+
"""
|
164 |
+
|
165 |
+
# 6 due to the 6 cube faces
|
166 |
+
off = 1 / 3
|
167 |
+
dupl_off = 1 / 6
|
168 |
+
|
169 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
170 |
+
def x_offset_calc(x, i):
|
171 |
+
offset_calc = i // 6
|
172 |
+
# Initial coordinates - just 3x2 grid
|
173 |
+
if offset_calc == 0:
|
174 |
+
return off * x
|
175 |
+
else:
|
176 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
177 |
+
# second overlap
|
178 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
179 |
+
|
180 |
+
def y_offset_calc(x, i):
|
181 |
+
offset_calc = i // 6
|
182 |
+
# Initial coordinates - just a 3x2 grid
|
183 |
+
if offset_calc == 0:
|
184 |
+
return off * x
|
185 |
+
else:
|
186 |
+
# Smaller coordinates in the lowest row
|
187 |
+
return dupl_off * x + off * 2
|
188 |
+
|
189 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
190 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
191 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
192 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
193 |
+
for i in range(index.max().item() + 1):
|
194 |
+
mask = index == i
|
195 |
+
if not mask.any():
|
196 |
+
continue
|
197 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
198 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
199 |
+
|
200 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
201 |
+
# All overlap elements are saved in half scale
|
202 |
+
div_x[index >= 6] = 6
|
203 |
+
div_y = div_x.clone() # Same for y
|
204 |
+
# Except for the random overlaps
|
205 |
+
div_x[index >= 12] = 2
|
206 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
207 |
+
div_y[index >= 12] = 3
|
208 |
+
|
209 |
+
return offset_x, offset_y, div_x, div_y
|
210 |
+
|
211 |
+
def _calculate_tangents(
|
212 |
+
self,
|
213 |
+
vertex_positions: Tensor,
|
214 |
+
vertex_normals: Tensor,
|
215 |
+
triangle_idxs: Tensor,
|
216 |
+
face_uv: Tensor,
|
217 |
+
) -> Tensor:
|
218 |
+
"""
|
219 |
+
Calculate the tangents for each triangle
|
220 |
+
|
221 |
+
Args:
|
222 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
223 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
224 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
225 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Float[Tensor, "Nf 3 4"]: Tangents
|
229 |
+
"""
|
230 |
+
vn_idx = [None] * 3
|
231 |
+
pos = [None] * 3
|
232 |
+
tex = face_uv.unbind(1)
|
233 |
+
for i in range(0, 3):
|
234 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
235 |
+
# t_nrm_idx is always the same as t_pos_idx
|
236 |
+
vn_idx[i] = triangle_idxs[:, i]
|
237 |
+
|
238 |
+
if(torch.backends.mps.is_available()):
|
239 |
+
tangents = torch.zeros_like(vertex_normals).contiguous()
|
240 |
+
tansum = torch.zeros_like(vertex_normals).contiguous()
|
241 |
+
else:
|
242 |
+
tangents = torch.zeros_like(vertex_normals)
|
243 |
+
tansum = torch.zeros_like(vertex_normals)
|
244 |
+
|
245 |
+
# Compute tangent space for each triangle
|
246 |
+
duv1 = tex[1] - tex[0]
|
247 |
+
duv2 = tex[2] - tex[0]
|
248 |
+
dpos1 = pos[1] - pos[0]
|
249 |
+
dpos2 = pos[2] - pos[0]
|
250 |
+
|
251 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
252 |
+
|
253 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
254 |
+
|
255 |
+
# Avoid division by zero for degenerated texture coordinates
|
256 |
+
denom_safe = denom.clip(1e-6)
|
257 |
+
tang = tng_nom / denom_safe
|
258 |
+
|
259 |
+
# Update all 3 vertices
|
260 |
+
for i in range(0, 3):
|
261 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
262 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
263 |
+
tansum.scatter_add_(
|
264 |
+
0, idx, torch.ones_like(tang)
|
265 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
266 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
267 |
+
# triangles influence the tangent space more
|
268 |
+
tangents = tangents / tansum
|
269 |
+
|
270 |
+
# Normalize and make sure tangent is perpendicular to normal
|
271 |
+
tangents = F.normalize(tangents, dim=1)
|
272 |
+
tangents = F.normalize(
|
273 |
+
tangents
|
274 |
+
- (tangents * vertex_normals).sum(-1, keepdim=True) * vertex_normals
|
275 |
+
)
|
276 |
+
|
277 |
+
return tangents
|
278 |
+
|
279 |
+
def _rotate_uv_slices_consistent_space(
|
280 |
+
self,
|
281 |
+
vertex_positions: Tensor,
|
282 |
+
vertex_normals: Tensor,
|
283 |
+
triangle_idxs: Tensor,
|
284 |
+
uv: Tensor,
|
285 |
+
index: Tensor,
|
286 |
+
) -> Tensor:
|
287 |
+
"""
|
288 |
+
Rotate the UV slices so they are in a consistent space
|
289 |
+
|
290 |
+
Args:
|
291 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
292 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
293 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
294 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
295 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Float[Tensor, "Nf 3 2"]: Rotated UV coordinates
|
299 |
+
"""
|
300 |
+
|
301 |
+
tangents = self._calculate_tangents(
|
302 |
+
vertex_positions, vertex_normals, triangle_idxs, uv
|
303 |
+
)
|
304 |
+
pos_stack = torch.stack(
|
305 |
+
[
|
306 |
+
-vertex_positions[..., 1],
|
307 |
+
vertex_positions[..., 0],
|
308 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
309 |
+
],
|
310 |
+
dim=-1,
|
311 |
+
)
|
312 |
+
expected_tangents = F.normalize(
|
313 |
+
torch.linalg.cross(
|
314 |
+
vertex_normals,
|
315 |
+
torch.linalg.cross(pos_stack, vertex_normals, dim=-1),
|
316 |
+
dim=-1,
|
317 |
+
),
|
318 |
+
-1,
|
319 |
+
)
|
320 |
+
|
321 |
+
actual_tangents = tangents[triangle_idxs]
|
322 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
323 |
+
|
324 |
+
def rotation_matrix_2d(theta):
|
325 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
326 |
+
return torch.tensor([[c, -s], [s, c]])
|
327 |
+
|
328 |
+
# Now find the rotation
|
329 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
330 |
+
for i in range(6):
|
331 |
+
mask = index_mod == i
|
332 |
+
if not mask.any():
|
333 |
+
continue
|
334 |
+
|
335 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
336 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
337 |
+
|
338 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
339 |
+
cross_product = (
|
340 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
341 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
342 |
+
)
|
343 |
+
angle = torch.atan2(cross_product, dot_product)
|
344 |
+
|
345 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
346 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
347 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
348 |
+
# Rotate it
|
349 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
350 |
+
|
351 |
+
# Rescale uv[mask] to be within the 0-1 range
|
352 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
353 |
+
|
354 |
+
return uv
|
355 |
+
|
356 |
+
def _handle_slice_uvs(
|
357 |
+
self,
|
358 |
+
uv: Tensor,
|
359 |
+
index: Tensor, # noqa: F821
|
360 |
+
island_padding: float,
|
361 |
+
max_index: int = 6 * 2,
|
362 |
+
) -> Tensor: # noqa: F821
|
363 |
+
"""
|
364 |
+
Handle the slice UVs
|
365 |
+
|
366 |
+
Args:
|
367 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
368 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
369 |
+
island_padding (float): Island padding
|
370 |
+
max_index (int): Maximum index
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
374 |
+
|
375 |
+
"""
|
376 |
+
uc, vc = uv.unbind(-1)
|
377 |
+
|
378 |
+
# Get the second slice (The first overlap)
|
379 |
+
index_filter = [index == i for i in range(6, max_index)]
|
380 |
+
|
381 |
+
# Normalize them to always fully fill the atlas patch
|
382 |
+
for i, fi in enumerate(index_filter):
|
383 |
+
if fi.sum() > 0:
|
384 |
+
# Scale the slice but only up to a factor of 2
|
385 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
386 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(
|
387 |
+
0.5
|
388 |
+
)
|
389 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(
|
390 |
+
0.5
|
391 |
+
)
|
392 |
+
|
393 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
394 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
395 |
+
|
396 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
397 |
+
|
398 |
+
def _handle_remaining_uvs(
|
399 |
+
self,
|
400 |
+
uv: Tensor,
|
401 |
+
index: Tensor, # noqa: F821
|
402 |
+
island_padding: float,
|
403 |
+
) -> Tensor:
|
404 |
+
"""
|
405 |
+
Handle the remaining UVs (The ones that are not slices)
|
406 |
+
|
407 |
+
Args:
|
408 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
409 |
+
index (Integer[Tensor, "Nf"]): Atlas index
|
410 |
+
island_padding (float): Island padding
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
414 |
+
"""
|
415 |
+
uc, vc = uv.unbind(-1)
|
416 |
+
# Get all remaining elements
|
417 |
+
remaining_filter = index >= 6 * 2
|
418 |
+
squares_left = remaining_filter.sum()
|
419 |
+
|
420 |
+
if squares_left == 0:
|
421 |
+
return uv
|
422 |
+
|
423 |
+
uc = uc[remaining_filter]
|
424 |
+
vc = vc[remaining_filter]
|
425 |
+
|
426 |
+
# Or remaining triangles are distributed in a rectangle
|
427 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
428 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
429 |
+
# sqrt(744/(0.5*(1/3)))
|
430 |
+
|
431 |
+
mult = math.sqrt(squares_left / ratio)
|
432 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
433 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
434 |
+
|
435 |
+
width = 1 / num_square_width
|
436 |
+
height = 1 / num_square_height
|
437 |
+
|
438 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
439 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
440 |
+
# assumes full coverage.
|
441 |
+
clip_val = min(width, height) * 1.5
|
442 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
443 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
444 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
445 |
+
).clip(clip_val)
|
446 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
447 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
448 |
+
).clip(clip_val)
|
449 |
+
# Add a small padding
|
450 |
+
uc = (
|
451 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
452 |
+
+ island_padding * num_square_width * 0.25
|
453 |
+
).clip(0, 1)
|
454 |
+
vc = (
|
455 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
456 |
+
+ island_padding * num_square_height * 0.25
|
457 |
+
).clip(0, 1)
|
458 |
+
|
459 |
+
uc = uc * width
|
460 |
+
vc = vc * height
|
461 |
+
|
462 |
+
# And calculate offsets for each element
|
463 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
464 |
+
x_idx = idx % num_square_width
|
465 |
+
y_idx = idx // num_square_width
|
466 |
+
# And move each triangle to its own spot
|
467 |
+
uc = uc + x_idx[:, None] * width
|
468 |
+
vc = vc + y_idx[:, None] * height
|
469 |
+
|
470 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
471 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
472 |
+
|
473 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
474 |
+
|
475 |
+
return uv
|
476 |
+
|
477 |
+
def _distribute_individual_uvs_in_atlas(
|
478 |
+
self,
|
479 |
+
face_uv: Tensor,
|
480 |
+
assigned_faces: Tensor,
|
481 |
+
offset_x: Tensor,
|
482 |
+
offset_y: Tensor,
|
483 |
+
div_x: Tensor,
|
484 |
+
div_y: Tensor,
|
485 |
+
island_padding: float,
|
486 |
+
) -> Tensor:
|
487 |
+
"""
|
488 |
+
Distribute the individual UVs in the atlas
|
489 |
+
|
490 |
+
Args:
|
491 |
+
face_uv (Float[Tensor, "Nf 3 2"]): Face UV coordinates
|
492 |
+
assigned_faces (Integer[Tensor, "Nf"]): Assigned faces
|
493 |
+
offset_x (Float[Tensor, "Nf"]): Offset x
|
494 |
+
offset_y (Float[Tensor, "Nf"]): Offset y
|
495 |
+
div_x (Float[Tensor, "Nf"]): Division x
|
496 |
+
div_y (Float[Tensor, "Nf"]): Division y
|
497 |
+
island_padding (float): Island padding
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
Float[Tensor, "Nf 3 2"]: Updated UV coordinates
|
501 |
+
"""
|
502 |
+
# Place the slice first
|
503 |
+
placed_uv = self._handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
504 |
+
# Then handle the remaining overlap elements
|
505 |
+
placed_uv = self._handle_remaining_uvs(
|
506 |
+
placed_uv, assigned_faces, island_padding
|
507 |
+
)
|
508 |
+
|
509 |
+
uc, vc = placed_uv.unbind(-1)
|
510 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
511 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
512 |
+
|
513 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
514 |
+
|
515 |
+
return uv
|
516 |
+
|
517 |
+
def _get_unique_face_uv(
|
518 |
+
self,
|
519 |
+
uv: Tensor,
|
520 |
+
) -> Tuple[Tensor, Tensor]:
|
521 |
+
"""
|
522 |
+
Get the unique face UV
|
523 |
+
|
524 |
+
Args:
|
525 |
+
uv (Float[Tensor, "Nf 3 2"]): UV coordinates
|
526 |
+
|
527 |
+
Returns:
|
528 |
+
Float[Tensor, "Utex 3"]: Unique UV coordinates
|
529 |
+
Integer[Tensor, "Nf"]: Vertex index
|
530 |
+
"""
|
531 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
532 |
+
# And add the face to uv index mapping
|
533 |
+
vtex_idx = unique_idx.view(-1, 3)
|
534 |
+
|
535 |
+
return unique_uv, vtex_idx
|
536 |
+
|
537 |
+
def _align_mesh_with_main_axis(
|
538 |
+
self, vertex_positions: Tensor, vertex_normals: Tensor
|
539 |
+
) -> Tuple[Tensor, Tensor]:
|
540 |
+
"""
|
541 |
+
Align the mesh with the main axis
|
542 |
+
|
543 |
+
Args:
|
544 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
545 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
Float[Tensor, "Nv 3"]: Rotated vertex positions
|
549 |
+
Float[Tensor, "Nv 3"]: Rotated vertex normals
|
550 |
+
"""
|
551 |
+
|
552 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
553 |
+
# Set the random seed so it's repeatable
|
554 |
+
torch.manual_seed(0)
|
555 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
556 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
557 |
+
|
558 |
+
main_axis = F.normalize(main_axis, eps=1e-6, dim=-1) # 3,
|
559 |
+
# Orthogonalize the second axis
|
560 |
+
seconday_axis = F.normalize(
|
561 |
+
seconday_axis
|
562 |
+
- (seconday_axis * main_axis).sum(-1, keepdim=True) * main_axis,
|
563 |
+
eps=1e-6,
|
564 |
+
dim=-1,
|
565 |
+
) # 3,
|
566 |
+
# Create perpendicular third axis
|
567 |
+
third_axis = F.normalize(
|
568 |
+
torch.cross(main_axis, seconday_axis, dim=-1), dim=-1, eps=1e-6
|
569 |
+
) # 3,
|
570 |
+
|
571 |
+
# Check to which canonical axis each aligns
|
572 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
573 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
574 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
575 |
+
|
576 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
577 |
+
# If two axes have the same argmax move one of them
|
578 |
+
all_possible_axis = {0, 1, 2}
|
579 |
+
cur_index = 1
|
580 |
+
while (
|
581 |
+
len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]))
|
582 |
+
!= 3
|
583 |
+
):
|
584 |
+
# Find missing axis
|
585 |
+
missing_axis = all_possible_axis - set(
|
586 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
587 |
+
)
|
588 |
+
missing_axis = missing_axis.pop()
|
589 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
590 |
+
# overall shape
|
591 |
+
if cur_index == 1:
|
592 |
+
third_axis_max_idx = missing_axis
|
593 |
+
elif cur_index == 2:
|
594 |
+
seconday_axis_max_idx = missing_axis
|
595 |
+
else:
|
596 |
+
raise ValueError("Could not find 3 unique axis")
|
597 |
+
cur_index += 1
|
598 |
+
|
599 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
600 |
+
raise ValueError("Could not find 3 unique axis")
|
601 |
+
|
602 |
+
axes = [None] * 3
|
603 |
+
axes[main_axis_max_idx] = main_axis
|
604 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
605 |
+
axes[third_axis_max_idx] = third_axis
|
606 |
+
# Create rotation matrix from the individual axes
|
607 |
+
rot_mat = torch.stack(axes, dim=1).T
|
608 |
+
|
609 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
610 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
611 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
612 |
+
|
613 |
+
return vertex_positions, vertex_normals
|
614 |
+
|
615 |
+
def forward(
|
616 |
+
self,
|
617 |
+
vertex_positions: Tensor,
|
618 |
+
vertex_normals: Tensor,
|
619 |
+
triangle_idxs: Tensor,
|
620 |
+
island_padding: float,
|
621 |
+
) -> Tuple[Tensor, Tensor]:
|
622 |
+
"""
|
623 |
+
Unwrap the mesh
|
624 |
+
|
625 |
+
Args:
|
626 |
+
vertex_positions (Float[Tensor, "Nv 3"]): Vertex positions
|
627 |
+
vertex_normals (Float[Tensor, "Nv 3"]): Vertex normals
|
628 |
+
triangle_idxs (Integer[Tensor, "Nf 3"]): Triangle indices
|
629 |
+
island_padding (float): Island padding
|
630 |
+
|
631 |
+
Returns:
|
632 |
+
Float[Tensor, "Utex 3"]: Unique UV coordinates
|
633 |
+
Integer[Tensor, "Nf"]: Vertex index
|
634 |
+
"""
|
635 |
+
vertex_positions, vertex_normals = self._align_mesh_with_main_axis(
|
636 |
+
vertex_positions, vertex_normals
|
637 |
+
)
|
638 |
+
bbox = torch.stack(
|
639 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values],
|
640 |
+
dim=0,
|
641 |
+
) # 2, 3
|
642 |
+
|
643 |
+
face_uv, face_index = self._box_assign_vertex_to_cube_face(
|
644 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
645 |
+
)
|
646 |
+
|
647 |
+
face_uv = self._rotate_uv_slices_consistent_space(
|
648 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
649 |
+
)
|
650 |
+
|
651 |
+
assigned_atlas_index = self._assign_faces_uv_to_atlas_index(
|
652 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
653 |
+
)
|
654 |
+
|
655 |
+
offset_x, offset_y, div_x, div_y = self._find_slice_offset_and_scale(
|
656 |
+
assigned_atlas_index
|
657 |
+
)
|
658 |
+
|
659 |
+
placed_uv = self._distribute_individual_uvs_in_atlas(
|
660 |
+
face_uv,
|
661 |
+
assigned_atlas_index,
|
662 |
+
offset_x,
|
663 |
+
offset_y,
|
664 |
+
div_x,
|
665 |
+
div_y,
|
666 |
+
island_padding,
|
667 |
+
)
|
668 |
+
|
669 |
+
return self._get_unique_face_uv(placed_uv)
|