Spaces:
Running
on
Zero
Running
on
Zero
JAMESYJL
commited on
Commit
·
faccdf3
1
Parent(s):
a97bc04
v1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -35
- LICENSE +21 -0
- README.md +49 -14
- app.py +342 -109
- configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json +102 -0
- configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json +101 -0
- configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json +101 -0
- configs/generation/ss_flow_img_dit_L_16l8_fp16.json +70 -0
- configs/generation/ss_flow_txt_dit_B_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_L_16l8_fp16.json +69 -0
- configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json +70 -0
- configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json +73 -0
- configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json +71 -0
- configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json +105 -0
- configs/vae/ss_vae_conv3d_16l8_fp16.json +65 -0
- dataset_toolkits/blender_script/io_scene_usdz.zip +0 -0
- dataset_toolkits/blender_script/render.py +528 -0
- dataset_toolkits/build_metadata.py +270 -0
- dataset_toolkits/datasets/3D-FUTURE.py +97 -0
- dataset_toolkits/datasets/ABO.py +96 -0
- dataset_toolkits/datasets/HSSD.py +103 -0
- dataset_toolkits/datasets/ObjaverseXL.py +92 -0
- dataset_toolkits/datasets/Toys4k.py +92 -0
- dataset_toolkits/download.py +52 -0
- dataset_toolkits/encode_latent.py +127 -0
- dataset_toolkits/encode_ss_latent.py +128 -0
- dataset_toolkits/extract_feature.py +179 -0
- dataset_toolkits/render.py +121 -0
- dataset_toolkits/render_cond.py +125 -0
- dataset_toolkits/setup.sh +1 -0
- dataset_toolkits/stat_latent.py +66 -0
- dataset_toolkits/utils.py +43 -0
- dataset_toolkits/voxelize.py +86 -0
- examples/airplane.png +3 -0
- examples/airplane2.png +3 -0
- examples/bear.png +3 -0
- examples/car.png +3 -0
- examples/car2.png +3 -0
- examples/gun1.png +3 -0
- examples/gun2.png +3 -0
- examples/icecream.png +3 -0
- examples/knife.png +3 -0
- examples/man1.png +3 -0
- examples/man2.png +3 -0
- examples/man3.png +3 -0
- examples/robot1.png +3 -0
- examples/robot2.png +3 -0
- examples/shoe.png +3 -0
- examples/sweater.png +3 -0
.gitattributes
CHANGED
@@ -1,37 +1,2 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
*.glb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.png filter=lfs diff=lfs merge=lfs -text
|
2 |
*.glb filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Junliang Ye
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,49 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h3 align="center"><strong>ShapeLLM-omni: A Native Multimodal LLM for 3D Generation and Understanding</strong></h3>
|
3 |
+
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://jamesyjl.github.io/">Junliang Ye</a><sup>1,2*</sup>,
|
6 |
+
<a href="https://thuwzy.github.io/">Zhengyi Wang</a><sup>1,2*</sup>,
|
7 |
+
<a href="https://zhaorw02.github.io/">Ruowen Zhao</a><sup>1*</sup>,
|
8 |
+
<a href="">Shenghao Xie</a><sup>3</sup>,
|
9 |
+
<a href="https://ml.cs.tsinghua.edu.cn/~jun/index.shtml">Jun Zhu</a><sup>1,2†</sup>
|
10 |
+
<br>
|
11 |
+
<sup>*</sup>Equal Contribution.
|
12 |
+
<br>
|
13 |
+
<sup>†</sup>Corresponding authors.
|
14 |
+
<br>
|
15 |
+
<sup>1</sup>Tsinghua University,
|
16 |
+
<sup>2</sup>ShengShu,
|
17 |
+
<sup>3</sup>Peking University,
|
18 |
+
</p>
|
19 |
+
|
20 |
+
<div align="center">
|
21 |
+
|
22 |
+
<a href='https://arxiv.org/abs/2503.15265'><img src='https://img.shields.io/badge/arXiv-2503.15265-b31b1b.svg'></a>
|
23 |
+
<a href='https://jamesyjl.github.io/ShapeLLM/'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
24 |
+
<a><img src='https://img.shields.io/badge/License-MIT-blue'></a>
|
25 |
+
<a href="https://huggingface.co/zzzrw/DeepMesh/tree/main"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Weights-HF-orange"></a>
|
26 |
+
<a href='https://www.youtube.com/watch?v=6grL7bSbQ2w'><img src='https://img.shields.io/badge/Youtube-Video-b31b1b.svg'>
|
27 |
+
|
28 |
+
</div>
|
29 |
+
|
30 |
+
## Release
|
31 |
+
- [6/03] 🔥🔥We released the pretrained weights for both **ShapeLLM-omni** (7B) and **3DVQVAE**.
|
32 |
+
- [6/03] 🔥🔥We released 50k high-quality 3D edited data pairs.
|
33 |
+
|
34 |
+
## Important Notes
|
35 |
+
- Please refer to our [project_page](https://zhaorw02.github.io/DeepMesh/) for more examples.
|
36 |
+
## Todo
|
37 |
+
- [ ] Release of training code.
|
38 |
+
- [ ] Release of model weights featuring multi-turn dialogue and 3D editing capabilities.
|
39 |
+
- [ ] Release of the entire 3D-Alpaca dataset.
|
40 |
+
|
41 |
+
## Acknowledgement
|
42 |
+
Our code is based on these wonderful repos:
|
43 |
+
* **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**
|
44 |
+
* **[TRELLIS](https://github.com/microsoft/TRELLIS)**
|
45 |
+
* **[PointLLM](https://github.com/OpenRobotLab/PointLLM)**
|
46 |
+
* **[Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL)**
|
47 |
+
* [LLaMA-Mesh](https://github.com/nv-tlabs/LLaMA-Mesh)
|
48 |
+
|
49 |
+
|
app.py
CHANGED
@@ -4,33 +4,84 @@ from threading import Thread
|
|
4 |
import gradio as gr
|
5 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
|
6 |
from qwen_vl_utils import process_vision_info
|
|
|
|
|
7 |
import trimesh
|
8 |
from trimesh.exchange.gltf import export_glb
|
9 |
import numpy as np
|
10 |
import tempfile
|
11 |
import copy
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def _remove_image_special(text):
|
14 |
text = text.replace('<ref>', '').replace('</ref>', '')
|
15 |
return re.sub(r'<box>.*?(</box>|$)', '', text)
|
16 |
|
17 |
-
|
18 |
def is_video_file(filename):
|
19 |
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
|
20 |
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
21 |
|
22 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
chat_query = _chatbot[-1][0]
|
24 |
query = task_history[-1][0]
|
|
|
25 |
if len(chat_query) == 0:
|
26 |
_chatbot.pop()
|
27 |
task_history.pop()
|
28 |
-
return _chatbot
|
29 |
print("User: " + _parse_text(query))
|
30 |
history_cp = copy.deepcopy(task_history)
|
31 |
full_response = ""
|
32 |
messages = []
|
33 |
content = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
for q, a in history_cp:
|
35 |
if isinstance(q, (tuple, list)):
|
36 |
if is_video_file(q[0]):
|
@@ -44,45 +95,92 @@ def predict(_chatbot, task_history):
|
|
44 |
content = []
|
45 |
messages.pop()
|
46 |
messages = _transform_messages(messages)
|
47 |
-
text = processor.apply_chat_template(
|
48 |
-
messages, tokenize=False, add_generation_prompt=True)
|
49 |
image_inputs, video_inputs = process_vision_info(messages)
|
50 |
-
inputs = processor(text=[text], images=image_inputs,
|
51 |
-
videos=video_inputs, padding=True, return_tensors='pt')
|
52 |
inputs = inputs.to(model.device)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
60 |
thread.start()
|
61 |
-
|
62 |
-
# 初始化响应文本
|
63 |
full_response = ""
|
64 |
-
|
65 |
-
|
66 |
-
# 处理流式输出
|
67 |
for new_text in streamer:
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
# 最终处理(如果需要保存完整响应)
|
74 |
task_history[-1] = (chat_query, full_response)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def regenerate(_chatbot, task_history):
|
88 |
if not task_history:
|
@@ -131,20 +229,45 @@ def _parse_text(text):
|
|
131 |
text = "".join(lines)
|
132 |
return text
|
133 |
|
134 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
task_text = text
|
136 |
history = history if history is not None else []
|
137 |
task_history = task_history if task_history is not None else []
|
138 |
history = history + [(_parse_text(text), None)]
|
139 |
task_history = task_history + [(task_text, None)]
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
history = history if history is not None else []
|
144 |
task_history = task_history if task_history is not None else []
|
145 |
history = history + [((file.name,), None)]
|
146 |
task_history = task_history + [((file.name,), None)]
|
147 |
-
|
|
|
148 |
|
149 |
def reset_user_input():
|
150 |
return gr.update(value="")
|
@@ -153,6 +276,96 @@ def reset_state(task_history):
|
|
153 |
task_history.clear()
|
154 |
return []
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def _transform_messages(original_messages):
|
157 |
transformed_messages = []
|
158 |
for message in original_messages:
|
@@ -173,84 +386,104 @@ def _transform_messages(original_messages):
|
|
173 |
|
174 |
return transformed_messages
|
175 |
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
MODEL_DIR = "yejunliang23/ShapeLLM-7B-omni"
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
MODEL_DIR,
|
182 |
-
torch_dtype=torch.float16,
|
183 |
-
device_map="auto",
|
184 |
-
trust_remote_code=True
|
185 |
-
)
|
186 |
-
processor = AutoProcessor.from_pretrained(MODEL_DIR)
|
187 |
tokenizer = processor.tokenizer
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
text = processor.apply_chat_template(
|
201 |
-
messages, tokenize=False, add_generation_prompt=True)
|
202 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
203 |
-
inputs = processor(text=[text], images=image_inputs,
|
204 |
-
videos=video_inputs, padding=True, return_tensors='pt')
|
205 |
-
inputs = inputs.to(model.device)
|
206 |
-
|
207 |
-
streamer = TextIteratorStreamer(
|
208 |
-
tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
209 |
-
|
210 |
-
gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
|
211 |
-
|
212 |
-
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
213 |
-
thread.start()
|
214 |
-
|
215 |
-
#for new_text in streamer:
|
216 |
-
# yield new_text
|
217 |
-
|
218 |
-
buffer = []
|
219 |
-
for chunk in streamer:
|
220 |
-
buffer.append(chunk)
|
221 |
-
yield "".join(buffer)
|
222 |
-
|
223 |
-
|
224 |
-
css = """
|
225 |
-
h1 { text-align: center; }
|
226 |
-
"""
|
227 |
-
PLACEHOLDER = (
|
228 |
-
"<div style='padding:30px;text-align:center;display:flex;flex-direction:column;align-items:center;'>"
|
229 |
-
"<h1 style='font-size:28px;opacity:0.55;'>Qwen2.5-VL Local Chat</h1>"
|
230 |
-
"<p style='font-size:18px;opacity:0.65;'>Ask anything or generate images!</p></div>"
|
231 |
-
)
|
232 |
-
|
233 |
with gr.Blocks() as demo:
|
234 |
-
gr.Markdown("
|
235 |
-
|
236 |
-
chatbot = gr.Chatbot(label='ShapeLLM-4o', elem_classes="control-height", height=500)
|
237 |
-
query = gr.Textbox(lines=2, label='Input')
|
238 |
-
task_history = gr.State([])
|
239 |
-
|
240 |
with gr.Row():
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
)
|
249 |
submit_btn.click(reset_user_input, [], [query])
|
250 |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
251 |
-
regen_btn.click(regenerate,
|
252 |
-
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn
|
|
|
|
|
253 |
|
254 |
|
255 |
-
|
256 |
-
demo.launch()
|
|
|
4 |
import gradio as gr
|
5 |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
|
6 |
from qwen_vl_utils import process_vision_info
|
7 |
+
from trellis.pipelines import TrellisImageTo3DPipeline,TrellisTextTo3DPipeline
|
8 |
+
from trellis.utils import render_utils, postprocessing_utils
|
9 |
import trimesh
|
10 |
from trimesh.exchange.gltf import export_glb
|
11 |
import numpy as np
|
12 |
import tempfile
|
13 |
import copy
|
14 |
+
import plotly.graph_objs as go
|
15 |
+
from PIL import Image
|
16 |
+
import plotly.express as px
|
17 |
+
import random
|
18 |
+
import open3d as o3d
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
|
21 |
def _remove_image_special(text):
|
22 |
text = text.replace('<ref>', '').replace('</ref>', '')
|
23 |
return re.sub(r'<box>.*?(</box>|$)', '', text)
|
24 |
|
|
|
25 |
def is_video_file(filename):
|
26 |
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
|
27 |
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
28 |
|
29 |
+
def token_to_mesh(full_response):
|
30 |
+
d1=full_response.split("><mesh")
|
31 |
+
d2=[]
|
32 |
+
for i in range(len(d1)):
|
33 |
+
try:
|
34 |
+
if d1[i][:5]=="<mesh":
|
35 |
+
d2.append(int(d1[i][5:]))
|
36 |
+
else:
|
37 |
+
d2.append(int(d1[i]))
|
38 |
+
except:
|
39 |
+
pass
|
40 |
+
while len(d2)<1024:
|
41 |
+
d2.append(d2[-1])
|
42 |
+
encoding_indices=torch.tensor(d2).unsqueeze(0)
|
43 |
+
return encoding_indices
|
44 |
+
|
45 |
+
def save_ply_from_array(verts):
|
46 |
+
header = [
|
47 |
+
"ply",
|
48 |
+
"format ascii 1.0",
|
49 |
+
f"element vertex {verts.shape[0]}",
|
50 |
+
"property float x",
|
51 |
+
"property float y",
|
52 |
+
"property float z",
|
53 |
+
"end_header"
|
54 |
+
]
|
55 |
+
tmpf = tempfile.NamedTemporaryFile(suffix=".ply", delete=False)
|
56 |
+
tmpf.write(("\n".join(header) + "\n").encode("utf-8"))
|
57 |
+
np.savetxt(tmpf, verts, fmt="%.6f")
|
58 |
+
tmpf.flush(); tmpf.close()
|
59 |
+
return tmpf.name
|
60 |
+
|
61 |
+
def predict(_chatbot,task_history,viewer_voxel,viewer_mesh,task_new,seed,top_k,top_p,temperature):
|
62 |
+
torch.manual_seed(seed)
|
63 |
chat_query = _chatbot[-1][0]
|
64 |
query = task_history[-1][0]
|
65 |
+
|
66 |
if len(chat_query) == 0:
|
67 |
_chatbot.pop()
|
68 |
task_history.pop()
|
69 |
+
return _chatbot,task_history,viewer_voxel,viewer_mesh,task_new
|
70 |
print("User: " + _parse_text(query))
|
71 |
history_cp = copy.deepcopy(task_history)
|
72 |
full_response = ""
|
73 |
messages = []
|
74 |
content = []
|
75 |
+
|
76 |
+
image_lst = []
|
77 |
+
for q, a in task_new:
|
78 |
+
if isinstance(q, (tuple, list)):
|
79 |
+
if not is_video_file(q[0]):
|
80 |
+
image_lst.append(q[0])
|
81 |
+
else:
|
82 |
+
image_lst.append(q[0])
|
83 |
+
|
84 |
+
task_new.clear()
|
85 |
for q, a in history_cp:
|
86 |
if isinstance(q, (tuple, list)):
|
87 |
if is_video_file(q[0]):
|
|
|
95 |
content = []
|
96 |
messages.pop()
|
97 |
messages = _transform_messages(messages)
|
98 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
99 |
image_inputs, video_inputs = process_vision_info(messages)
|
100 |
+
inputs = processor(text=[text], images=image_inputs,videos=video_inputs, padding=True, return_tensors='pt')
|
|
|
101 |
inputs = inputs.to(model.device)
|
102 |
|
103 |
+
eos_token_id = [tokenizer.eos_token_id,159858]
|
104 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
105 |
+
gen_kwargs = {'max_new_tokens': 2048, 'streamer': streamer,"eos_token_id":eos_token_id,\
|
106 |
+
"top_k":top_k,"top_p":top_p,"temperature":temperature,"eos_token_id":eos_token_id,**inputs}
|
107 |
|
108 |
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
109 |
thread.start()
|
|
|
|
|
110 |
full_response = ""
|
111 |
+
encoding_indices = None
|
112 |
+
_chatbot[-1] = (_parse_text(chat_query), "")
|
|
|
113 |
for new_text in streamer:
|
114 |
+
if new_text:
|
115 |
+
if "<mesh" in new_text:
|
116 |
+
encoding_indices = token_to_mesh(new_text)
|
117 |
+
new_text = new_text.replace("><",",")[1:-1]
|
118 |
+
new_text = new_text.split("mesh-start,")[1].split(",mesh-end")[0]
|
119 |
+
new_text = f"mesh-start\n{new_text}\nmesh-end"
|
120 |
+
full_response += new_text
|
121 |
+
_chatbot[-1] = (_parse_text(chat_query), _parse_text(full_response))
|
122 |
+
yield _chatbot,viewer_voxel,viewer_mesh,task_new
|
123 |
|
|
|
124 |
task_history[-1] = (chat_query, full_response)
|
125 |
+
yield _chatbot,viewer_voxel,viewer_mesh,task_new
|
126 |
+
|
127 |
+
if encoding_indices is not None:
|
128 |
+
print("processing mesh...")
|
129 |
+
recon = vqvae.Decode(encoding_indices.to(model.device))
|
130 |
+
z_s = recon[0].detach().cpu()
|
131 |
+
z_s = (z_s>0)*1
|
132 |
+
indices = torch.nonzero(z_s[0] == 1)
|
133 |
+
position_recon= (indices.float() + 0.5) / 64 - 0.5
|
134 |
+
fig = make_pointcloud_figure(position_recon)
|
135 |
+
yield _chatbot,fig,viewer_mesh,task_new
|
136 |
+
|
137 |
+
position=position_recon
|
138 |
+
coords = ((position + 0.5) * 64).int().contiguous()
|
139 |
+
ss = torch.zeros(1, 64, 64, 64, dtype=torch.long)
|
140 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
141 |
+
ss=ss.unsqueeze(0)
|
142 |
+
coords = torch.argwhere(ss>0)[:, [0, 2, 3, 4]].int()
|
143 |
+
coords = coords.to(model.device)
|
144 |
+
try:
|
145 |
+
print("processing mesh...")
|
146 |
+
if len(image_lst) == 0:
|
147 |
+
# text to 3d
|
148 |
+
with torch.no_grad():
|
149 |
+
prompt = chat_query
|
150 |
+
cond = pipeline_text.get_cond([prompt])
|
151 |
+
slat = pipeline_text.sample_slat(cond, coords)
|
152 |
+
outputs = pipeline_text.decode_slat(slat, ['mesh', 'gaussian'])
|
153 |
+
|
154 |
+
glb = postprocessing_utils.to_glb(
|
155 |
+
outputs['gaussian'][0],
|
156 |
+
outputs['mesh'][0],
|
157 |
+
simplify=0.95,
|
158 |
+
texture_size=1024,
|
159 |
+
verbose=False
|
160 |
+
)
|
161 |
+
glb.export(f"temper.glb")
|
162 |
+
print("processing mesh over...")
|
163 |
+
yield _chatbot,fig,"temper.glb"
|
164 |
+
else:
|
165 |
+
# image to 3d
|
166 |
+
with torch.no_grad():
|
167 |
+
img = pipeline_image.preprocess_image(Image.open(image_lst[-1]))
|
168 |
+
cond = pipeline_image.get_cond([img])
|
169 |
+
slat = pipeline_image.sample_slat(cond, coords)
|
170 |
+
outputs = pipeline_image.decode_slat(slat, ['mesh', 'gaussian'])
|
171 |
+
glb = postprocessing_utils.to_glb(
|
172 |
+
outputs['gaussian'][0],
|
173 |
+
outputs['mesh'][0],
|
174 |
+
simplify=0.95,
|
175 |
+
texture_size=1024,
|
176 |
+
verbose=False
|
177 |
+
)
|
178 |
+
glb.export(f"temper.glb")
|
179 |
+
print("processing mesh over...")
|
180 |
+
yield _chatbot,fig,"temper.glb",task_new
|
181 |
+
except:
|
182 |
+
print("processing mesh...bug")
|
183 |
+
yield _chatbot,fig,viewer_mesh,task_new
|
184 |
|
185 |
def regenerate(_chatbot, task_history):
|
186 |
if not task_history:
|
|
|
229 |
text = "".join(lines)
|
230 |
return text
|
231 |
|
232 |
+
def add_text_prefix(text):
|
233 |
+
text = f"Please generate a 3D asset based on the prompt I provided: {text}"
|
234 |
+
return gr.update(value=text)
|
235 |
+
|
236 |
+
def token_to_words(token):
|
237 |
+
mesh = "<mesh-start>"
|
238 |
+
for j in range(1024):
|
239 |
+
mesh += f"<mesh{token[j]}>"
|
240 |
+
mesh += "<mesh-end>"
|
241 |
+
return mesh
|
242 |
+
|
243 |
+
def add_text(history, task_history, text,task_new):
|
244 |
task_text = text
|
245 |
history = history if history is not None else []
|
246 |
task_history = task_history if task_history is not None else []
|
247 |
history = history + [(_parse_text(text), None)]
|
248 |
task_history = task_history + [(task_text, None)]
|
249 |
+
task_new = task_new + [(task_text, None)]
|
250 |
+
return history, task_history,task_new
|
251 |
+
|
252 |
+
def add_file(history, task_history, file, task_new, fig, query):
|
253 |
+
if file.name.endswith(('.obj', '.glb')):
|
254 |
+
position_recon = load_vertices(file.name)#(N,3)
|
255 |
+
|
256 |
+
coords = ((torch.from_numpy(position_recon) + 0.5) * 64).int().contiguous()
|
257 |
+
ss = torch.zeros(1, 64, 64, 64, dtype=torch.long)
|
258 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
259 |
+
token = vqvae.Encode(ss.to(dtype=torch.float32).unsqueeze(0).to("cuda"))
|
260 |
+
token = token[0].cpu().numpy().tolist()
|
261 |
+
words = token_to_words(token)
|
262 |
+
fig = make_pointcloud_figure(position_recon,rotate=True)
|
263 |
+
return history, task_history,file.name,task_new,fig,gr.update(
|
264 |
+
value=f"{words}\nGive a quick overview of the object represented by this 3D mesh.")
|
265 |
history = history if history is not None else []
|
266 |
task_history = task_history if task_history is not None else []
|
267 |
history = history + [((file.name,), None)]
|
268 |
task_history = task_history + [((file.name,), None)]
|
269 |
+
task_new = task_new + [((file.name,), None)]
|
270 |
+
return history, task_history, file.name, task_new, fig, query
|
271 |
|
272 |
def reset_user_input():
|
273 |
return gr.update(value="")
|
|
|
276 |
task_history.clear()
|
277 |
return []
|
278 |
|
279 |
+
def make_pointcloud_figure(verts,rotate=False):
|
280 |
+
if rotate:
|
281 |
+
verts = verts.copy()
|
282 |
+
verts[:, 0] *= -1.0
|
283 |
+
N = len(verts)
|
284 |
+
soft_palette = ["#FFEBEE", "#FFF3E0", "#FFFDE7", "#E8F5E9",]
|
285 |
+
palette = px.colors.qualitative.Set3
|
286 |
+
base_colors = [palette[i % len(palette)] for i in range(N)]
|
287 |
+
random.shuffle(base_colors)
|
288 |
+
|
289 |
+
camera = dict(
|
290 |
+
eye=dict(x=0.0, y=2.5, z=0.0),
|
291 |
+
center=dict(x=0.0, y=0.0, z=0.0),
|
292 |
+
up=dict(x=0.0, y=0.0, z=1.0),
|
293 |
+
projection=dict(type="orthographic")
|
294 |
+
)
|
295 |
+
|
296 |
+
scatter = go.Scatter3d(
|
297 |
+
x=verts[:, 0],
|
298 |
+
y=verts[:, 1],
|
299 |
+
z=verts[:, 2],
|
300 |
+
mode='markers',
|
301 |
+
marker=dict(
|
302 |
+
size=2,
|
303 |
+
color=base_colors,
|
304 |
+
opacity=1,
|
305 |
+
line=dict(width=1)
|
306 |
+
)
|
307 |
+
)
|
308 |
+
layout = go.Layout(
|
309 |
+
width =700,
|
310 |
+
height=200,
|
311 |
+
scene=dict(
|
312 |
+
xaxis=dict(visible=False),
|
313 |
+
yaxis=dict(visible=False),
|
314 |
+
zaxis=dict(visible=False),
|
315 |
+
camera=camera
|
316 |
+
),
|
317 |
+
margin=dict(l=0, r=0, b=0, t=0)
|
318 |
+
)
|
319 |
+
fig = go.Figure(data=[scatter], layout=layout)
|
320 |
+
return fig
|
321 |
+
|
322 |
+
def rotate_points(points, axis='x', angle_deg=90):
|
323 |
+
angle_rad = np.deg2rad(angle_deg)
|
324 |
+
if axis == 'x':
|
325 |
+
R = trimesh.transformations.rotation_matrix(angle_rad, [1, 0, 0])[:3, :3]
|
326 |
+
elif axis == 'y':
|
327 |
+
R = trimesh.transformations.rotation_matrix(angle_rad, [0, 1, 0])[:3, :3]
|
328 |
+
elif axis == 'z':
|
329 |
+
R = trimesh.transformations.rotation_matrix(angle_rad, [0, 0, 1])[:3, :3]
|
330 |
+
else:
|
331 |
+
raise ValueError("axis must be 'x', 'y', or 'z'")
|
332 |
+
return points @ R.T
|
333 |
+
|
334 |
+
def convert_trimesh_to_open3d(trimesh_mesh):
|
335 |
+
o3d_mesh = o3d.geometry.TriangleMesh()
|
336 |
+
o3d_mesh.vertices = o3d.utility.Vector3dVector(
|
337 |
+
np.asarray(trimesh_mesh.vertices, dtype=np.float64)
|
338 |
+
)
|
339 |
+
o3d_mesh.triangles = o3d.utility.Vector3iVector(
|
340 |
+
np.asarray(trimesh_mesh.faces, dtype=np.int32)
|
341 |
+
)
|
342 |
+
return o3d_mesh
|
343 |
+
|
344 |
+
def load_vertices(filepath):
|
345 |
+
mesh = trimesh.load(filepath, force='mesh')
|
346 |
+
mesh = convert_trimesh_to_open3d(mesh)
|
347 |
+
vertices = np.asarray(mesh.vertices)
|
348 |
+
min_vals = vertices.min()
|
349 |
+
max_vals = vertices.max()
|
350 |
+
vertices_normalized = (vertices - min_vals) / (max_vals - min_vals)
|
351 |
+
vertices = vertices_normalized * 1.0 - 0.5
|
352 |
+
vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
|
353 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
354 |
+
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
|
355 |
+
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
|
356 |
+
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
|
357 |
+
vertices = (vertices + 0.5) / 64 - 0.5
|
358 |
+
voxel = rotate_points(vertices, axis='x', angle_deg=90)
|
359 |
+
return voxel
|
360 |
+
|
361 |
+
def add_file2(history, task_history, file,task_new):
|
362 |
+
history = history if history is not None else []
|
363 |
+
task_history = task_history if task_history is not None else []
|
364 |
+
history = history + [((file,), None)]
|
365 |
+
task_history = task_history + [((file,), None)]
|
366 |
+
task_new = task_new + [((file,), None)]
|
367 |
+
return history, task_history,file,task_new
|
368 |
+
|
369 |
def _transform_messages(original_messages):
|
370 |
transformed_messages = []
|
371 |
for message in original_messages:
|
|
|
386 |
|
387 |
return transformed_messages
|
388 |
|
389 |
+
from trellis.models.sparse_structure_vqvae import VQVAE3D
|
390 |
+
device = torch.device("cuda")
|
391 |
+
vqvae = VQVAE3D(num_embeddings=8192)
|
392 |
+
device = torch.device("cuda")
|
393 |
+
vqvae.eval()
|
394 |
+
filepath = hf_hub_download(repo_id="yejunliang23/3DVQVAE",filename="3DVQVAE.bin")
|
395 |
+
state_dict = torch.load(filepath, map_location="cpu")
|
396 |
+
vqvae.load_state_dict(state_dict)
|
397 |
+
vqvae=vqvae.to(device)
|
398 |
+
|
399 |
MODEL_DIR = "yejunliang23/ShapeLLM-7B-omni"
|
400 |
+
model_ckpt_path=MODEL_DIR
|
401 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_ckpt_path, torch_dtype="auto", device_map={"": 0})
|
402 |
+
processor = AutoProcessor.from_pretrained(model_ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
tokenizer = processor.tokenizer
|
404 |
+
from huggingface_hub import hf_hub_download
|
405 |
+
|
406 |
+
pipeline_text = TrellisTextTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-text-xlarge")
|
407 |
+
pipeline_text.to(device)
|
408 |
+
pipeline_image = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
409 |
+
pipeline_image.to(device)
|
410 |
+
|
411 |
+
_DESCRIPTION = '''
|
412 |
+
* Project page of ShapeLLM-Omni: https://jamesyjl.github.io/ShapeLLM/
|
413 |
+
* As generation tasks currently lack support for multi-turn dialogue, it's strongly recommended to clear the chat history before starting a new task
|
414 |
+
* The model's 3D understanding is limited to shape only, so color and texture should be ignored in 3D captioning tasks
|
415 |
+
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
with gr.Blocks() as demo:
|
417 |
+
gr.Markdown("# ShapeLLM-omni: A Native Multimodal LLM for 3D Generation and Understanding")
|
418 |
+
gr.Markdown(_DESCRIPTION)
|
|
|
|
|
|
|
|
|
419 |
with gr.Row():
|
420 |
+
with gr.Column():
|
421 |
+
chatbot = gr.Chatbot(label='ShapeLLM-Omni', elem_classes="control-height", height=500)
|
422 |
+
seed = gr.Number(value=42, label="seed", precision=0)
|
423 |
+
top_k = gr.Slider(label="top_k",minimum=1024,maximum=8194,value=1024,step=10)
|
424 |
+
top_p = gr.Slider(label="top_p",minimum=0.1,maximum=1.0,value=0.1,step=0.05)
|
425 |
+
temperature = gr.Slider(label="temperature",minimum=0.1,maximum=1.0,value=0.1,step=0.05)
|
426 |
+
|
427 |
+
query = gr.Textbox(lines=2, label='Input')
|
428 |
+
image_input = gr.Image(visible=False, type="filepath", label="Image Input")
|
429 |
+
with gr.Column():
|
430 |
+
with gr.Row():
|
431 |
+
addfile_btn = gr.UploadButton("📁 Upload", file_types=["image", "video",".obj",".glb"])
|
432 |
+
submit_btn = gr.Button("🚀 Submit")
|
433 |
+
with gr.Row():
|
434 |
+
regen_btn = gr.Button("🤔️ Regenerate")
|
435 |
+
empty_bin = gr.Button("🧹 Clear History")
|
436 |
+
task_history = gr.State([])
|
437 |
+
task_new = gr.State([])
|
438 |
+
with gr.Column():
|
439 |
+
viewer_plot = gr.Plot(label="Voxel Visual",scale=1.0)
|
440 |
+
viewer_mesh = gr.Model3D(label="Mesh Visual", height=200,scale=1.0)
|
441 |
+
|
442 |
+
examples_text = gr.Examples(
|
443 |
+
examples=[
|
444 |
+
["A drone with four propellers and a central body."],
|
445 |
+
["A stone axe with a handle."],
|
446 |
+
["the titanic, aerial view."],
|
447 |
+
["A 3D model of a small yellow and blue robot with wheels and two pots."],
|
448 |
+
["A futuristic vehicle with a sleek design and multiple wheels."],
|
449 |
+
["A car with four wheels and a roof."],
|
450 |
+
],
|
451 |
+
inputs=[query],
|
452 |
+
label="text-to-3d examples",
|
453 |
+
fn=add_text_prefix,
|
454 |
+
outputs=[query],
|
455 |
+
cache_examples=True,
|
456 |
+
)
|
457 |
+
|
458 |
+
examples_text.dataset.click(
|
459 |
+
fn=add_text,
|
460 |
+
inputs=[chatbot, task_history, query,task_new],
|
461 |
+
outputs=[chatbot, task_history,task_new],
|
462 |
+
)
|
463 |
+
examples_image = gr.Examples(
|
464 |
+
label="image-to-3d examples",
|
465 |
+
examples=[os.path.join("examples", i) for i in os.listdir("examples")],
|
466 |
+
inputs=[image_input],
|
467 |
+
examples_per_page = 20,
|
468 |
+
)
|
469 |
+
image_input.change(
|
470 |
+
fn=add_file2,
|
471 |
+
inputs=[chatbot, task_history, image_input,task_new],
|
472 |
+
outputs=[chatbot, task_history,viewer_mesh,task_new],
|
473 |
+
show_progress=True
|
474 |
+
)
|
475 |
+
|
476 |
+
submit_btn.click(add_text, [chatbot, task_history, query,task_new],\
|
477 |
+
[chatbot, task_history,task_new]).then(
|
478 |
+
predict, [chatbot, task_history,viewer_plot,viewer_mesh,task_new,seed,top_k,top_p,temperature],\
|
479 |
+
[chatbot,viewer_plot,viewer_mesh,task_new], show_progress=True
|
480 |
)
|
481 |
submit_btn.click(reset_user_input, [], [query])
|
482 |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
483 |
+
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
|
484 |
+
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn, task_new, viewer_plot, query],\
|
485 |
+
[chatbot, task_history, viewer_mesh, task_new, viewer_plot, query],\
|
486 |
+
show_progress=True)
|
487 |
|
488 |
|
489 |
+
demo.launch()
|
|
configs/generation/slat_flow_img_dit_L_64l8p2_fp16.json
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 1024,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "ImageConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"image_size": 518,
|
30 |
+
"normalization": {
|
31 |
+
"mean": [
|
32 |
+
-2.1687545776367188,
|
33 |
+
-0.004347046371549368,
|
34 |
+
-0.13352349400520325,
|
35 |
+
-0.08418072760105133,
|
36 |
+
-0.5271206498146057,
|
37 |
+
0.7238689064979553,
|
38 |
+
-1.1414450407028198,
|
39 |
+
1.2039363384246826
|
40 |
+
],
|
41 |
+
"std": [
|
42 |
+
2.377650737762451,
|
43 |
+
2.386378288269043,
|
44 |
+
2.124418020248413,
|
45 |
+
2.1748552322387695,
|
46 |
+
2.663944721221924,
|
47 |
+
2.371192216873169,
|
48 |
+
2.6217446327209473,
|
49 |
+
2.684523105621338
|
50 |
+
]
|
51 |
+
},
|
52 |
+
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"trainer": {
|
56 |
+
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
|
57 |
+
"args": {
|
58 |
+
"max_steps": 1000000,
|
59 |
+
"batch_size_per_gpu": 8,
|
60 |
+
"batch_split": 4,
|
61 |
+
"optimizer": {
|
62 |
+
"name": "AdamW",
|
63 |
+
"args": {
|
64 |
+
"lr": 0.0001,
|
65 |
+
"weight_decay": 0.0
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"ema_rate": [
|
69 |
+
0.9999
|
70 |
+
],
|
71 |
+
"fp16_mode": "inflat_all",
|
72 |
+
"fp16_scale_growth": 0.001,
|
73 |
+
"elastic": {
|
74 |
+
"name": "LinearMemoryController",
|
75 |
+
"args": {
|
76 |
+
"target_ratio": 0.75,
|
77 |
+
"max_mem_ratio_start": 0.5
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"grad_clip": {
|
81 |
+
"name": "AdaptiveGradClipper",
|
82 |
+
"args": {
|
83 |
+
"max_norm": 1.0,
|
84 |
+
"clip_percentile": 95
|
85 |
+
}
|
86 |
+
},
|
87 |
+
"i_log": 500,
|
88 |
+
"i_sample": 10000,
|
89 |
+
"i_save": 10000,
|
90 |
+
"p_uncond": 0.1,
|
91 |
+
"t_schedule": {
|
92 |
+
"name": "logitNormal",
|
93 |
+
"args": {
|
94 |
+
"mean": 1.0,
|
95 |
+
"std": 1.0
|
96 |
+
}
|
97 |
+
},
|
98 |
+
"sigma_min": 1e-5,
|
99 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
100 |
+
}
|
101 |
+
}
|
102 |
+
}
|
configs/generation/slat_flow_txt_dit_B_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 768,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 12,
|
12 |
+
"num_heads": 12,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 16,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/slat_flow_txt_dit_L_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 2,
|
16 |
+
"io_block_channels": [128],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 8,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "ElasticSLatFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1280,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 28,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 2,
|
15 |
+
"num_io_res_blocks": 3,
|
16 |
+
"io_block_channels": [256],
|
17 |
+
"pe_mode": "ape",
|
18 |
+
"qk_rms_norm": true,
|
19 |
+
"use_fp16": true
|
20 |
+
}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"dataset": {
|
24 |
+
"name": "TextConditionedSLat",
|
25 |
+
"args": {
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768,
|
29 |
+
"normalization": {
|
30 |
+
"mean": [
|
31 |
+
-2.1687545776367188,
|
32 |
+
-0.004347046371549368,
|
33 |
+
-0.13352349400520325,
|
34 |
+
-0.08418072760105133,
|
35 |
+
-0.5271206498146057,
|
36 |
+
0.7238689064979553,
|
37 |
+
-1.1414450407028198,
|
38 |
+
1.2039363384246826
|
39 |
+
],
|
40 |
+
"std": [
|
41 |
+
2.377650737762451,
|
42 |
+
2.386378288269043,
|
43 |
+
2.124418020248413,
|
44 |
+
2.1748552322387695,
|
45 |
+
2.663944721221924,
|
46 |
+
2.371192216873169,
|
47 |
+
2.6217446327209473,
|
48 |
+
2.684523105621338
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"trainer": {
|
55 |
+
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
|
56 |
+
"args": {
|
57 |
+
"max_steps": 1000000,
|
58 |
+
"batch_size_per_gpu": 4,
|
59 |
+
"batch_split": 4,
|
60 |
+
"optimizer": {
|
61 |
+
"name": "AdamW",
|
62 |
+
"args": {
|
63 |
+
"lr": 0.0001,
|
64 |
+
"weight_decay": 0.0
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"ema_rate": [
|
68 |
+
0.9999
|
69 |
+
],
|
70 |
+
"fp16_mode": "inflat_all",
|
71 |
+
"fp16_scale_growth": 0.001,
|
72 |
+
"elastic": {
|
73 |
+
"name": "LinearMemoryController",
|
74 |
+
"args": {
|
75 |
+
"target_ratio": 0.75,
|
76 |
+
"max_mem_ratio_start": 0.5
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"grad_clip": {
|
80 |
+
"name": "AdaptiveGradClipper",
|
81 |
+
"args": {
|
82 |
+
"max_norm": 1.0,
|
83 |
+
"clip_percentile": 95
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"i_log": 500,
|
87 |
+
"i_sample": 10000,
|
88 |
+
"i_save": 10000,
|
89 |
+
"p_uncond": 0.1,
|
90 |
+
"t_schedule": {
|
91 |
+
"name": "logitNormal",
|
92 |
+
"args": {
|
93 |
+
"mean": 1.0,
|
94 |
+
"std": 1.0
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"sigma_min": 1e-5,
|
98 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
configs/generation/ss_flow_img_dit_L_16l8_fp16.json
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 1024,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "ImageConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"image_size": 518,
|
27 |
+
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "ImageConditionedFlowMatchingCFGTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 8,
|
35 |
+
"batch_split": 1,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 0.0001,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"grad_clip": {
|
49 |
+
"name": "AdaptiveGradClipper",
|
50 |
+
"args": {
|
51 |
+
"max_norm": 1.0,
|
52 |
+
"clip_percentile": 95
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"i_log": 500,
|
56 |
+
"i_sample": 10000,
|
57 |
+
"i_save": 10000,
|
58 |
+
"p_uncond": 0.1,
|
59 |
+
"t_schedule": {
|
60 |
+
"name": "logitNormal",
|
61 |
+
"args": {
|
62 |
+
"mean": 1.0,
|
63 |
+
"std": 1.0
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"sigma_min": 1e-5,
|
67 |
+
"image_cond_model": "dinov2_vitl14_reg"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
configs/generation/ss_flow_txt_dit_B_16l8_fp16.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 768,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 12,
|
12 |
+
"num_heads": 12,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "TextConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"trainer": {
|
30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
31 |
+
"args": {
|
32 |
+
"max_steps": 1000000,
|
33 |
+
"batch_size_per_gpu": 16,
|
34 |
+
"batch_split": 1,
|
35 |
+
"optimizer": {
|
36 |
+
"name": "AdamW",
|
37 |
+
"args": {
|
38 |
+
"lr": 0.0001,
|
39 |
+
"weight_decay": 0.0
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"ema_rate": [
|
43 |
+
0.9999
|
44 |
+
],
|
45 |
+
"fp16_mode": "inflat_all",
|
46 |
+
"fp16_scale_growth": 0.001,
|
47 |
+
"grad_clip": {
|
48 |
+
"name": "AdaptiveGradClipper",
|
49 |
+
"args": {
|
50 |
+
"max_norm": 1.0,
|
51 |
+
"clip_percentile": 95
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"i_log": 500,
|
55 |
+
"i_sample": 10000,
|
56 |
+
"i_save": 10000,
|
57 |
+
"p_uncond": 0.1,
|
58 |
+
"t_schedule": {
|
59 |
+
"name": "logitNormal",
|
60 |
+
"args": {
|
61 |
+
"mean": 1.0,
|
62 |
+
"std": 1.0
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"sigma_min": 1e-5,
|
66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
configs/generation/ss_flow_txt_dit_L_16l8_fp16.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1024,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 24,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"use_fp16": true
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "TextConditionedSparseStructureLatent",
|
23 |
+
"args": {
|
24 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
25 |
+
"min_aesthetic_score": 4.5,
|
26 |
+
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"trainer": {
|
30 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
31 |
+
"args": {
|
32 |
+
"max_steps": 1000000,
|
33 |
+
"batch_size_per_gpu": 8,
|
34 |
+
"batch_split": 1,
|
35 |
+
"optimizer": {
|
36 |
+
"name": "AdamW",
|
37 |
+
"args": {
|
38 |
+
"lr": 0.0001,
|
39 |
+
"weight_decay": 0.0
|
40 |
+
}
|
41 |
+
},
|
42 |
+
"ema_rate": [
|
43 |
+
0.9999
|
44 |
+
],
|
45 |
+
"fp16_mode": "inflat_all",
|
46 |
+
"fp16_scale_growth": 0.001,
|
47 |
+
"grad_clip": {
|
48 |
+
"name": "AdaptiveGradClipper",
|
49 |
+
"args": {
|
50 |
+
"max_norm": 1.0,
|
51 |
+
"clip_percentile": 95
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"i_log": 500,
|
55 |
+
"i_sample": 10000,
|
56 |
+
"i_save": 10000,
|
57 |
+
"p_uncond": 0.1,
|
58 |
+
"t_schedule": {
|
59 |
+
"name": "logitNormal",
|
60 |
+
"args": {
|
61 |
+
"mean": 1.0,
|
62 |
+
"std": 1.0
|
63 |
+
}
|
64 |
+
},
|
65 |
+
"sigma_min": 1e-5,
|
66 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
configs/generation/ss_flow_txt_dit_XL_16l8_fp16.json
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"denoiser": {
|
4 |
+
"name": "SparseStructureFlowModel",
|
5 |
+
"args": {
|
6 |
+
"resolution": 16,
|
7 |
+
"in_channels": 8,
|
8 |
+
"out_channels": 8,
|
9 |
+
"model_channels": 1280,
|
10 |
+
"cond_channels": 768,
|
11 |
+
"num_blocks": 28,
|
12 |
+
"num_heads": 16,
|
13 |
+
"mlp_ratio": 4,
|
14 |
+
"patch_size": 1,
|
15 |
+
"pe_mode": "ape",
|
16 |
+
"qk_rms_norm": true,
|
17 |
+
"qk_rms_norm_cross": true,
|
18 |
+
"use_fp16": true
|
19 |
+
}
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"dataset": {
|
23 |
+
"name": "TextConditionedSparseStructureLatent",
|
24 |
+
"args": {
|
25 |
+
"latent_model": "ss_enc_conv3d_16l8_fp16",
|
26 |
+
"min_aesthetic_score": 4.5,
|
27 |
+
"pretrained_ss_dec": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "TextConditionedFlowMatchingCFGTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 4,
|
35 |
+
"batch_split": 1,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 0.0001,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"grad_clip": {
|
49 |
+
"name": "AdaptiveGradClipper",
|
50 |
+
"args": {
|
51 |
+
"max_norm": 1.0,
|
52 |
+
"clip_percentile": 95
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"i_log": 500,
|
56 |
+
"i_sample": 10000,
|
57 |
+
"i_save": 10000,
|
58 |
+
"p_uncond": 0.1,
|
59 |
+
"t_schedule": {
|
60 |
+
"name": "logitNormal",
|
61 |
+
"args": {
|
62 |
+
"mean": 1.0,
|
63 |
+
"std": 1.0
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"sigma_min": 1e-5,
|
67 |
+
"text_cond_model": "openai/clip-vit-large-patch14"
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
configs/vae/slat_vae_dec_mesh_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"decoder": {
|
4 |
+
"name": "ElasticSLatMeshDecoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"model_channels": 768,
|
8 |
+
"latent_channels": 8,
|
9 |
+
"num_blocks": 12,
|
10 |
+
"num_heads": 12,
|
11 |
+
"mlp_ratio": 4,
|
12 |
+
"attn_mode": "swin",
|
13 |
+
"window_size": 8,
|
14 |
+
"use_fp16": true,
|
15 |
+
"representation_config": {
|
16 |
+
"use_color": true
|
17 |
+
}
|
18 |
+
}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"dataset": {
|
22 |
+
"name": "Slat2RenderGeo",
|
23 |
+
"args": {
|
24 |
+
"image_size": 512,
|
25 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
26 |
+
"min_aesthetic_score": 4.5,
|
27 |
+
"max_num_voxels": 32768
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"trainer": {
|
31 |
+
"name": "SLatVaeMeshDecoderTrainer",
|
32 |
+
"args": {
|
33 |
+
"max_steps": 1000000,
|
34 |
+
"batch_size_per_gpu": 4,
|
35 |
+
"batch_split": 4,
|
36 |
+
"optimizer": {
|
37 |
+
"name": "AdamW",
|
38 |
+
"args": {
|
39 |
+
"lr": 1e-4,
|
40 |
+
"weight_decay": 0.0
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"ema_rate": [
|
44 |
+
0.9999
|
45 |
+
],
|
46 |
+
"fp16_mode": "inflat_all",
|
47 |
+
"fp16_scale_growth": 0.001,
|
48 |
+
"elastic": {
|
49 |
+
"name": "LinearMemoryController",
|
50 |
+
"args": {
|
51 |
+
"target_ratio": 0.75,
|
52 |
+
"max_mem_ratio_start": 0.5
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"grad_clip": {
|
56 |
+
"name": "AdaptiveGradClipper",
|
57 |
+
"args": {
|
58 |
+
"max_norm": 1.0,
|
59 |
+
"clip_percentile": 95
|
60 |
+
}
|
61 |
+
},
|
62 |
+
"i_log": 500,
|
63 |
+
"i_sample": 10000,
|
64 |
+
"i_save": 10000,
|
65 |
+
"lambda_ssim": 0.2,
|
66 |
+
"lambda_lpips": 0.2,
|
67 |
+
"lambda_tsdf": 0.01,
|
68 |
+
"lambda_depth": 10.0,
|
69 |
+
"lambda_color": 0.1,
|
70 |
+
"depth_loss_type": "smooth_l1"
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
configs/vae/slat_vae_dec_rf_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"decoder": {
|
4 |
+
"name": "ElasticSLatRadianceFieldDecoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"model_channels": 768,
|
8 |
+
"latent_channels": 8,
|
9 |
+
"num_blocks": 12,
|
10 |
+
"num_heads": 12,
|
11 |
+
"mlp_ratio": 4,
|
12 |
+
"attn_mode": "swin",
|
13 |
+
"window_size": 8,
|
14 |
+
"use_fp16": true,
|
15 |
+
"representation_config": {
|
16 |
+
"rank": 16,
|
17 |
+
"dim": 8
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"dataset": {
|
23 |
+
"name": "SLat2Render",
|
24 |
+
"args": {
|
25 |
+
"image_size": 512,
|
26 |
+
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
|
27 |
+
"min_aesthetic_score": 4.5,
|
28 |
+
"max_num_voxels": 32768
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"trainer": {
|
32 |
+
"name": "SLatVaeRadianceFieldDecoderTrainer",
|
33 |
+
"args": {
|
34 |
+
"max_steps": 1000000,
|
35 |
+
"batch_size_per_gpu": 4,
|
36 |
+
"batch_split": 2,
|
37 |
+
"optimizer": {
|
38 |
+
"name": "AdamW",
|
39 |
+
"args": {
|
40 |
+
"lr": 1e-4,
|
41 |
+
"weight_decay": 0.0
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"ema_rate": [
|
45 |
+
0.9999
|
46 |
+
],
|
47 |
+
"fp16_mode": "inflat_all",
|
48 |
+
"fp16_scale_growth": 0.001,
|
49 |
+
"elastic": {
|
50 |
+
"name": "LinearMemoryController",
|
51 |
+
"args": {
|
52 |
+
"target_ratio": 0.75,
|
53 |
+
"max_mem_ratio_start": 0.5
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"grad_clip": {
|
57 |
+
"name": "AdaptiveGradClipper",
|
58 |
+
"args": {
|
59 |
+
"max_norm": 1.0,
|
60 |
+
"clip_percentile": 95
|
61 |
+
}
|
62 |
+
},
|
63 |
+
"i_log": 500,
|
64 |
+
"i_sample": 10000,
|
65 |
+
"i_save": 10000,
|
66 |
+
"loss_type": "l1",
|
67 |
+
"lambda_ssim": 0.2,
|
68 |
+
"lambda_lpips": 0.2
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
configs/vae/slat_vae_enc_dec_gs_swin8_B_64l8_fp16.json
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"encoder": {
|
4 |
+
"name": "ElasticSLatEncoder",
|
5 |
+
"args": {
|
6 |
+
"resolution": 64,
|
7 |
+
"in_channels": 1024,
|
8 |
+
"model_channels": 768,
|
9 |
+
"latent_channels": 8,
|
10 |
+
"num_blocks": 12,
|
11 |
+
"num_heads": 12,
|
12 |
+
"mlp_ratio": 4,
|
13 |
+
"attn_mode": "swin",
|
14 |
+
"window_size": 8,
|
15 |
+
"use_fp16": true
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"decoder": {
|
19 |
+
"name": "ElasticSLatGaussianDecoder",
|
20 |
+
"args": {
|
21 |
+
"resolution": 64,
|
22 |
+
"model_channels": 768,
|
23 |
+
"latent_channels": 8,
|
24 |
+
"num_blocks": 12,
|
25 |
+
"num_heads": 12,
|
26 |
+
"mlp_ratio": 4,
|
27 |
+
"attn_mode": "swin",
|
28 |
+
"window_size": 8,
|
29 |
+
"use_fp16": true,
|
30 |
+
"representation_config": {
|
31 |
+
"lr": {
|
32 |
+
"_xyz": 1.0,
|
33 |
+
"_features_dc": 1.0,
|
34 |
+
"_opacity": 1.0,
|
35 |
+
"_scaling": 1.0,
|
36 |
+
"_rotation": 0.1
|
37 |
+
},
|
38 |
+
"perturb_offset": true,
|
39 |
+
"voxel_size": 1.5,
|
40 |
+
"num_gaussians": 32,
|
41 |
+
"2d_filter_kernel_size": 0.1,
|
42 |
+
"3d_filter_kernel_size": 9e-4,
|
43 |
+
"scaling_bias": 4e-3,
|
44 |
+
"opacity_bias": 0.1,
|
45 |
+
"scaling_activation": "softplus"
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
},
|
50 |
+
"dataset": {
|
51 |
+
"name": "SparseFeat2Render",
|
52 |
+
"args": {
|
53 |
+
"image_size": 512,
|
54 |
+
"model": "dinov2_vitl14_reg",
|
55 |
+
"resolution": 64,
|
56 |
+
"min_aesthetic_score": 4.5,
|
57 |
+
"max_num_voxels": 32768
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"trainer": {
|
61 |
+
"name": "SLatVaeGaussianTrainer",
|
62 |
+
"args": {
|
63 |
+
"max_steps": 1000000,
|
64 |
+
"batch_size_per_gpu": 4,
|
65 |
+
"batch_split": 2,
|
66 |
+
"optimizer": {
|
67 |
+
"name": "AdamW",
|
68 |
+
"args": {
|
69 |
+
"lr": 1e-4,
|
70 |
+
"weight_decay": 0.0
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"ema_rate": [
|
74 |
+
0.9999
|
75 |
+
],
|
76 |
+
"fp16_mode": "inflat_all",
|
77 |
+
"fp16_scale_growth": 0.001,
|
78 |
+
"elastic": {
|
79 |
+
"name": "LinearMemoryController",
|
80 |
+
"args": {
|
81 |
+
"target_ratio": 0.75,
|
82 |
+
"max_mem_ratio_start": 0.5
|
83 |
+
}
|
84 |
+
},
|
85 |
+
"grad_clip": {
|
86 |
+
"name": "AdaptiveGradClipper",
|
87 |
+
"args": {
|
88 |
+
"max_norm": 1.0,
|
89 |
+
"clip_percentile": 95
|
90 |
+
}
|
91 |
+
},
|
92 |
+
"i_log": 500,
|
93 |
+
"i_sample": 10000,
|
94 |
+
"i_save": 10000,
|
95 |
+
"loss_type": "l1",
|
96 |
+
"lambda_ssim": 0.2,
|
97 |
+
"lambda_lpips": 0.2,
|
98 |
+
"lambda_kl": 1e-06,
|
99 |
+
"regularizations": {
|
100 |
+
"lambda_vol": 10000.0,
|
101 |
+
"lambda_opacity": 0.001
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
configs/vae/ss_vae_conv3d_16l8_fp16.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"encoder": {
|
4 |
+
"name": "SparseStructureEncoder",
|
5 |
+
"args": {
|
6 |
+
"in_channels": 1,
|
7 |
+
"latent_channels": 8,
|
8 |
+
"num_res_blocks": 2,
|
9 |
+
"num_res_blocks_middle": 2,
|
10 |
+
"channels": [32, 128, 512],
|
11 |
+
"use_fp16": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"decoder": {
|
15 |
+
"name": "SparseStructureDecoder",
|
16 |
+
"args": {
|
17 |
+
"out_channels": 1,
|
18 |
+
"latent_channels": 8,
|
19 |
+
"num_res_blocks": 2,
|
20 |
+
"num_res_blocks_middle": 2,
|
21 |
+
"channels": [512, 128, 32],
|
22 |
+
"use_fp16": true
|
23 |
+
}
|
24 |
+
}
|
25 |
+
},
|
26 |
+
"dataset": {
|
27 |
+
"name": "SparseStructure",
|
28 |
+
"args": {
|
29 |
+
"resolution": 64,
|
30 |
+
"min_aesthetic_score": 4.5
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"trainer": {
|
34 |
+
"name": "SparseStructureVaeTrainer",
|
35 |
+
"args": {
|
36 |
+
"max_steps": 1000000,
|
37 |
+
"batch_size_per_gpu": 4,
|
38 |
+
"batch_split": 1,
|
39 |
+
"optimizer": {
|
40 |
+
"name": "AdamW",
|
41 |
+
"args": {
|
42 |
+
"lr": 1e-4,
|
43 |
+
"weight_decay": 0.0
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"ema_rate": [
|
47 |
+
0.9999
|
48 |
+
],
|
49 |
+
"fp16_mode": "inflat_all",
|
50 |
+
"fp16_scale_growth": 0.001,
|
51 |
+
"grad_clip": {
|
52 |
+
"name": "AdaptiveGradClipper",
|
53 |
+
"args": {
|
54 |
+
"max_norm": 1.0,
|
55 |
+
"clip_percentile": 95
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"i_log": 500,
|
59 |
+
"i_sample": 10000,
|
60 |
+
"i_save": 10000,
|
61 |
+
"loss_type": "dice",
|
62 |
+
"lambda_kl": 0.001
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
dataset_toolkits/blender_script/io_scene_usdz.zip
ADDED
Binary file (34.7 kB). View file
|
|
dataset_toolkits/blender_script/render.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, sys, os, math, re, glob
|
2 |
+
from typing import *
|
3 |
+
import bpy
|
4 |
+
from mathutils import Vector, Matrix
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import glob
|
8 |
+
|
9 |
+
|
10 |
+
"""=============== BLENDER ==============="""
|
11 |
+
|
12 |
+
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
13 |
+
"obj": bpy.ops.import_scene.obj,
|
14 |
+
"glb": bpy.ops.import_scene.gltf,
|
15 |
+
"gltf": bpy.ops.import_scene.gltf,
|
16 |
+
"usd": bpy.ops.import_scene.usd,
|
17 |
+
"fbx": bpy.ops.import_scene.fbx,
|
18 |
+
"stl": bpy.ops.import_mesh.stl,
|
19 |
+
"usda": bpy.ops.import_scene.usda,
|
20 |
+
"dae": bpy.ops.wm.collada_import,
|
21 |
+
"ply": bpy.ops.import_mesh.ply,
|
22 |
+
"abc": bpy.ops.wm.alembic_import,
|
23 |
+
"blend": bpy.ops.wm.append,
|
24 |
+
}
|
25 |
+
|
26 |
+
EXT = {
|
27 |
+
'PNG': 'png',
|
28 |
+
'JPEG': 'jpg',
|
29 |
+
'OPEN_EXR': 'exr',
|
30 |
+
'TIFF': 'tiff',
|
31 |
+
'BMP': 'bmp',
|
32 |
+
'HDR': 'hdr',
|
33 |
+
'TARGA': 'tga'
|
34 |
+
}
|
35 |
+
|
36 |
+
def init_render(engine='CYCLES', resolution=512, geo_mode=False):
|
37 |
+
bpy.context.scene.render.engine = engine
|
38 |
+
bpy.context.scene.render.resolution_x = resolution
|
39 |
+
bpy.context.scene.render.resolution_y = resolution
|
40 |
+
bpy.context.scene.render.resolution_percentage = 100
|
41 |
+
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
42 |
+
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
|
43 |
+
bpy.context.scene.render.film_transparent = True
|
44 |
+
|
45 |
+
bpy.context.scene.cycles.device = 'GPU'
|
46 |
+
bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
|
47 |
+
bpy.context.scene.cycles.filter_type = 'BOX'
|
48 |
+
bpy.context.scene.cycles.filter_width = 1
|
49 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
50 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
51 |
+
bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
|
52 |
+
bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
|
53 |
+
bpy.context.scene.cycles.use_denoising = True
|
54 |
+
|
55 |
+
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
56 |
+
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
57 |
+
|
58 |
+
def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
|
59 |
+
if not any([save_depth, save_normal, save_albedo, save_mist]):
|
60 |
+
return {}, {}
|
61 |
+
outputs = {}
|
62 |
+
spec_nodes = {}
|
63 |
+
|
64 |
+
bpy.context.scene.use_nodes = True
|
65 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
|
66 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
|
67 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
|
68 |
+
bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
|
69 |
+
|
70 |
+
nodes = bpy.context.scene.node_tree.nodes
|
71 |
+
links = bpy.context.scene.node_tree.links
|
72 |
+
for n in nodes:
|
73 |
+
nodes.remove(n)
|
74 |
+
|
75 |
+
render_layers = nodes.new('CompositorNodeRLayers')
|
76 |
+
|
77 |
+
if save_depth:
|
78 |
+
depth_file_output = nodes.new('CompositorNodeOutputFile')
|
79 |
+
depth_file_output.base_path = ''
|
80 |
+
depth_file_output.file_slots[0].use_node_format = True
|
81 |
+
depth_file_output.format.file_format = 'PNG'
|
82 |
+
depth_file_output.format.color_depth = '16'
|
83 |
+
depth_file_output.format.color_mode = 'BW'
|
84 |
+
# Remap to 0-1
|
85 |
+
map = nodes.new(type="CompositorNodeMapRange")
|
86 |
+
map.inputs[1].default_value = 0 # (min value you will be getting)
|
87 |
+
map.inputs[2].default_value = 10 # (max value you will be getting)
|
88 |
+
map.inputs[3].default_value = 0 # (min value you will map to)
|
89 |
+
map.inputs[4].default_value = 1 # (max value you will map to)
|
90 |
+
|
91 |
+
links.new(render_layers.outputs['Depth'], map.inputs[0])
|
92 |
+
links.new(map.outputs[0], depth_file_output.inputs[0])
|
93 |
+
|
94 |
+
outputs['depth'] = depth_file_output
|
95 |
+
spec_nodes['depth_map'] = map
|
96 |
+
|
97 |
+
if save_normal:
|
98 |
+
normal_file_output = nodes.new('CompositorNodeOutputFile')
|
99 |
+
normal_file_output.base_path = ''
|
100 |
+
normal_file_output.file_slots[0].use_node_format = True
|
101 |
+
normal_file_output.format.file_format = 'OPEN_EXR'
|
102 |
+
normal_file_output.format.color_mode = 'RGB'
|
103 |
+
normal_file_output.format.color_depth = '16'
|
104 |
+
|
105 |
+
links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
|
106 |
+
|
107 |
+
outputs['normal'] = normal_file_output
|
108 |
+
|
109 |
+
if save_albedo:
|
110 |
+
albedo_file_output = nodes.new('CompositorNodeOutputFile')
|
111 |
+
albedo_file_output.base_path = ''
|
112 |
+
albedo_file_output.file_slots[0].use_node_format = True
|
113 |
+
albedo_file_output.format.file_format = 'PNG'
|
114 |
+
albedo_file_output.format.color_mode = 'RGBA'
|
115 |
+
albedo_file_output.format.color_depth = '8'
|
116 |
+
|
117 |
+
alpha_albedo = nodes.new('CompositorNodeSetAlpha')
|
118 |
+
|
119 |
+
links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
|
120 |
+
links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
|
121 |
+
links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
|
122 |
+
|
123 |
+
outputs['albedo'] = albedo_file_output
|
124 |
+
|
125 |
+
if save_mist:
|
126 |
+
bpy.data.worlds['World'].mist_settings.start = 0
|
127 |
+
bpy.data.worlds['World'].mist_settings.depth = 10
|
128 |
+
|
129 |
+
mist_file_output = nodes.new('CompositorNodeOutputFile')
|
130 |
+
mist_file_output.base_path = ''
|
131 |
+
mist_file_output.file_slots[0].use_node_format = True
|
132 |
+
mist_file_output.format.file_format = 'PNG'
|
133 |
+
mist_file_output.format.color_mode = 'BW'
|
134 |
+
mist_file_output.format.color_depth = '16'
|
135 |
+
|
136 |
+
links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
|
137 |
+
|
138 |
+
outputs['mist'] = mist_file_output
|
139 |
+
|
140 |
+
return outputs, spec_nodes
|
141 |
+
|
142 |
+
def init_scene() -> None:
|
143 |
+
"""Resets the scene to a clean state.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
None
|
147 |
+
"""
|
148 |
+
# delete everything
|
149 |
+
for obj in bpy.data.objects:
|
150 |
+
bpy.data.objects.remove(obj, do_unlink=True)
|
151 |
+
|
152 |
+
# delete all the materials
|
153 |
+
for material in bpy.data.materials:
|
154 |
+
bpy.data.materials.remove(material, do_unlink=True)
|
155 |
+
|
156 |
+
# delete all the textures
|
157 |
+
for texture in bpy.data.textures:
|
158 |
+
bpy.data.textures.remove(texture, do_unlink=True)
|
159 |
+
|
160 |
+
# delete all the images
|
161 |
+
for image in bpy.data.images:
|
162 |
+
bpy.data.images.remove(image, do_unlink=True)
|
163 |
+
|
164 |
+
def init_camera():
|
165 |
+
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
|
166 |
+
bpy.context.collection.objects.link(cam)
|
167 |
+
bpy.context.scene.camera = cam
|
168 |
+
cam.data.sensor_height = cam.data.sensor_width = 32
|
169 |
+
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
170 |
+
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
171 |
+
cam_constraint.up_axis = 'UP_Y'
|
172 |
+
cam_empty = bpy.data.objects.new("Empty", None)
|
173 |
+
cam_empty.location = (0, 0, 0)
|
174 |
+
bpy.context.scene.collection.objects.link(cam_empty)
|
175 |
+
cam_constraint.target = cam_empty
|
176 |
+
return cam
|
177 |
+
|
178 |
+
def init_lighting():
|
179 |
+
# Clear existing lights
|
180 |
+
bpy.ops.object.select_all(action="DESELECT")
|
181 |
+
bpy.ops.object.select_by_type(type="LIGHT")
|
182 |
+
bpy.ops.object.delete()
|
183 |
+
|
184 |
+
# Create key light
|
185 |
+
default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
|
186 |
+
bpy.context.collection.objects.link(default_light)
|
187 |
+
default_light.data.energy = 1000
|
188 |
+
default_light.location = (4, 1, 6)
|
189 |
+
default_light.rotation_euler = (0, 0, 0)
|
190 |
+
|
191 |
+
# create top light
|
192 |
+
top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
|
193 |
+
bpy.context.collection.objects.link(top_light)
|
194 |
+
top_light.data.energy = 10000
|
195 |
+
top_light.location = (0, 0, 10)
|
196 |
+
top_light.scale = (100, 100, 100)
|
197 |
+
|
198 |
+
# create bottom light
|
199 |
+
bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
|
200 |
+
bpy.context.collection.objects.link(bottom_light)
|
201 |
+
bottom_light.data.energy = 1000
|
202 |
+
bottom_light.location = (0, 0, -10)
|
203 |
+
bottom_light.rotation_euler = (0, 0, 0)
|
204 |
+
|
205 |
+
return {
|
206 |
+
"default_light": default_light,
|
207 |
+
"top_light": top_light,
|
208 |
+
"bottom_light": bottom_light
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
def load_object(object_path: str) -> None:
|
213 |
+
"""Loads a model with a supported file extension into the scene.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
object_path (str): Path to the model file.
|
217 |
+
|
218 |
+
Raises:
|
219 |
+
ValueError: If the file extension is not supported.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
None
|
223 |
+
"""
|
224 |
+
file_extension = object_path.split(".")[-1].lower()
|
225 |
+
if file_extension is None:
|
226 |
+
raise ValueError(f"Unsupported file type: {object_path}")
|
227 |
+
|
228 |
+
if file_extension == "usdz":
|
229 |
+
# install usdz io package
|
230 |
+
dirname = os.path.dirname(os.path.realpath(__file__))
|
231 |
+
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
232 |
+
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
233 |
+
# enable it
|
234 |
+
addon_name = "io_scene_usdz"
|
235 |
+
bpy.ops.preferences.addon_enable(module=addon_name)
|
236 |
+
# import the usdz
|
237 |
+
from io_scene_usdz.import_usdz import import_usdz
|
238 |
+
|
239 |
+
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
240 |
+
return None
|
241 |
+
|
242 |
+
# load from existing import functions
|
243 |
+
import_function = IMPORT_FUNCTIONS[file_extension]
|
244 |
+
|
245 |
+
print(f"Loading object from {object_path}")
|
246 |
+
if file_extension == "blend":
|
247 |
+
import_function(directory=object_path, link=False)
|
248 |
+
elif file_extension in {"glb", "gltf"}:
|
249 |
+
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
|
250 |
+
else:
|
251 |
+
import_function(filepath=object_path)
|
252 |
+
|
253 |
+
def delete_invisible_objects() -> None:
|
254 |
+
"""Deletes all invisible objects in the scene.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
None
|
258 |
+
"""
|
259 |
+
# bpy.ops.object.mode_set(mode="OBJECT")
|
260 |
+
bpy.ops.object.select_all(action="DESELECT")
|
261 |
+
for obj in bpy.context.scene.objects:
|
262 |
+
if obj.hide_viewport or obj.hide_render:
|
263 |
+
obj.hide_viewport = False
|
264 |
+
obj.hide_render = False
|
265 |
+
obj.hide_select = False
|
266 |
+
obj.select_set(True)
|
267 |
+
bpy.ops.object.delete()
|
268 |
+
|
269 |
+
# Delete invisible collections
|
270 |
+
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
271 |
+
for col in invisible_collections:
|
272 |
+
bpy.data.collections.remove(col)
|
273 |
+
|
274 |
+
def split_mesh_normal():
|
275 |
+
bpy.ops.object.select_all(action="DESELECT")
|
276 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
277 |
+
bpy.context.view_layer.objects.active = objs[0]
|
278 |
+
for obj in objs:
|
279 |
+
obj.select_set(True)
|
280 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
281 |
+
bpy.ops.mesh.select_all(action='SELECT')
|
282 |
+
bpy.ops.mesh.split_normals()
|
283 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
284 |
+
bpy.ops.object.select_all(action="DESELECT")
|
285 |
+
|
286 |
+
def delete_custom_normals():
|
287 |
+
for this_obj in bpy.data.objects:
|
288 |
+
if this_obj.type == "MESH":
|
289 |
+
bpy.context.view_layer.objects.active = this_obj
|
290 |
+
bpy.ops.mesh.customdata_custom_splitnormals_clear()
|
291 |
+
|
292 |
+
def override_material():
|
293 |
+
new_mat = bpy.data.materials.new(name="Override0123456789")
|
294 |
+
new_mat.use_nodes = True
|
295 |
+
new_mat.node_tree.nodes.clear()
|
296 |
+
bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
297 |
+
bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
|
298 |
+
bsdf.inputs[1].default_value = 1
|
299 |
+
output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
300 |
+
new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
301 |
+
bpy.context.scene.view_layers['View Layer'].material_override = new_mat
|
302 |
+
|
303 |
+
def unhide_all_objects() -> None:
|
304 |
+
"""Unhides all objects in the scene.
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
None
|
308 |
+
"""
|
309 |
+
for obj in bpy.context.scene.objects:
|
310 |
+
obj.hide_set(False)
|
311 |
+
|
312 |
+
def convert_to_meshes() -> None:
|
313 |
+
"""Converts all objects in the scene to meshes.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
None
|
317 |
+
"""
|
318 |
+
bpy.ops.object.select_all(action="DESELECT")
|
319 |
+
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
|
320 |
+
for obj in bpy.context.scene.objects:
|
321 |
+
obj.select_set(True)
|
322 |
+
bpy.ops.object.convert(target="MESH")
|
323 |
+
|
324 |
+
def triangulate_meshes() -> None:
|
325 |
+
"""Triangulates all meshes in the scene.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
None
|
329 |
+
"""
|
330 |
+
bpy.ops.object.select_all(action="DESELECT")
|
331 |
+
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
332 |
+
bpy.context.view_layer.objects.active = objs[0]
|
333 |
+
for obj in objs:
|
334 |
+
obj.select_set(True)
|
335 |
+
bpy.ops.object.mode_set(mode="EDIT")
|
336 |
+
bpy.ops.mesh.reveal()
|
337 |
+
bpy.ops.mesh.select_all(action="SELECT")
|
338 |
+
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
|
339 |
+
bpy.ops.object.mode_set(mode="OBJECT")
|
340 |
+
bpy.ops.object.select_all(action="DESELECT")
|
341 |
+
|
342 |
+
def scene_bbox() -> Tuple[Vector, Vector]:
|
343 |
+
"""Returns the bounding box of the scene.
|
344 |
+
|
345 |
+
Taken from Shap-E rendering script
|
346 |
+
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
350 |
+
"""
|
351 |
+
bbox_min = (math.inf,) * 3
|
352 |
+
bbox_max = (-math.inf,) * 3
|
353 |
+
found = False
|
354 |
+
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
355 |
+
for obj in scene_meshes:
|
356 |
+
found = True
|
357 |
+
for coord in obj.bound_box:
|
358 |
+
coord = Vector(coord)
|
359 |
+
coord = obj.matrix_world @ coord
|
360 |
+
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
361 |
+
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
362 |
+
if not found:
|
363 |
+
raise RuntimeError("no objects in scene to compute bounding box for")
|
364 |
+
return Vector(bbox_min), Vector(bbox_max)
|
365 |
+
|
366 |
+
def normalize_scene() -> Tuple[float, Vector]:
|
367 |
+
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
368 |
+
at the origin.
|
369 |
+
|
370 |
+
Mostly taken from the Point-E / Shap-E rendering script
|
371 |
+
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
372 |
+
but fix for multiple root objects: (see bug report here:
|
373 |
+
https://github.com/openai/shap-e/pull/60).
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
377 |
+
"""
|
378 |
+
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
379 |
+
if len(scene_root_objects) > 1:
|
380 |
+
# create an empty object to be used as a parent for all root objects
|
381 |
+
scene = bpy.data.objects.new("ParentEmpty", None)
|
382 |
+
bpy.context.scene.collection.objects.link(scene)
|
383 |
+
|
384 |
+
# parent all root objects to the empty object
|
385 |
+
for obj in scene_root_objects:
|
386 |
+
obj.parent = scene
|
387 |
+
else:
|
388 |
+
scene = scene_root_objects[0]
|
389 |
+
|
390 |
+
bbox_min, bbox_max = scene_bbox()
|
391 |
+
scale = 1 / max(bbox_max - bbox_min)
|
392 |
+
scene.scale = scene.scale * scale
|
393 |
+
|
394 |
+
# Apply scale to matrix_world.
|
395 |
+
bpy.context.view_layer.update()
|
396 |
+
bbox_min, bbox_max = scene_bbox()
|
397 |
+
offset = -(bbox_min + bbox_max) / 2
|
398 |
+
scene.matrix_world.translation += offset
|
399 |
+
bpy.ops.object.select_all(action="DESELECT")
|
400 |
+
|
401 |
+
return scale, offset
|
402 |
+
|
403 |
+
def get_transform_matrix(obj: bpy.types.Object) -> list:
|
404 |
+
pos, rt, _ = obj.matrix_world.decompose()
|
405 |
+
rt = rt.to_matrix()
|
406 |
+
matrix = []
|
407 |
+
for ii in range(3):
|
408 |
+
a = []
|
409 |
+
for jj in range(3):
|
410 |
+
a.append(rt[ii][jj])
|
411 |
+
a.append(pos[ii])
|
412 |
+
matrix.append(a)
|
413 |
+
matrix.append([0, 0, 0, 1])
|
414 |
+
return matrix
|
415 |
+
|
416 |
+
def main(arg):
|
417 |
+
os.makedirs(arg.output_folder, exist_ok=True)
|
418 |
+
|
419 |
+
# Initialize context
|
420 |
+
init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
|
421 |
+
outputs, spec_nodes = init_nodes(
|
422 |
+
save_depth=arg.save_depth,
|
423 |
+
save_normal=arg.save_normal,
|
424 |
+
save_albedo=arg.save_albedo,
|
425 |
+
save_mist=arg.save_mist
|
426 |
+
)
|
427 |
+
if arg.object.endswith(".blend"):
|
428 |
+
delete_invisible_objects()
|
429 |
+
else:
|
430 |
+
init_scene()
|
431 |
+
load_object(arg.object)
|
432 |
+
if arg.split_normal:
|
433 |
+
split_mesh_normal()
|
434 |
+
# delete_custom_normals()
|
435 |
+
print('[INFO] Scene initialized.')
|
436 |
+
|
437 |
+
# normalize scene
|
438 |
+
scale, offset = normalize_scene()
|
439 |
+
print('[INFO] Scene normalized.')
|
440 |
+
|
441 |
+
# Initialize camera and lighting
|
442 |
+
cam = init_camera()
|
443 |
+
init_lighting()
|
444 |
+
print('[INFO] Camera and lighting initialized.')
|
445 |
+
|
446 |
+
# Override material
|
447 |
+
if arg.geo_mode:
|
448 |
+
override_material()
|
449 |
+
|
450 |
+
# Create a list of views
|
451 |
+
to_export = {
|
452 |
+
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
453 |
+
"scale": scale,
|
454 |
+
"offset": [offset.x, offset.y, offset.z],
|
455 |
+
"frames": []
|
456 |
+
}
|
457 |
+
views = json.loads(arg.views)
|
458 |
+
for i, view in enumerate(views):
|
459 |
+
cam.location = (
|
460 |
+
view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
|
461 |
+
view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
|
462 |
+
view['radius'] * np.sin(view['pitch'])
|
463 |
+
)
|
464 |
+
cam.data.lens = 16 / np.tan(view['fov'] / 2)
|
465 |
+
|
466 |
+
if arg.save_depth:
|
467 |
+
spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
|
468 |
+
spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
|
469 |
+
|
470 |
+
bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
|
471 |
+
for name, output in outputs.items():
|
472 |
+
output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
|
473 |
+
|
474 |
+
# Render the scene
|
475 |
+
bpy.ops.render.render(write_still=True)
|
476 |
+
bpy.context.view_layer.update()
|
477 |
+
for name, output in outputs.items():
|
478 |
+
ext = EXT[output.format.file_format]
|
479 |
+
path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
|
480 |
+
os.rename(path, f'{output.file_slots[0].path}.{ext}')
|
481 |
+
|
482 |
+
# Save camera parameters
|
483 |
+
metadata = {
|
484 |
+
"file_path": f'{i:03d}.png',
|
485 |
+
"camera_angle_x": view['fov'],
|
486 |
+
"transform_matrix": get_transform_matrix(cam)
|
487 |
+
}
|
488 |
+
if arg.save_depth:
|
489 |
+
metadata['depth'] = {
|
490 |
+
'min': view['radius'] - 0.5 * np.sqrt(3),
|
491 |
+
'max': view['radius'] + 0.5 * np.sqrt(3)
|
492 |
+
}
|
493 |
+
to_export["frames"].append(metadata)
|
494 |
+
|
495 |
+
# Save the camera parameters
|
496 |
+
with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
|
497 |
+
json.dump(to_export, f, indent=4)
|
498 |
+
|
499 |
+
if arg.save_mesh:
|
500 |
+
# triangulate meshes
|
501 |
+
unhide_all_objects()
|
502 |
+
convert_to_meshes()
|
503 |
+
triangulate_meshes()
|
504 |
+
print('[INFO] Meshes triangulated.')
|
505 |
+
|
506 |
+
# export ply mesh
|
507 |
+
bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == '__main__':
|
511 |
+
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
512 |
+
parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
|
513 |
+
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
514 |
+
parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
|
515 |
+
parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
|
516 |
+
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
517 |
+
parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
|
518 |
+
parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
|
519 |
+
parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
|
520 |
+
parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
|
521 |
+
parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
|
522 |
+
parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
|
523 |
+
parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
|
524 |
+
argv = sys.argv[sys.argv.index("--") + 1:]
|
525 |
+
args = parser.parse_args(argv)
|
526 |
+
|
527 |
+
main(args)
|
528 |
+
|
dataset_toolkits/build_metadata.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from tqdm import tqdm
|
10 |
+
from easydict import EasyDict as edict
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
import utils3d
|
13 |
+
|
14 |
+
def get_first_directory(path):
|
15 |
+
with os.scandir(path) as it:
|
16 |
+
for entry in it:
|
17 |
+
if entry.is_dir():
|
18 |
+
return entry.name
|
19 |
+
return None
|
20 |
+
|
21 |
+
def need_process(key):
|
22 |
+
return key in opt.field or opt.field == ['all']
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
29 |
+
help='Directory to save the metadata')
|
30 |
+
parser.add_argument('--field', type=str, default='all',
|
31 |
+
help='Fields to process, separated by commas')
|
32 |
+
parser.add_argument('--from_file', action='store_true',
|
33 |
+
help='Build metadata from file instead of from records of processings.' +
|
34 |
+
'Useful when some processing fail to generate records but file already exists.')
|
35 |
+
dataset_utils.add_args(parser)
|
36 |
+
opt = parser.parse_args(sys.argv[2:])
|
37 |
+
opt = edict(vars(opt))
|
38 |
+
|
39 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
40 |
+
os.makedirs(os.path.join(opt.output_dir, 'merged_records'), exist_ok=True)
|
41 |
+
|
42 |
+
opt.field = opt.field.split(',')
|
43 |
+
|
44 |
+
timestamp = str(int(time.time()))
|
45 |
+
|
46 |
+
# get file list
|
47 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
48 |
+
print('Loading previous metadata...')
|
49 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
50 |
+
else:
|
51 |
+
metadata = dataset_utils.get_metadata(**opt)
|
52 |
+
metadata.set_index('sha256', inplace=True)
|
53 |
+
|
54 |
+
# merge downloaded
|
55 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('downloaded_') and f.endswith('.csv')]
|
56 |
+
df_parts = []
|
57 |
+
for f in df_files:
|
58 |
+
try:
|
59 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
if len(df_parts) > 0:
|
63 |
+
df = pd.concat(df_parts)
|
64 |
+
df.set_index('sha256', inplace=True)
|
65 |
+
if 'local_path' in metadata.columns:
|
66 |
+
metadata.update(df, overwrite=True)
|
67 |
+
else:
|
68 |
+
metadata = metadata.join(df, on='sha256', how='left')
|
69 |
+
for f in df_files:
|
70 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
71 |
+
|
72 |
+
# detect models
|
73 |
+
image_models = []
|
74 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features')):
|
75 |
+
image_models = os.listdir(os.path.join(opt.output_dir, 'features'))
|
76 |
+
latent_models = []
|
77 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents')):
|
78 |
+
latent_models = os.listdir(os.path.join(opt.output_dir, 'latents'))
|
79 |
+
ss_latent_models = []
|
80 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents')):
|
81 |
+
ss_latent_models = os.listdir(os.path.join(opt.output_dir, 'ss_latents'))
|
82 |
+
print(f'Image models: {image_models}')
|
83 |
+
print(f'Latent models: {latent_models}')
|
84 |
+
print(f'Sparse Structure latent models: {ss_latent_models}')
|
85 |
+
|
86 |
+
if 'rendered' not in metadata.columns:
|
87 |
+
metadata['rendered'] = [False] * len(metadata)
|
88 |
+
if 'voxelized' not in metadata.columns:
|
89 |
+
metadata['voxelized'] = [False] * len(metadata)
|
90 |
+
if 'num_voxels' not in metadata.columns:
|
91 |
+
metadata['num_voxels'] = [0] * len(metadata)
|
92 |
+
if 'cond_rendered' not in metadata.columns:
|
93 |
+
metadata['cond_rendered'] = [False] * len(metadata)
|
94 |
+
for model in image_models:
|
95 |
+
if f'feature_{model}' not in metadata.columns:
|
96 |
+
metadata[f'feature_{model}'] = [False] * len(metadata)
|
97 |
+
for model in latent_models:
|
98 |
+
if f'latent_{model}' not in metadata.columns:
|
99 |
+
metadata[f'latent_{model}'] = [False] * len(metadata)
|
100 |
+
for model in ss_latent_models:
|
101 |
+
if f'ss_latent_{model}' not in metadata.columns:
|
102 |
+
metadata[f'ss_latent_{model}'] = [False] * len(metadata)
|
103 |
+
|
104 |
+
# merge rendered
|
105 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('rendered_') and f.endswith('.csv')]
|
106 |
+
df_parts = []
|
107 |
+
for f in df_files:
|
108 |
+
try:
|
109 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
110 |
+
except:
|
111 |
+
pass
|
112 |
+
if len(df_parts) > 0:
|
113 |
+
df = pd.concat(df_parts)
|
114 |
+
df.set_index('sha256', inplace=True)
|
115 |
+
metadata.update(df, overwrite=True)
|
116 |
+
for f in df_files:
|
117 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
118 |
+
|
119 |
+
# merge voxelized
|
120 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('voxelized_') and f.endswith('.csv')]
|
121 |
+
df_parts = []
|
122 |
+
for f in df_files:
|
123 |
+
try:
|
124 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
125 |
+
except:
|
126 |
+
pass
|
127 |
+
if len(df_parts) > 0:
|
128 |
+
df = pd.concat(df_parts)
|
129 |
+
df.set_index('sha256', inplace=True)
|
130 |
+
metadata.update(df, overwrite=True)
|
131 |
+
for f in df_files:
|
132 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
133 |
+
|
134 |
+
# merge cond_rendered
|
135 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith('cond_rendered_') and f.endswith('.csv')]
|
136 |
+
df_parts = []
|
137 |
+
for f in df_files:
|
138 |
+
try:
|
139 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
140 |
+
except:
|
141 |
+
pass
|
142 |
+
if len(df_parts) > 0:
|
143 |
+
df = pd.concat(df_parts)
|
144 |
+
df.set_index('sha256', inplace=True)
|
145 |
+
metadata.update(df, overwrite=True)
|
146 |
+
for f in df_files:
|
147 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
148 |
+
|
149 |
+
# merge features
|
150 |
+
for model in image_models:
|
151 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'feature_{model}_') and f.endswith('.csv')]
|
152 |
+
df_parts = []
|
153 |
+
for f in df_files:
|
154 |
+
try:
|
155 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
156 |
+
except:
|
157 |
+
pass
|
158 |
+
if len(df_parts) > 0:
|
159 |
+
df = pd.concat(df_parts)
|
160 |
+
df.set_index('sha256', inplace=True)
|
161 |
+
metadata.update(df, overwrite=True)
|
162 |
+
for f in df_files:
|
163 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
164 |
+
|
165 |
+
# merge latents
|
166 |
+
for model in latent_models:
|
167 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'latent_{model}_') and f.endswith('.csv')]
|
168 |
+
df_parts = []
|
169 |
+
for f in df_files:
|
170 |
+
try:
|
171 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
172 |
+
except:
|
173 |
+
pass
|
174 |
+
if len(df_parts) > 0:
|
175 |
+
df = pd.concat(df_parts)
|
176 |
+
df.set_index('sha256', inplace=True)
|
177 |
+
metadata.update(df, overwrite=True)
|
178 |
+
for f in df_files:
|
179 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
180 |
+
|
181 |
+
# merge sparse structure latents
|
182 |
+
for model in ss_latent_models:
|
183 |
+
df_files = [f for f in os.listdir(opt.output_dir) if f.startswith(f'ss_latent_{model}_') and f.endswith('.csv')]
|
184 |
+
df_parts = []
|
185 |
+
for f in df_files:
|
186 |
+
try:
|
187 |
+
df_parts.append(pd.read_csv(os.path.join(opt.output_dir, f)))
|
188 |
+
except:
|
189 |
+
pass
|
190 |
+
if len(df_parts) > 0:
|
191 |
+
df = pd.concat(df_parts)
|
192 |
+
df.set_index('sha256', inplace=True)
|
193 |
+
metadata.update(df, overwrite=True)
|
194 |
+
for f in df_files:
|
195 |
+
shutil.move(os.path.join(opt.output_dir, f), os.path.join(opt.output_dir, 'merged_records', f'{timestamp}_{f}'))
|
196 |
+
|
197 |
+
# build metadata from files
|
198 |
+
if opt.from_file:
|
199 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
200 |
+
tqdm(total=len(metadata), desc="Building metadata") as pbar:
|
201 |
+
def worker(sha256):
|
202 |
+
try:
|
203 |
+
if need_process('rendered') and metadata.loc[sha256, 'rendered'] == False and \
|
204 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
205 |
+
metadata.loc[sha256, 'rendered'] = True
|
206 |
+
if need_process('voxelized') and metadata.loc[sha256, 'rendered'] == True and metadata.loc[sha256, 'voxelized'] == False and \
|
207 |
+
os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
208 |
+
try:
|
209 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
210 |
+
metadata.loc[sha256, 'voxelized'] = True
|
211 |
+
metadata.loc[sha256, 'num_voxels'] = len(pts)
|
212 |
+
except Exception as e:
|
213 |
+
pass
|
214 |
+
if need_process('cond_rendered') and metadata.loc[sha256, 'cond_rendered'] == False and \
|
215 |
+
os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
216 |
+
metadata.loc[sha256, 'cond_rendered'] = True
|
217 |
+
for model in image_models:
|
218 |
+
if need_process(f'feature_{model}') and \
|
219 |
+
metadata.loc[sha256, f'feature_{model}'] == False and \
|
220 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
221 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
222 |
+
os.path.exists(os.path.join(opt.output_dir, 'features', model, f'{sha256}.npz')):
|
223 |
+
metadata.loc[sha256, f'feature_{model}'] = True
|
224 |
+
for model in latent_models:
|
225 |
+
if need_process(f'latent_{model}') and \
|
226 |
+
metadata.loc[sha256, f'latent_{model}'] == False and \
|
227 |
+
metadata.loc[sha256, 'rendered'] == True and \
|
228 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
229 |
+
os.path.exists(os.path.join(opt.output_dir, 'latents', model, f'{sha256}.npz')):
|
230 |
+
metadata.loc[sha256, f'latent_{model}'] = True
|
231 |
+
for model in ss_latent_models:
|
232 |
+
if need_process(f'ss_latent_{model}') and \
|
233 |
+
metadata.loc[sha256, f'ss_latent_{model}'] == False and \
|
234 |
+
metadata.loc[sha256, 'voxelized'] == True and \
|
235 |
+
os.path.exists(os.path.join(opt.output_dir, 'ss_latents', model, f'{sha256}.npz')):
|
236 |
+
metadata.loc[sha256, f'ss_latent_{model}'] = True
|
237 |
+
pbar.update()
|
238 |
+
except Exception as e:
|
239 |
+
print(f'Error processing {sha256}: {e}')
|
240 |
+
pbar.update()
|
241 |
+
|
242 |
+
executor.map(worker, metadata.index)
|
243 |
+
executor.shutdown(wait=True)
|
244 |
+
|
245 |
+
# statistics
|
246 |
+
metadata.to_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
247 |
+
num_downloaded = metadata['local_path'].count() if 'local_path' in metadata.columns else 0
|
248 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'w') as f:
|
249 |
+
f.write('Statistics:\n')
|
250 |
+
f.write(f' - Number of assets: {len(metadata)}\n')
|
251 |
+
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
|
252 |
+
f.write(f' - Number of assets rendered: {metadata["rendered"].sum()}\n')
|
253 |
+
f.write(f' - Number of assets voxelized: {metadata["voxelized"].sum()}\n')
|
254 |
+
if len(image_models) != 0:
|
255 |
+
f.write(f' - Number of assets with image features extracted:\n')
|
256 |
+
for model in image_models:
|
257 |
+
f.write(f' - {model}: {metadata[f"feature_{model}"].sum()}\n')
|
258 |
+
if len(latent_models) != 0:
|
259 |
+
f.write(f' - Number of assets with latents extracted:\n')
|
260 |
+
for model in latent_models:
|
261 |
+
f.write(f' - {model}: {metadata[f"latent_{model}"].sum()}\n')
|
262 |
+
if len(ss_latent_models) != 0:
|
263 |
+
f.write(f' - Number of assets with sparse structure latents extracted:\n')
|
264 |
+
for model in ss_latent_models:
|
265 |
+
f.write(f' - {model}: {metadata[f"ss_latent_{model}"].sum()}\n')
|
266 |
+
f.write(f' - Number of assets with captions: {metadata["captions"].count()}\n')
|
267 |
+
f.write(f' - Number of assets with image conditions: {metadata["cond_rendered"].sum()}\n')
|
268 |
+
|
269 |
+
with open(os.path.join(opt.output_dir, 'statistics.txt'), 'r') as f:
|
270 |
+
print(f.read())
|
dataset_toolkits/datasets/3D-FUTURE.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import zipfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/3D-FUTURE.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')):
|
24 |
+
print("\033[93m")
|
25 |
+
print("3D-FUTURE have to be downloaded manually")
|
26 |
+
print(f"Please download the 3D-FUTURE-model.zip file and place it in the {output_dir}/raw directory")
|
27 |
+
print("Visit https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future for more information")
|
28 |
+
print("\033[0m")
|
29 |
+
raise FileNotFoundError("3D-FUTURE-model.zip not found")
|
30 |
+
|
31 |
+
downloaded = {}
|
32 |
+
metadata = metadata.set_index("file_identifier")
|
33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', '3D-FUTURE-model.zip')) as zip_ref:
|
34 |
+
all_names = zip_ref.namelist()
|
35 |
+
instances = [instance[:-1] for instance in all_names if re.match(r"^3D-FUTURE-model/[^/]+/$", instance)]
|
36 |
+
instances = list(filter(lambda x: x in metadata.index, instances))
|
37 |
+
|
38 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
39 |
+
tqdm(total=len(instances), desc="Extracting") as pbar:
|
40 |
+
def worker(instance: str) -> str:
|
41 |
+
try:
|
42 |
+
instance_files = list(filter(lambda x: x.startswith(f"{instance}/") and not x.endswith("/"), all_names))
|
43 |
+
zip_ref.extractall(os.path.join(output_dir, 'raw'), members=instance_files)
|
44 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', f"{instance}/image.jpg"))
|
45 |
+
pbar.update()
|
46 |
+
return sha256
|
47 |
+
except Exception as e:
|
48 |
+
pbar.update()
|
49 |
+
print(f"Error extracting for {instance}: {e}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
sha256s = executor.map(worker, instances)
|
53 |
+
executor.shutdown(wait=True)
|
54 |
+
|
55 |
+
for k, sha256 in zip(instances, sha256s):
|
56 |
+
if sha256 is not None:
|
57 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
58 |
+
downloaded[sha256] = os.path.join("raw", f"{k}/raw_model.obj")
|
59 |
+
else:
|
60 |
+
print(f"Error downloading {k}: sha256s do not match")
|
61 |
+
|
62 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
63 |
+
|
64 |
+
|
65 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
66 |
+
import os
|
67 |
+
from concurrent.futures import ThreadPoolExecutor
|
68 |
+
from tqdm import tqdm
|
69 |
+
|
70 |
+
# load metadata
|
71 |
+
metadata = metadata.to_dict('records')
|
72 |
+
|
73 |
+
# processing objects
|
74 |
+
records = []
|
75 |
+
max_workers = max_workers or os.cpu_count()
|
76 |
+
try:
|
77 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
78 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
79 |
+
def worker(metadatum):
|
80 |
+
try:
|
81 |
+
local_path = metadatum['local_path']
|
82 |
+
sha256 = metadatum['sha256']
|
83 |
+
file = os.path.join(output_dir, local_path)
|
84 |
+
record = func(file, sha256)
|
85 |
+
if record is not None:
|
86 |
+
records.append(record)
|
87 |
+
pbar.update()
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error processing object {sha256}: {e}")
|
90 |
+
pbar.update()
|
91 |
+
|
92 |
+
executor.map(worker, metadata)
|
93 |
+
executor.shutdown(wait=True)
|
94 |
+
except:
|
95 |
+
print("Error happened during processing.")
|
96 |
+
|
97 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ABO.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import tarfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ABO.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')):
|
24 |
+
try:
|
25 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
26 |
+
os.system(f"wget -O {output_dir}/raw/abo-3dmodels.tar https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-3dmodels.tar")
|
27 |
+
except:
|
28 |
+
print("\033[93m")
|
29 |
+
print("Error downloading ABO dataset. Please check your internet connection and try again.")
|
30 |
+
print("Or, you can manually download the abo-3dmodels.tar file and place it in the {output_dir}/raw directory")
|
31 |
+
print("Visit https://amazon-berkeley-objects.s3.amazonaws.com/index.html for more information")
|
32 |
+
print("\033[0m")
|
33 |
+
raise FileNotFoundError("Error downloading ABO dataset")
|
34 |
+
|
35 |
+
downloaded = {}
|
36 |
+
metadata = metadata.set_index("file_identifier")
|
37 |
+
with tarfile.open(os.path.join(output_dir, 'raw', 'abo-3dmodels.tar')) as tar:
|
38 |
+
with ThreadPoolExecutor(max_workers=1) as executor, \
|
39 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
40 |
+
def worker(instance: str) -> str:
|
41 |
+
try:
|
42 |
+
tar.extract(f"3dmodels/original/{instance}", path=os.path.join(output_dir, 'raw'))
|
43 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/3dmodels/original', instance))
|
44 |
+
pbar.update()
|
45 |
+
return sha256
|
46 |
+
except Exception as e:
|
47 |
+
pbar.update()
|
48 |
+
print(f"Error extracting for {instance}: {e}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
sha256s = executor.map(worker, metadata.index)
|
52 |
+
executor.shutdown(wait=True)
|
53 |
+
|
54 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
55 |
+
if sha256 is not None:
|
56 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
57 |
+
downloaded[sha256] = os.path.join('raw/3dmodels/original', k)
|
58 |
+
else:
|
59 |
+
print(f"Error downloading {k}: sha256s do not match")
|
60 |
+
|
61 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
62 |
+
|
63 |
+
|
64 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
65 |
+
import os
|
66 |
+
from concurrent.futures import ThreadPoolExecutor
|
67 |
+
from tqdm import tqdm
|
68 |
+
|
69 |
+
# load metadata
|
70 |
+
metadata = metadata.to_dict('records')
|
71 |
+
|
72 |
+
# processing objects
|
73 |
+
records = []
|
74 |
+
max_workers = max_workers or os.cpu_count()
|
75 |
+
try:
|
76 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
77 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
78 |
+
def worker(metadatum):
|
79 |
+
try:
|
80 |
+
local_path = metadatum['local_path']
|
81 |
+
sha256 = metadatum['sha256']
|
82 |
+
file = os.path.join(output_dir, local_path)
|
83 |
+
record = func(file, sha256)
|
84 |
+
if record is not None:
|
85 |
+
records.append(record)
|
86 |
+
pbar.update()
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error processing object {sha256}: {e}")
|
89 |
+
pbar.update()
|
90 |
+
|
91 |
+
executor.map(worker, metadata)
|
92 |
+
executor.shutdown(wait=True)
|
93 |
+
except:
|
94 |
+
print("Error happened during processing.")
|
95 |
+
|
96 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/HSSD.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import tarfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
import huggingface_hub
|
9 |
+
from utils import get_file_hash
|
10 |
+
|
11 |
+
|
12 |
+
def add_args(parser: argparse.ArgumentParser):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
def get_metadata(**kwargs):
|
17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/HSSD.csv")
|
18 |
+
return metadata
|
19 |
+
|
20 |
+
|
21 |
+
def download(metadata, output_dir, **kwargs):
|
22 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
23 |
+
|
24 |
+
# check login
|
25 |
+
try:
|
26 |
+
huggingface_hub.whoami()
|
27 |
+
except:
|
28 |
+
print("\033[93m")
|
29 |
+
print("Haven't logged in to the Hugging Face Hub.")
|
30 |
+
print("Visit https://huggingface.co/settings/tokens to get a token.")
|
31 |
+
print("\033[0m")
|
32 |
+
huggingface_hub.login()
|
33 |
+
|
34 |
+
try:
|
35 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename="README.md", repo_type="dataset")
|
36 |
+
except:
|
37 |
+
print("\033[93m")
|
38 |
+
print("Error downloading HSSD dataset.")
|
39 |
+
print("Check if you have access to the HSSD dataset.")
|
40 |
+
print("Visit https://huggingface.co/datasets/hssd/hssd-models for more information")
|
41 |
+
print("\033[0m")
|
42 |
+
|
43 |
+
downloaded = {}
|
44 |
+
metadata = metadata.set_index("file_identifier")
|
45 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
46 |
+
tqdm(total=len(metadata), desc="Downloading") as pbar:
|
47 |
+
def worker(instance: str) -> str:
|
48 |
+
try:
|
49 |
+
huggingface_hub.hf_hub_download(repo_id="hssd/hssd-models", filename=instance, repo_type="dataset", local_dir=os.path.join(output_dir, 'raw'))
|
50 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw', instance))
|
51 |
+
pbar.update()
|
52 |
+
return sha256
|
53 |
+
except Exception as e:
|
54 |
+
pbar.update()
|
55 |
+
print(f"Error extracting for {instance}: {e}")
|
56 |
+
return None
|
57 |
+
|
58 |
+
sha256s = executor.map(worker, metadata.index)
|
59 |
+
executor.shutdown(wait=True)
|
60 |
+
|
61 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
62 |
+
if sha256 is not None:
|
63 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
64 |
+
downloaded[sha256] = os.path.join('raw', k)
|
65 |
+
else:
|
66 |
+
print(f"Error downloading {k}: sha256s do not match")
|
67 |
+
|
68 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
69 |
+
|
70 |
+
|
71 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
72 |
+
import os
|
73 |
+
from concurrent.futures import ThreadPoolExecutor
|
74 |
+
from tqdm import tqdm
|
75 |
+
|
76 |
+
# load metadata
|
77 |
+
metadata = metadata.to_dict('records')
|
78 |
+
|
79 |
+
# processing objects
|
80 |
+
records = []
|
81 |
+
max_workers = max_workers or os.cpu_count()
|
82 |
+
try:
|
83 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
84 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
85 |
+
def worker(metadatum):
|
86 |
+
try:
|
87 |
+
local_path = metadatum['local_path']
|
88 |
+
sha256 = metadatum['sha256']
|
89 |
+
file = os.path.join(output_dir, local_path)
|
90 |
+
record = func(file, sha256)
|
91 |
+
if record is not None:
|
92 |
+
records.append(record)
|
93 |
+
pbar.update()
|
94 |
+
except Exception as e:
|
95 |
+
print(f"Error processing object {sha256}: {e}")
|
96 |
+
pbar.update()
|
97 |
+
|
98 |
+
executor.map(worker, metadata)
|
99 |
+
executor.shutdown(wait=True)
|
100 |
+
except:
|
101 |
+
print("Error happened during processing.")
|
102 |
+
|
103 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/ObjaverseXL.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pandas as pd
|
6 |
+
import objaverse.xl as oxl
|
7 |
+
from utils import get_file_hash
|
8 |
+
|
9 |
+
|
10 |
+
def add_args(parser: argparse.ArgumentParser):
|
11 |
+
parser.add_argument('--source', type=str, default='sketchfab',
|
12 |
+
help='Data source to download annotations from (github, sketchfab)')
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(source, **kwargs):
|
16 |
+
if source == 'sketchfab':
|
17 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv")
|
18 |
+
elif source == 'github':
|
19 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv")
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Invalid source: {source}")
|
22 |
+
return metadata
|
23 |
+
|
24 |
+
|
25 |
+
def download(metadata, output_dir, **kwargs):
|
26 |
+
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
27 |
+
|
28 |
+
# download annotations
|
29 |
+
annotations = oxl.get_annotations()
|
30 |
+
annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)]
|
31 |
+
|
32 |
+
# download and render objects
|
33 |
+
file_paths = oxl.download_objects(
|
34 |
+
annotations,
|
35 |
+
download_dir=os.path.join(output_dir, "raw"),
|
36 |
+
save_repo_format="zip",
|
37 |
+
)
|
38 |
+
|
39 |
+
downloaded = {}
|
40 |
+
metadata = metadata.set_index("file_identifier")
|
41 |
+
for k, v in file_paths.items():
|
42 |
+
sha256 = metadata.loc[k, "sha256"]
|
43 |
+
downloaded[sha256] = os.path.relpath(v, output_dir)
|
44 |
+
|
45 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
46 |
+
|
47 |
+
|
48 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
49 |
+
import os
|
50 |
+
from concurrent.futures import ThreadPoolExecutor
|
51 |
+
from tqdm import tqdm
|
52 |
+
import tempfile
|
53 |
+
import zipfile
|
54 |
+
|
55 |
+
# load metadata
|
56 |
+
metadata = metadata.to_dict('records')
|
57 |
+
|
58 |
+
# processing objects
|
59 |
+
records = []
|
60 |
+
max_workers = max_workers or os.cpu_count()
|
61 |
+
try:
|
62 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
63 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
64 |
+
def worker(metadatum):
|
65 |
+
try:
|
66 |
+
local_path = metadatum['local_path']
|
67 |
+
sha256 = metadatum['sha256']
|
68 |
+
if local_path.startswith('raw/github/repos/'):
|
69 |
+
path_parts = local_path.split('/')
|
70 |
+
file_name = os.path.join(*path_parts[5:])
|
71 |
+
zip_file = os.path.join(output_dir, *path_parts[:5])
|
72 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
73 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
74 |
+
zip_ref.extractall(tmp_dir)
|
75 |
+
file = os.path.join(tmp_dir, file_name)
|
76 |
+
record = func(file, sha256)
|
77 |
+
else:
|
78 |
+
file = os.path.join(output_dir, local_path)
|
79 |
+
record = func(file, sha256)
|
80 |
+
if record is not None:
|
81 |
+
records.append(record)
|
82 |
+
pbar.update()
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error processing object {sha256}: {e}")
|
85 |
+
pbar.update()
|
86 |
+
|
87 |
+
executor.map(worker, metadata)
|
88 |
+
executor.shutdown(wait=True)
|
89 |
+
except:
|
90 |
+
print("Error happened during processing.")
|
91 |
+
|
92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/datasets/Toys4k.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
import zipfile
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from tqdm import tqdm
|
7 |
+
import pandas as pd
|
8 |
+
from utils import get_file_hash
|
9 |
+
|
10 |
+
|
11 |
+
def add_args(parser: argparse.ArgumentParser):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
def get_metadata(**kwargs):
|
16 |
+
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv")
|
17 |
+
return metadata
|
18 |
+
|
19 |
+
|
20 |
+
def download(metadata, output_dir, **kwargs):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
|
23 |
+
if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')):
|
24 |
+
print("\033[93m")
|
25 |
+
print("Toys4k have to be downloaded manually")
|
26 |
+
print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory")
|
27 |
+
print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information")
|
28 |
+
print("\033[0m")
|
29 |
+
raise FileNotFoundError("toys4k_blend_files.zip not found")
|
30 |
+
|
31 |
+
downloaded = {}
|
32 |
+
metadata = metadata.set_index("file_identifier")
|
33 |
+
with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref:
|
34 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
|
35 |
+
tqdm(total=len(metadata), desc="Extracting") as pbar:
|
36 |
+
def worker(instance: str) -> str:
|
37 |
+
try:
|
38 |
+
zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw'))
|
39 |
+
sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance))
|
40 |
+
pbar.update()
|
41 |
+
return sha256
|
42 |
+
except Exception as e:
|
43 |
+
pbar.update()
|
44 |
+
print(f"Error extracting for {instance}: {e}")
|
45 |
+
return None
|
46 |
+
|
47 |
+
sha256s = executor.map(worker, metadata.index)
|
48 |
+
executor.shutdown(wait=True)
|
49 |
+
|
50 |
+
for k, sha256 in zip(metadata.index, sha256s):
|
51 |
+
if sha256 is not None:
|
52 |
+
if sha256 == metadata.loc[k, "sha256"]:
|
53 |
+
downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k)
|
54 |
+
else:
|
55 |
+
print(f"Error downloading {k}: sha256s do not match")
|
56 |
+
|
57 |
+
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
58 |
+
|
59 |
+
|
60 |
+
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
61 |
+
import os
|
62 |
+
from concurrent.futures import ThreadPoolExecutor
|
63 |
+
from tqdm import tqdm
|
64 |
+
|
65 |
+
# load metadata
|
66 |
+
metadata = metadata.to_dict('records')
|
67 |
+
|
68 |
+
# processing objects
|
69 |
+
records = []
|
70 |
+
max_workers = max_workers or os.cpu_count()
|
71 |
+
try:
|
72 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
73 |
+
tqdm(total=len(metadata), desc=desc) as pbar:
|
74 |
+
def worker(metadatum):
|
75 |
+
try:
|
76 |
+
local_path = metadatum['local_path']
|
77 |
+
sha256 = metadatum['sha256']
|
78 |
+
file = os.path.join(output_dir, local_path)
|
79 |
+
record = func(file, sha256)
|
80 |
+
if record is not None:
|
81 |
+
records.append(record)
|
82 |
+
pbar.update()
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error processing object {sha256}: {e}")
|
85 |
+
pbar.update()
|
86 |
+
|
87 |
+
executor.map(worker, metadata)
|
88 |
+
executor.shutdown(wait=True)
|
89 |
+
except:
|
90 |
+
print("Error happened during processing.")
|
91 |
+
|
92 |
+
return pd.DataFrame.from_records(records)
|
dataset_toolkits/download.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import importlib
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
14 |
+
help='Directory to save the metadata')
|
15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
16 |
+
help='Filter objects with aesthetic score lower than this value')
|
17 |
+
parser.add_argument('--instances', type=str, default=None,
|
18 |
+
help='Instances to process')
|
19 |
+
dataset_utils.add_args(parser)
|
20 |
+
parser.add_argument('--rank', type=int, default=0)
|
21 |
+
parser.add_argument('--world_size', type=int, default=1)
|
22 |
+
opt = parser.parse_args(sys.argv[2:])
|
23 |
+
opt = edict(vars(opt))
|
24 |
+
|
25 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
26 |
+
|
27 |
+
# get file list
|
28 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
29 |
+
raise ValueError('metadata.csv not found')
|
30 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
31 |
+
if opt.instances is None:
|
32 |
+
if opt.filter_low_aesthetic_score is not None:
|
33 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
34 |
+
if 'local_path' in metadata.columns:
|
35 |
+
metadata = metadata[metadata['local_path'].isna()]
|
36 |
+
else:
|
37 |
+
if os.path.exists(opt.instances):
|
38 |
+
with open(opt.instances, 'r') as f:
|
39 |
+
instances = f.read().splitlines()
|
40 |
+
else:
|
41 |
+
instances = opt.instances.split(',')
|
42 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
43 |
+
|
44 |
+
start = len(metadata) * opt.rank // opt.world_size
|
45 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
46 |
+
metadata = metadata[start:end]
|
47 |
+
|
48 |
+
print(f'Processing {len(metadata)} objects...')
|
49 |
+
|
50 |
+
# process objects
|
51 |
+
downloaded = dataset_utils.download(metadata, **opt)
|
52 |
+
downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_latent.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
4 |
+
import copy
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
from easydict import EasyDict as edict
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
+
from queue import Queue
|
14 |
+
|
15 |
+
import trellis.models as models
|
16 |
+
import trellis.modules.sparse as sp
|
17 |
+
|
18 |
+
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
25 |
+
help='Directory to save the metadata')
|
26 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
27 |
+
help='Filter objects with aesthetic score lower than this value')
|
28 |
+
parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
|
29 |
+
help='Feature model')
|
30 |
+
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
|
31 |
+
help='Pretrained encoder model')
|
32 |
+
parser.add_argument('--model_root', type=str, default='results',
|
33 |
+
help='Root directory of models')
|
34 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
35 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
36 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
37 |
+
help='Checkpoint to load')
|
38 |
+
parser.add_argument('--instances', type=str, default=None,
|
39 |
+
help='Instances to process')
|
40 |
+
parser.add_argument('--rank', type=int, default=0)
|
41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
42 |
+
opt = parser.parse_args()
|
43 |
+
opt = edict(vars(opt))
|
44 |
+
|
45 |
+
if opt.enc_model is None:
|
46 |
+
latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
|
47 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
48 |
+
else:
|
49 |
+
latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
|
50 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
51 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
52 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
53 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
54 |
+
encoder.eval()
|
55 |
+
print(f'Loaded model from {ckpt_path}')
|
56 |
+
|
57 |
+
os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
|
58 |
+
|
59 |
+
# get file list
|
60 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
61 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
62 |
+
else:
|
63 |
+
raise ValueError('metadata.csv not found')
|
64 |
+
if opt.instances is not None:
|
65 |
+
with open(opt.instances, 'r') as f:
|
66 |
+
sha256s = [line.strip() for line in f]
|
67 |
+
metadata = metadata[metadata['sha256'].isin(sha256s)]
|
68 |
+
else:
|
69 |
+
if opt.filter_low_aesthetic_score is not None:
|
70 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
71 |
+
metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
|
72 |
+
if f'latent_{latent_name}' in metadata.columns:
|
73 |
+
metadata = metadata[metadata[f'latent_{latent_name}'] == False]
|
74 |
+
|
75 |
+
start = len(metadata) * opt.rank // opt.world_size
|
76 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
77 |
+
metadata = metadata[start:end]
|
78 |
+
records = []
|
79 |
+
|
80 |
+
# filter out objects that are already processed
|
81 |
+
sha256s = list(metadata['sha256'].values)
|
82 |
+
for sha256 in copy.copy(sha256s):
|
83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
|
84 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
85 |
+
sha256s.remove(sha256)
|
86 |
+
|
87 |
+
# encode latents
|
88 |
+
load_queue = Queue(maxsize=4)
|
89 |
+
try:
|
90 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
91 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
92 |
+
def loader(sha256):
|
93 |
+
try:
|
94 |
+
feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
|
95 |
+
load_queue.put((sha256, feats))
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Error loading features for {sha256}: {e}")
|
98 |
+
loader_executor.map(loader, sha256s)
|
99 |
+
|
100 |
+
def saver(sha256, pack):
|
101 |
+
save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
|
102 |
+
np.savez_compressed(save_path, **pack)
|
103 |
+
records.append({'sha256': sha256, f'latent_{latent_name}': True})
|
104 |
+
|
105 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
106 |
+
sha256, feats = load_queue.get()
|
107 |
+
feats = sp.SparseTensor(
|
108 |
+
feats = torch.from_numpy(feats['patchtokens']).float(),
|
109 |
+
coords = torch.cat([
|
110 |
+
torch.zeros(feats['patchtokens'].shape[0], 1).int(),
|
111 |
+
torch.from_numpy(feats['indices']).int(),
|
112 |
+
], dim=1),
|
113 |
+
).cuda()
|
114 |
+
latent = encoder(feats, sample_posterior=False)
|
115 |
+
assert torch.isfinite(latent.feats).all(), "Non-finite latent"
|
116 |
+
pack = {
|
117 |
+
'feats': latent.feats.cpu().numpy().astype(np.float32),
|
118 |
+
'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
|
119 |
+
}
|
120 |
+
saver_executor.submit(saver, sha256, pack)
|
121 |
+
|
122 |
+
saver_executor.shutdown(wait=True)
|
123 |
+
except:
|
124 |
+
print("Error happened during processing.")
|
125 |
+
|
126 |
+
records = pd.DataFrame.from_records(records)
|
127 |
+
records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/encode_ss_latent.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
4 |
+
import copy
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import utils3d
|
11 |
+
from tqdm import tqdm
|
12 |
+
from easydict import EasyDict as edict
|
13 |
+
from concurrent.futures import ThreadPoolExecutor
|
14 |
+
from queue import Queue
|
15 |
+
|
16 |
+
import trellis.models as models
|
17 |
+
|
18 |
+
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
|
21 |
+
|
22 |
+
def get_voxels(instance):
|
23 |
+
position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
|
24 |
+
coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
|
25 |
+
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
|
26 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
27 |
+
return ss
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
33 |
+
help='Directory to save the metadata')
|
34 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
35 |
+
help='Filter objects with aesthetic score lower than this value')
|
36 |
+
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
|
37 |
+
help='Pretrained encoder model')
|
38 |
+
parser.add_argument('--model_root', type=str, default='results',
|
39 |
+
help='Root directory of models')
|
40 |
+
parser.add_argument('--enc_model', type=str, default=None,
|
41 |
+
help='Encoder model. if specified, use this model instead of pretrained model')
|
42 |
+
parser.add_argument('--ckpt', type=str, default=None,
|
43 |
+
help='Checkpoint to load')
|
44 |
+
parser.add_argument('--resolution', type=int, default=64,
|
45 |
+
help='Resolution')
|
46 |
+
parser.add_argument('--instances', type=str, default=None,
|
47 |
+
help='Instances to process')
|
48 |
+
parser.add_argument('--rank', type=int, default=0)
|
49 |
+
parser.add_argument('--world_size', type=int, default=1)
|
50 |
+
opt = parser.parse_args()
|
51 |
+
opt = edict(vars(opt))
|
52 |
+
|
53 |
+
if opt.enc_model is None:
|
54 |
+
latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
|
55 |
+
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
|
56 |
+
else:
|
57 |
+
latent_name = f'{opt.enc_model}_{opt.ckpt}'
|
58 |
+
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
|
59 |
+
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
|
60 |
+
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
|
61 |
+
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
|
62 |
+
encoder.eval()
|
63 |
+
print(f'Loaded model from {ckpt_path}')
|
64 |
+
|
65 |
+
os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
|
66 |
+
|
67 |
+
# get file list
|
68 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
69 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
70 |
+
else:
|
71 |
+
raise ValueError('metadata.csv not found')
|
72 |
+
if opt.instances is not None:
|
73 |
+
with open(opt.instances, 'r') as f:
|
74 |
+
instances = f.read().splitlines()
|
75 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
76 |
+
else:
|
77 |
+
if opt.filter_low_aesthetic_score is not None:
|
78 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
79 |
+
metadata = metadata[metadata['voxelized'] == True]
|
80 |
+
if f'ss_latent_{latent_name}' in metadata.columns:
|
81 |
+
metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
|
82 |
+
|
83 |
+
start = len(metadata) * opt.rank // opt.world_size
|
84 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
85 |
+
metadata = metadata[start:end]
|
86 |
+
records = []
|
87 |
+
|
88 |
+
# filter out objects that are already processed
|
89 |
+
sha256s = list(metadata['sha256'].values)
|
90 |
+
for sha256 in copy.copy(sha256s):
|
91 |
+
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
|
92 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
93 |
+
sha256s.remove(sha256)
|
94 |
+
|
95 |
+
# encode latents
|
96 |
+
load_queue = Queue(maxsize=4)
|
97 |
+
try:
|
98 |
+
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
|
99 |
+
ThreadPoolExecutor(max_workers=32) as saver_executor:
|
100 |
+
def loader(sha256):
|
101 |
+
try:
|
102 |
+
ss = get_voxels(sha256)[None].float()
|
103 |
+
load_queue.put((sha256, ss))
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error loading features for {sha256}: {e}")
|
106 |
+
loader_executor.map(loader, sha256s)
|
107 |
+
|
108 |
+
def saver(sha256, pack):
|
109 |
+
save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
|
110 |
+
np.savez_compressed(save_path, **pack)
|
111 |
+
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
|
112 |
+
|
113 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
|
114 |
+
sha256, ss = load_queue.get()
|
115 |
+
ss = ss.cuda().float()
|
116 |
+
latent = encoder(ss, sample_posterior=False)
|
117 |
+
assert torch.isfinite(latent).all(), "Non-finite latent"
|
118 |
+
pack = {
|
119 |
+
'mean': latent[0].cpu().numpy(),
|
120 |
+
}
|
121 |
+
saver_executor.submit(saver, sha256, pack)
|
122 |
+
|
123 |
+
saver_executor.shutdown(wait=True)
|
124 |
+
except:
|
125 |
+
print("Error happened during processing.")
|
126 |
+
|
127 |
+
records = pd.DataFrame.from_records(records)
|
128 |
+
records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)
|
dataset_toolkits/extract_feature.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import utils3d
|
12 |
+
from tqdm import tqdm
|
13 |
+
from easydict import EasyDict as edict
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from queue import Queue
|
16 |
+
from torchvision import transforms
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
|
20 |
+
torch.set_grad_enabled(False)
|
21 |
+
|
22 |
+
|
23 |
+
def get_data(frames, sha256):
|
24 |
+
with ThreadPoolExecutor(max_workers=16) as executor:
|
25 |
+
def worker(view):
|
26 |
+
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
|
27 |
+
try:
|
28 |
+
image = Image.open(image_path)
|
29 |
+
except:
|
30 |
+
print(f"Error loading image {image_path}")
|
31 |
+
return None
|
32 |
+
image = image.resize((518, 518), Image.Resampling.LANCZOS)
|
33 |
+
image = np.array(image).astype(np.float32) / 255
|
34 |
+
image = image[:, :, :3] * image[:, :, 3:]
|
35 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
36 |
+
|
37 |
+
c2w = torch.tensor(view['transform_matrix'])
|
38 |
+
c2w[:3, 1:3] *= -1
|
39 |
+
extrinsics = torch.inverse(c2w)
|
40 |
+
fov = view['camera_angle_x']
|
41 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
42 |
+
|
43 |
+
return {
|
44 |
+
'image': image,
|
45 |
+
'extrinsics': extrinsics,
|
46 |
+
'intrinsics': intrinsics
|
47 |
+
}
|
48 |
+
|
49 |
+
datas = executor.map(worker, frames)
|
50 |
+
for data in datas:
|
51 |
+
if data is not None:
|
52 |
+
yield data
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
58 |
+
help='Directory to save the metadata')
|
59 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
60 |
+
help='Filter objects with aesthetic score lower than this value')
|
61 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
|
62 |
+
help='Feature extraction model')
|
63 |
+
parser.add_argument('--instances', type=str, default=None,
|
64 |
+
help='Instances to process')
|
65 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
66 |
+
parser.add_argument('--rank', type=int, default=0)
|
67 |
+
parser.add_argument('--world_size', type=int, default=1)
|
68 |
+
opt = parser.parse_args()
|
69 |
+
opt = edict(vars(opt))
|
70 |
+
|
71 |
+
feature_name = opt.model
|
72 |
+
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
|
73 |
+
|
74 |
+
# load model
|
75 |
+
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
|
76 |
+
dinov2_model.eval().cuda()
|
77 |
+
transform = transforms.Compose([
|
78 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
79 |
+
])
|
80 |
+
n_patch = 518 // 14
|
81 |
+
|
82 |
+
# get file list
|
83 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
84 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
85 |
+
else:
|
86 |
+
raise ValueError('metadata.csv not found')
|
87 |
+
if opt.instances is not None:
|
88 |
+
with open(opt.instances, 'r') as f:
|
89 |
+
instances = f.read().splitlines()
|
90 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
91 |
+
else:
|
92 |
+
if opt.filter_low_aesthetic_score is not None:
|
93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
94 |
+
if f'feature_{feature_name}' in metadata.columns:
|
95 |
+
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
|
96 |
+
metadata = metadata[metadata['voxelized'] == True]
|
97 |
+
metadata = metadata[metadata['rendered'] == True]
|
98 |
+
|
99 |
+
start = len(metadata) * opt.rank // opt.world_size
|
100 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
101 |
+
metadata = metadata[start:end]
|
102 |
+
records = []
|
103 |
+
|
104 |
+
# filter out objects that are already processed
|
105 |
+
sha256s = list(metadata['sha256'].values)
|
106 |
+
for sha256 in copy.copy(sha256s):
|
107 |
+
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
|
108 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
109 |
+
sha256s.remove(sha256)
|
110 |
+
|
111 |
+
# extract features
|
112 |
+
load_queue = Queue(maxsize=4)
|
113 |
+
try:
|
114 |
+
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
|
115 |
+
ThreadPoolExecutor(max_workers=8) as saver_executor:
|
116 |
+
def loader(sha256):
|
117 |
+
try:
|
118 |
+
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
|
119 |
+
metadata = json.load(f)
|
120 |
+
frames = metadata['frames']
|
121 |
+
data = []
|
122 |
+
for datum in get_data(frames, sha256):
|
123 |
+
datum['image'] = transform(datum['image'])
|
124 |
+
data.append(datum)
|
125 |
+
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
126 |
+
load_queue.put((sha256, data, positions))
|
127 |
+
except Exception as e:
|
128 |
+
print(f"Error loading data for {sha256}: {e}")
|
129 |
+
|
130 |
+
loader_executor.map(loader, sha256s)
|
131 |
+
|
132 |
+
def saver(sha256, pack, patchtokens, uv):
|
133 |
+
pack['patchtokens'] = F.grid_sample(
|
134 |
+
patchtokens,
|
135 |
+
uv.unsqueeze(1),
|
136 |
+
mode='bilinear',
|
137 |
+
align_corners=False,
|
138 |
+
).squeeze(2).permute(0, 2, 1).cpu().numpy()
|
139 |
+
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
|
140 |
+
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
|
141 |
+
np.savez_compressed(save_path, **pack)
|
142 |
+
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
|
143 |
+
|
144 |
+
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
|
145 |
+
sha256, data, positions = load_queue.get()
|
146 |
+
positions = torch.from_numpy(positions).float().cuda()
|
147 |
+
indices = ((positions + 0.5) * 64).long()
|
148 |
+
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
|
149 |
+
n_views = len(data)
|
150 |
+
N = positions.shape[0]
|
151 |
+
pack = {
|
152 |
+
'indices': indices.cpu().numpy().astype(np.uint8),
|
153 |
+
}
|
154 |
+
patchtokens_lst = []
|
155 |
+
uv_lst = []
|
156 |
+
for i in range(0, n_views, opt.batch_size):
|
157 |
+
batch_data = data[i:i+opt.batch_size]
|
158 |
+
bs = len(batch_data)
|
159 |
+
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
|
160 |
+
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
|
161 |
+
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
|
162 |
+
features = dinov2_model(batch_images, is_training=True)
|
163 |
+
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
|
164 |
+
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
|
165 |
+
patchtokens_lst.append(patchtokens)
|
166 |
+
uv_lst.append(uv)
|
167 |
+
patchtokens = torch.cat(patchtokens_lst, dim=0)
|
168 |
+
uv = torch.cat(uv_lst, dim=0)
|
169 |
+
|
170 |
+
# save features
|
171 |
+
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
|
172 |
+
|
173 |
+
saver_executor.shutdown(wait=True)
|
174 |
+
except:
|
175 |
+
print("Error happened during processing.")
|
176 |
+
|
177 |
+
records = pd.DataFrame.from_records(records)
|
178 |
+
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)
|
179 |
+
|
dataset_toolkits/render.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import copy
|
4 |
+
import sys
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
from functools import partial
|
10 |
+
from subprocess import DEVNULL, call
|
11 |
+
import numpy as np
|
12 |
+
from utils import sphere_hammersley_sequence
|
13 |
+
|
14 |
+
|
15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
18 |
+
|
19 |
+
def _install_blender():
|
20 |
+
if not os.path.exists(BLENDER_PATH):
|
21 |
+
os.system('sudo apt-get update')
|
22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
25 |
+
|
26 |
+
|
27 |
+
def _render(file_path, sha256, output_dir, num_views):
|
28 |
+
output_folder = os.path.join(output_dir, 'renders', sha256)
|
29 |
+
|
30 |
+
# Build camera {yaw, pitch, radius, fov}
|
31 |
+
yaws = []
|
32 |
+
pitchs = []
|
33 |
+
offset = (np.random.rand(), np.random.rand())
|
34 |
+
for i in range(num_views):
|
35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
36 |
+
yaws.append(y)
|
37 |
+
pitchs.append(p)
|
38 |
+
radius = [2] * num_views
|
39 |
+
fov = [40 / 180 * np.pi] * num_views
|
40 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
41 |
+
|
42 |
+
args = [
|
43 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
44 |
+
'--',
|
45 |
+
'--views', json.dumps(views),
|
46 |
+
'--object', os.path.expanduser(file_path),
|
47 |
+
'--resolution', '512',
|
48 |
+
'--output_folder', output_folder,
|
49 |
+
'--engine', 'CYCLES',
|
50 |
+
'--save_mesh',
|
51 |
+
]
|
52 |
+
if file_path.endswith('.blend'):
|
53 |
+
args.insert(1, file_path)
|
54 |
+
|
55 |
+
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
56 |
+
|
57 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
58 |
+
return {'sha256': sha256, 'rendered': True}
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
63 |
+
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
66 |
+
help='Directory to save the metadata')
|
67 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
68 |
+
help='Filter objects with aesthetic score lower than this value')
|
69 |
+
parser.add_argument('--instances', type=str, default=None,
|
70 |
+
help='Instances to process')
|
71 |
+
parser.add_argument('--num_views', type=int, default=150,
|
72 |
+
help='Number of views to render')
|
73 |
+
dataset_utils.add_args(parser)
|
74 |
+
parser.add_argument('--rank', type=int, default=0)
|
75 |
+
parser.add_argument('--world_size', type=int, default=1)
|
76 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
77 |
+
opt = parser.parse_args(sys.argv[2:])
|
78 |
+
opt = edict(vars(opt))
|
79 |
+
|
80 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
|
81 |
+
|
82 |
+
# install blender
|
83 |
+
print('Checking blender...', flush=True)
|
84 |
+
_install_blender()
|
85 |
+
|
86 |
+
# get file list
|
87 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
88 |
+
raise ValueError('metadata.csv not found')
|
89 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
90 |
+
if opt.instances is None:
|
91 |
+
metadata = metadata[metadata['local_path'].notna()]
|
92 |
+
if opt.filter_low_aesthetic_score is not None:
|
93 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
94 |
+
if 'rendered' in metadata.columns:
|
95 |
+
metadata = metadata[metadata['rendered'] == False]
|
96 |
+
else:
|
97 |
+
if os.path.exists(opt.instances):
|
98 |
+
with open(opt.instances, 'r') as f:
|
99 |
+
instances = f.read().splitlines()
|
100 |
+
else:
|
101 |
+
instances = opt.instances.split(',')
|
102 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
103 |
+
|
104 |
+
start = len(metadata) * opt.rank // opt.world_size
|
105 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
106 |
+
metadata = metadata[start:end]
|
107 |
+
records = []
|
108 |
+
|
109 |
+
# filter out objects that are already processed
|
110 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
111 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
112 |
+
records.append({'sha256': sha256, 'rendered': True})
|
113 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
114 |
+
|
115 |
+
print(f'Processing {len(metadata)} objects...')
|
116 |
+
|
117 |
+
# process objects
|
118 |
+
func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
|
119 |
+
rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
120 |
+
rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
|
121 |
+
rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/render_cond.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import copy
|
4 |
+
import sys
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import pandas as pd
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
from functools import partial
|
10 |
+
from subprocess import DEVNULL, call
|
11 |
+
import numpy as np
|
12 |
+
from utils import sphere_hammersley_sequence
|
13 |
+
|
14 |
+
|
15 |
+
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
16 |
+
BLENDER_INSTALLATION_PATH = '/tmp'
|
17 |
+
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
18 |
+
|
19 |
+
def _install_blender():
|
20 |
+
if not os.path.exists(BLENDER_PATH):
|
21 |
+
os.system('sudo apt-get update')
|
22 |
+
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
23 |
+
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
24 |
+
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
25 |
+
|
26 |
+
|
27 |
+
def _render_cond(file_path, sha256, output_dir, num_views):
|
28 |
+
output_folder = os.path.join(output_dir, 'renders_cond', sha256)
|
29 |
+
|
30 |
+
# Build camera {yaw, pitch, radius, fov}
|
31 |
+
yaws = []
|
32 |
+
pitchs = []
|
33 |
+
offset = (np.random.rand(), np.random.rand())
|
34 |
+
for i in range(num_views):
|
35 |
+
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
36 |
+
yaws.append(y)
|
37 |
+
pitchs.append(p)
|
38 |
+
fov_min, fov_max = 10, 70
|
39 |
+
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
|
40 |
+
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
|
41 |
+
k_min = 1 / radius_max**2
|
42 |
+
k_max = 1 / radius_min**2
|
43 |
+
ks = np.random.uniform(k_min, k_max, (1000000,))
|
44 |
+
radius = [1 / np.sqrt(k) for k in ks]
|
45 |
+
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
|
46 |
+
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
47 |
+
|
48 |
+
args = [
|
49 |
+
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
50 |
+
'--',
|
51 |
+
'--views', json.dumps(views),
|
52 |
+
'--object', os.path.expanduser(file_path),
|
53 |
+
'--output_folder', os.path.expanduser(output_folder),
|
54 |
+
'--resolution', '1024',
|
55 |
+
]
|
56 |
+
if file_path.endswith('.blend'):
|
57 |
+
args.insert(1, file_path)
|
58 |
+
|
59 |
+
call(args, stdout=DEVNULL)
|
60 |
+
|
61 |
+
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
62 |
+
return {'sha256': sha256, 'cond_rendered': True}
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
67 |
+
|
68 |
+
parser = argparse.ArgumentParser()
|
69 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
70 |
+
help='Directory to save the metadata')
|
71 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
72 |
+
help='Filter objects with aesthetic score lower than this value')
|
73 |
+
parser.add_argument('--instances', type=str, default=None,
|
74 |
+
help='Instances to process')
|
75 |
+
parser.add_argument('--num_views', type=int, default=24,
|
76 |
+
help='Number of views to render')
|
77 |
+
dataset_utils.add_args(parser)
|
78 |
+
parser.add_argument('--rank', type=int, default=0)
|
79 |
+
parser.add_argument('--world_size', type=int, default=1)
|
80 |
+
parser.add_argument('--max_workers', type=int, default=8)
|
81 |
+
opt = parser.parse_args(sys.argv[2:])
|
82 |
+
opt = edict(vars(opt))
|
83 |
+
|
84 |
+
os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
|
85 |
+
|
86 |
+
# install blender
|
87 |
+
print('Checking blender...', flush=True)
|
88 |
+
_install_blender()
|
89 |
+
|
90 |
+
# get file list
|
91 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
92 |
+
raise ValueError('metadata.csv not found')
|
93 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
94 |
+
if opt.instances is None:
|
95 |
+
metadata = metadata[metadata['local_path'].notna()]
|
96 |
+
if opt.filter_low_aesthetic_score is not None:
|
97 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
98 |
+
if 'cond_rendered' in metadata.columns:
|
99 |
+
metadata = metadata[metadata['cond_rendered'] == False]
|
100 |
+
else:
|
101 |
+
if os.path.exists(opt.instances):
|
102 |
+
with open(opt.instances, 'r') as f:
|
103 |
+
instances = f.read().splitlines()
|
104 |
+
else:
|
105 |
+
instances = opt.instances.split(',')
|
106 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
107 |
+
|
108 |
+
start = len(metadata) * opt.rank // opt.world_size
|
109 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
110 |
+
metadata = metadata[start:end]
|
111 |
+
records = []
|
112 |
+
|
113 |
+
# filter out objects that are already processed
|
114 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
115 |
+
if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
116 |
+
records.append({'sha256': sha256, 'cond_rendered': True})
|
117 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
118 |
+
|
119 |
+
print(f'Processing {len(metadata)} objects...')
|
120 |
+
|
121 |
+
# process objects
|
122 |
+
func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
|
123 |
+
cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
124 |
+
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
|
125 |
+
cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)
|
dataset_toolkits/setup.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub
|
dataset_toolkits/stat_latent.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
14 |
+
help='Directory to save the metadata')
|
15 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
16 |
+
help='Filter objects with aesthetic score lower than this value')
|
17 |
+
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16',
|
18 |
+
help='Latent model to use')
|
19 |
+
parser.add_argument('--num_samples', type=int, default=50000,
|
20 |
+
help='Number of samples to use for calculating stats')
|
21 |
+
opt = parser.parse_args()
|
22 |
+
opt = edict(vars(opt))
|
23 |
+
|
24 |
+
# get file list
|
25 |
+
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
26 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
27 |
+
else:
|
28 |
+
raise ValueError('metadata.csv not found')
|
29 |
+
if opt.filter_low_aesthetic_score is not None:
|
30 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
31 |
+
metadata = metadata[metadata[f'latent_{opt.model}'] == True]
|
32 |
+
sha256s = metadata['sha256'].values
|
33 |
+
sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False)
|
34 |
+
|
35 |
+
# stats
|
36 |
+
means = []
|
37 |
+
mean2s = []
|
38 |
+
with ThreadPoolExecutor(max_workers=16) as executor, \
|
39 |
+
tqdm(total=len(sha256s), desc="Extracting features") as pbar:
|
40 |
+
def worker(sha256):
|
41 |
+
try:
|
42 |
+
feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz'))
|
43 |
+
feats = feats['feats']
|
44 |
+
means.append(feats.mean(axis=0))
|
45 |
+
mean2s.append((feats ** 2).mean(axis=0))
|
46 |
+
pbar.update()
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error extracting features for {sha256}: {e}")
|
49 |
+
pbar.update()
|
50 |
+
|
51 |
+
executor.map(worker, sha256s)
|
52 |
+
executor.shutdown(wait=True)
|
53 |
+
|
54 |
+
mean = np.array(means).mean(axis=0)
|
55 |
+
mean2 = np.array(mean2s).mean(axis=0)
|
56 |
+
std = np.sqrt(mean2 - mean ** 2)
|
57 |
+
|
58 |
+
print('mean:', mean)
|
59 |
+
print('std:', std)
|
60 |
+
|
61 |
+
with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f:
|
62 |
+
json.dump({
|
63 |
+
'mean': mean.tolist(),
|
64 |
+
'std': std.tolist(),
|
65 |
+
}, f, indent=4)
|
66 |
+
|
dataset_toolkits/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import hashlib
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_file_hash(file: str) -> str:
|
7 |
+
sha256 = hashlib.sha256()
|
8 |
+
# Read the file from the path
|
9 |
+
with open(file, "rb") as f:
|
10 |
+
# Update the hash with the file content
|
11 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
12 |
+
sha256.update(byte_block)
|
13 |
+
return sha256.hexdigest()
|
14 |
+
|
15 |
+
# ===============LOW DISCREPANCY SEQUENCES================
|
16 |
+
|
17 |
+
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
18 |
+
|
19 |
+
def radical_inverse(base, n):
|
20 |
+
val = 0
|
21 |
+
inv_base = 1.0 / base
|
22 |
+
inv_base_n = inv_base
|
23 |
+
while n > 0:
|
24 |
+
digit = n % base
|
25 |
+
val += digit * inv_base_n
|
26 |
+
n //= base
|
27 |
+
inv_base_n *= inv_base
|
28 |
+
return val
|
29 |
+
|
30 |
+
def halton_sequence(dim, n):
|
31 |
+
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
32 |
+
|
33 |
+
def hammersley_sequence(dim, n, num_samples):
|
34 |
+
return [n / num_samples] + halton_sequence(dim - 1, n)
|
35 |
+
|
36 |
+
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
|
37 |
+
u, v = hammersley_sequence(2, n, num_samples)
|
38 |
+
u += offset[0] / num_samples
|
39 |
+
v += offset[1]
|
40 |
+
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
41 |
+
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
42 |
+
phi = v * 2 * np.pi
|
43 |
+
return [phi, theta]
|
dataset_toolkits/voxelize.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import sys
|
4 |
+
import importlib
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
from easydict import EasyDict as edict
|
8 |
+
from functools import partial
|
9 |
+
import numpy as np
|
10 |
+
import open3d as o3d
|
11 |
+
import utils3d
|
12 |
+
|
13 |
+
|
14 |
+
def _voxelize(file, sha256, output_dir):
|
15 |
+
mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply'))
|
16 |
+
# clamp vertices to the range [-0.5, 0.5]
|
17 |
+
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
|
18 |
+
mesh.vertices = o3d.utility.Vector3dVector(vertices)
|
19 |
+
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
|
20 |
+
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
|
21 |
+
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
|
22 |
+
vertices = (vertices + 0.5) / 64 - 0.5
|
23 |
+
utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices)
|
24 |
+
return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)}
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
29 |
+
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
32 |
+
help='Directory to save the metadata')
|
33 |
+
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
34 |
+
help='Filter objects with aesthetic score lower than this value')
|
35 |
+
parser.add_argument('--instances', type=str, default=None,
|
36 |
+
help='Instances to process')
|
37 |
+
parser.add_argument('--num_views', type=int, default=150,
|
38 |
+
help='Number of views to render')
|
39 |
+
dataset_utils.add_args(parser)
|
40 |
+
parser.add_argument('--rank', type=int, default=0)
|
41 |
+
parser.add_argument('--world_size', type=int, default=1)
|
42 |
+
parser.add_argument('--max_workers', type=int, default=None)
|
43 |
+
opt = parser.parse_args(sys.argv[2:])
|
44 |
+
opt = edict(vars(opt))
|
45 |
+
|
46 |
+
os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True)
|
47 |
+
|
48 |
+
# get file list
|
49 |
+
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
50 |
+
raise ValueError('metadata.csv not found')
|
51 |
+
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
52 |
+
if opt.instances is None:
|
53 |
+
if opt.filter_low_aesthetic_score is not None:
|
54 |
+
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
55 |
+
if 'rendered' not in metadata.columns:
|
56 |
+
raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first')
|
57 |
+
metadata = metadata[metadata['rendered'] == True]
|
58 |
+
if 'voxelized' in metadata.columns:
|
59 |
+
metadata = metadata[metadata['voxelized'] == False]
|
60 |
+
else:
|
61 |
+
if os.path.exists(opt.instances):
|
62 |
+
with open(opt.instances, 'r') as f:
|
63 |
+
instances = f.read().splitlines()
|
64 |
+
else:
|
65 |
+
instances = opt.instances.split(',')
|
66 |
+
metadata = metadata[metadata['sha256'].isin(instances)]
|
67 |
+
|
68 |
+
start = len(metadata) * opt.rank // opt.world_size
|
69 |
+
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
70 |
+
metadata = metadata[start:end]
|
71 |
+
records = []
|
72 |
+
|
73 |
+
# filter out objects that are already processed
|
74 |
+
for sha256 in copy.copy(metadata['sha256'].values):
|
75 |
+
if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
|
76 |
+
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
|
77 |
+
records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)})
|
78 |
+
metadata = metadata[metadata['sha256'] != sha256]
|
79 |
+
|
80 |
+
print(f'Processing {len(metadata)} objects...')
|
81 |
+
|
82 |
+
# process objects
|
83 |
+
func = partial(_voxelize, output_dir=opt.output_dir)
|
84 |
+
voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing')
|
85 |
+
voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)])
|
86 |
+
voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False)
|
examples/airplane.png
ADDED
![]() |
Git LFS Details
|
examples/airplane2.png
ADDED
![]() |
Git LFS Details
|
examples/bear.png
ADDED
![]() |
Git LFS Details
|
examples/car.png
ADDED
![]() |
Git LFS Details
|
examples/car2.png
ADDED
![]() |
Git LFS Details
|
examples/gun1.png
ADDED
![]() |
Git LFS Details
|
examples/gun2.png
ADDED
![]() |
Git LFS Details
|
examples/icecream.png
ADDED
![]() |
Git LFS Details
|
examples/knife.png
ADDED
![]() |
Git LFS Details
|
examples/man1.png
ADDED
![]() |
Git LFS Details
|
examples/man2.png
ADDED
![]() |
Git LFS Details
|
examples/man3.png
ADDED
![]() |
Git LFS Details
|
examples/robot1.png
ADDED
![]() |
Git LFS Details
|
examples/robot2.png
ADDED
![]() |
Git LFS Details
|
examples/shoe.png
ADDED
![]() |
Git LFS Details
|
examples/sweater.png
ADDED
![]() |
Git LFS Details
|