diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..b597ee4e8097c1563905d8d1a6769f9b9e9041cc
--- /dev/null
+++ b/app.py
@@ -0,0 +1,152 @@
+import gradio as gr
+# from gradio_litmodel3d import LitModel3D
+
+import os
+from typing import *
+import imageio
+import uuid
+from PIL import Image
+from trellis.pipelines import TrellisImageTo3DPipeline
+from trellis.utils import render_utils, postprocessing_utils
+
+
+def preprocess_image(image: Image.Image) -> Image.Image:
+    """
+    Preprocess the input image.
+
+    Args:
+        image (Image.Image): The input image.
+
+    Returns:
+        Image.Image: The preprocessed image.
+    """
+    return pipeline.preprocess_image(image)
+
+
+def image_to_3d(image: Image.Image) -> Tuple[dict, str]:
+    """
+    Convert an image to a 3D model.
+
+    Args:
+        image (Image.Image): The input image.
+
+    Returns:
+        dict: The information of the generated 3D model.
+        str: The path to the video of the 3D model.
+    """
+    outputs = pipeline(image, formats=["gaussian", "mesh"], preprocess_image=False)
+    video = render_utils.render_video(outputs['gaussian'][0])['color']
+    model_id = uuid.uuid4()
+    video_path = f"/tmp/Trellis-demo/{model_id}.mp4"
+    os.makedirs(os.path.dirname(video_path), exist_ok=True)
+    imageio.mimsave(video_path, video, fps=30)
+    model = {'gaussian': outputs['gaussian'][0], 'mesh': outputs['mesh'][0], 'model_id': model_id}
+    return model, video_path
+
+
+def extract_glb(model: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
+    """
+    Extract a GLB file from the 3D model.
+
+    Args:
+        model (dict): The generated 3D model.
+        mesh_simplify (float): The mesh simplification factor.
+        texture_size (int): The texture resolution.
+
+    Returns:
+        str: The path to the extracted GLB file.
+    """
+    glb = postprocessing_utils.to_glb(model['gaussian'], model['mesh'], simplify=mesh_simplify, texture_size=texture_size)
+    glb_path = f"/tmp/Trellis-demo/{model['model_id']}.glb"
+    glb.export(glb_path)
+    return glb_path, glb_path
+
+
+def activate_button() -> gr.Button:
+    return gr.Button(interactive=True)
+
+
+def deactivate_button() -> gr.Button:
+    return gr.Button(interactive=False)
+
+
+with gr.Blocks() as demo:
+    with gr.Row():
+        with gr.Column():
+            image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
+            generate_btn = gr.Button("Generate", interactive=False)
+
+            mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
+            texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
+            extract_glb_btn = gr.Button("Extract GLB", interactive=False)
+
+        with gr.Column():
+            video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
+            model_output = gr.Model3D(label="Extracted GLB", height=300)
+            download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
+
+    # Example images at the bottom of the page
+    with gr.Row():
+        examples = gr.Examples(
+            examples=[
+                f'assets/example_image/{image}'
+                for image in os.listdir("assets/example_image")
+            ],
+            inputs=[image_prompt],
+            fn=lambda image: (preprocess_image(image), gr.Button(interactive=True)),
+            outputs=[image_prompt, generate_btn],
+            run_on_click=True,
+            examples_per_page=64,
+        )
+
+    model = gr.State()
+
+    # Handlers
+    image_prompt.upload(
+        preprocess_image,
+        inputs=[image_prompt],
+        outputs=[image_prompt],
+    ).then(
+        activate_button,
+        outputs=[generate_btn],
+    )
+
+    image_prompt.clear(
+        deactivate_button,
+        outputs=[generate_btn],
+    )
+
+    generate_btn.click(
+        image_to_3d,
+        inputs=[image_prompt],
+        outputs=[model, video_output],
+    ).then(
+        activate_button,
+        outputs=[extract_glb_btn],
+    )
+
+    video_output.clear(
+        deactivate_button,
+        outputs=[extract_glb_btn],
+    )
+
+    extract_glb_btn.click(
+        extract_glb,
+        inputs=[model, mesh_simplify, texture_size],
+        outputs=[model_output, download_glb],
+    ).then(
+        activate_button,
+        outputs=[download_glb],
+    )
+
+    model_output.clear(
+        deactivate_button,
+        outputs=[download_glb],
+    )
+    
+
+# Launch the Gradio app
+if __name__ == "__main__":
+    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
+    pipeline.cuda()
+    demo.launch()
diff --git a/assets/example_image/T.png b/assets/example_image/T.png
new file mode 100644
index 0000000000000000000000000000000000000000..79af51bc9d711951fbc63be16b7d07c84294355b
Binary files /dev/null and b/assets/example_image/T.png differ
diff --git a/assets/example_image/typical_building_building.png b/assets/example_image/typical_building_building.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f9adcf79d4297bb9b906608c23c311f9b8f23d2
Binary files /dev/null and b/assets/example_image/typical_building_building.png differ
diff --git a/assets/example_image/typical_building_castle.png b/assets/example_image/typical_building_castle.png
new file mode 100644
index 0000000000000000000000000000000000000000..5f5f50733f3b8ed168026340ed679df357ccb9ec
Binary files /dev/null and b/assets/example_image/typical_building_castle.png differ
diff --git a/assets/example_image/typical_building_colorful_cottage.png b/assets/example_image/typical_building_colorful_cottage.png
new file mode 100644
index 0000000000000000000000000000000000000000..94616b19be6a896413c287b3168b3b7886d64a56
Binary files /dev/null and b/assets/example_image/typical_building_colorful_cottage.png differ
diff --git a/assets/example_image/typical_building_maya_pyramid.png b/assets/example_image/typical_building_maya_pyramid.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d87f3c3980f80eee878b4f8ab69e32279a5ea50
Binary files /dev/null and b/assets/example_image/typical_building_maya_pyramid.png differ
diff --git a/assets/example_image/typical_building_mushroom.png b/assets/example_image/typical_building_mushroom.png
new file mode 100644
index 0000000000000000000000000000000000000000..c4db49d4284e6fec83b2ea548e7e85d489711b84
Binary files /dev/null and b/assets/example_image/typical_building_mushroom.png differ
diff --git a/assets/example_image/typical_building_space_station.png b/assets/example_image/typical_building_space_station.png
new file mode 100644
index 0000000000000000000000000000000000000000..e37a5806d403dccc098d82492d687d71afa36850
Binary files /dev/null and b/assets/example_image/typical_building_space_station.png differ
diff --git a/assets/example_image/typical_creature_dragon.png b/assets/example_image/typical_creature_dragon.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3fb92ff0400451c69f44fb75156f50db93da6ec
Binary files /dev/null and b/assets/example_image/typical_creature_dragon.png differ
diff --git a/assets/example_image/typical_creature_elephant.png b/assets/example_image/typical_creature_elephant.png
new file mode 100644
index 0000000000000000000000000000000000000000..6fc3cf1776c66b91e739cad8e839b52074a57e4c
Binary files /dev/null and b/assets/example_image/typical_creature_elephant.png differ
diff --git a/assets/example_image/typical_creature_furry.png b/assets/example_image/typical_creature_furry.png
new file mode 100644
index 0000000000000000000000000000000000000000..eb4e8d6c6cac1e03a206429eaf7de261c6f14072
Binary files /dev/null and b/assets/example_image/typical_creature_furry.png differ
diff --git a/assets/example_image/typical_creature_quadruped.png b/assets/example_image/typical_creature_quadruped.png
new file mode 100644
index 0000000000000000000000000000000000000000..b246e08e05702051fb22cced1366ab765cd6fbb0
Binary files /dev/null and b/assets/example_image/typical_creature_quadruped.png differ
diff --git a/assets/example_image/typical_creature_robot_crab.png b/assets/example_image/typical_creature_robot_crab.png
new file mode 100644
index 0000000000000000000000000000000000000000..8b4e10b353e0e9b60634ea272ff8fd9135fdd640
Binary files /dev/null and b/assets/example_image/typical_creature_robot_crab.png differ
diff --git a/assets/example_image/typical_creature_robot_dinosour.png b/assets/example_image/typical_creature_robot_dinosour.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f8f51728fe1fecb0532673756b1601ef46edc2c
Binary files /dev/null and b/assets/example_image/typical_creature_robot_dinosour.png differ
diff --git a/assets/example_image/typical_creature_rock_monster.png b/assets/example_image/typical_creature_rock_monster.png
new file mode 100644
index 0000000000000000000000000000000000000000..29dc243b197d9b3ee4df9355a5f08752ef0b9b9e
Binary files /dev/null and b/assets/example_image/typical_creature_rock_monster.png differ
diff --git a/assets/example_image/typical_humanoid_block_robot.png b/assets/example_image/typical_humanoid_block_robot.png
new file mode 100644
index 0000000000000000000000000000000000000000..195212e38e6a8e331b02c2d58728ba41dba429a1
Binary files /dev/null and b/assets/example_image/typical_humanoid_block_robot.png differ
diff --git a/assets/example_image/typical_humanoid_dragonborn.png b/assets/example_image/typical_humanoid_dragonborn.png
new file mode 100644
index 0000000000000000000000000000000000000000..61ca2d9e69634c12ee9ae6f7e77f84839df83fdb
Binary files /dev/null and b/assets/example_image/typical_humanoid_dragonborn.png differ
diff --git a/assets/example_image/typical_humanoid_dwarf.png b/assets/example_image/typical_humanoid_dwarf.png
new file mode 100644
index 0000000000000000000000000000000000000000..16de1631fff3cc42a3a5d6a8b0f638da75ad7b2f
Binary files /dev/null and b/assets/example_image/typical_humanoid_dwarf.png differ
diff --git a/assets/example_image/typical_humanoid_goblin.png b/assets/example_image/typical_humanoid_goblin.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e4fe04517801d5722817e8dfaed2af83b31d67e
Binary files /dev/null and b/assets/example_image/typical_humanoid_goblin.png differ
diff --git a/assets/example_image/typical_humanoid_mech.png b/assets/example_image/typical_humanoid_mech.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0fbbdf6cda5636f517b6e2fa3f20e15e56e3777
Binary files /dev/null and b/assets/example_image/typical_humanoid_mech.png differ
diff --git a/assets/example_image/typical_misc_crate.png b/assets/example_image/typical_misc_crate.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3086f885bf9fc27c398b5bacfb04a65bd7dfbd9
Binary files /dev/null and b/assets/example_image/typical_misc_crate.png differ
diff --git a/assets/example_image/typical_misc_fireplace.png b/assets/example_image/typical_misc_fireplace.png
new file mode 100644
index 0000000000000000000000000000000000000000..82d79bc10346604a8b8b9cc8e2c317e8dc6d8c47
Binary files /dev/null and b/assets/example_image/typical_misc_fireplace.png differ
diff --git a/assets/example_image/typical_misc_gate.png b/assets/example_image/typical_misc_gate.png
new file mode 100644
index 0000000000000000000000000000000000000000..fa77919f9d9faabc26b9287b35c4dd3b4006163e
Binary files /dev/null and b/assets/example_image/typical_misc_gate.png differ
diff --git a/assets/example_image/typical_misc_lantern.png b/assets/example_image/typical_misc_lantern.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c93f5dea2638a5a169dd1557a36f6d34b57144d
Binary files /dev/null and b/assets/example_image/typical_misc_lantern.png differ
diff --git a/assets/example_image/typical_misc_magicbook.png b/assets/example_image/typical_misc_magicbook.png
new file mode 100644
index 0000000000000000000000000000000000000000..7dc521a10fda176694c30170811809050f478a66
Binary files /dev/null and b/assets/example_image/typical_misc_magicbook.png differ
diff --git a/assets/example_image/typical_misc_mailbox.png b/assets/example_image/typical_misc_mailbox.png
new file mode 100644
index 0000000000000000000000000000000000000000..b6e8bc50cd270bb7462eee2af7a6d5649ef54cf2
Binary files /dev/null and b/assets/example_image/typical_misc_mailbox.png differ
diff --git a/assets/example_image/typical_misc_monster_chest.png b/assets/example_image/typical_misc_monster_chest.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d544370fa306138e7dbab3e548d6e05b8ef2317
Binary files /dev/null and b/assets/example_image/typical_misc_monster_chest.png differ
diff --git a/assets/example_image/typical_misc_paper_machine.png b/assets/example_image/typical_misc_paper_machine.png
new file mode 100644
index 0000000000000000000000000000000000000000..a630074dbfe32c53f52f2f27e5b6b3eff8469a9e
Binary files /dev/null and b/assets/example_image/typical_misc_paper_machine.png differ
diff --git a/assets/example_image/typical_misc_phonograph.png b/assets/example_image/typical_misc_phonograph.png
new file mode 100644
index 0000000000000000000000000000000000000000..668662d741344ac16427259fc966186ef8ca97a9
Binary files /dev/null and b/assets/example_image/typical_misc_phonograph.png differ
diff --git a/assets/example_image/typical_misc_portal2.png b/assets/example_image/typical_misc_portal2.png
new file mode 100644
index 0000000000000000000000000000000000000000..666daa75fbaf7df55585f7143906d158175be6be
Binary files /dev/null and b/assets/example_image/typical_misc_portal2.png differ
diff --git a/assets/example_image/typical_misc_storage_chest.png b/assets/example_image/typical_misc_storage_chest.png
new file mode 100644
index 0000000000000000000000000000000000000000..38f4bd31f8eb62badcc5e1a51d4612e528b4069e
Binary files /dev/null and b/assets/example_image/typical_misc_storage_chest.png differ
diff --git a/assets/example_image/typical_misc_telephone.png b/assets/example_image/typical_misc_telephone.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0a7d65a300d9f1adc55b2fc36951731f5abb355
Binary files /dev/null and b/assets/example_image/typical_misc_telephone.png differ
diff --git a/assets/example_image/typical_misc_television.png b/assets/example_image/typical_misc_television.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d6b5882b42ce532f6a60080ad55bda7053530c0
Binary files /dev/null and b/assets/example_image/typical_misc_television.png differ
diff --git a/assets/example_image/typical_misc_workbench.png b/assets/example_image/typical_misc_workbench.png
new file mode 100644
index 0000000000000000000000000000000000000000..88024f960ff56aa619b0c496f85de390076bbf5a
Binary files /dev/null and b/assets/example_image/typical_misc_workbench.png differ
diff --git a/assets/example_image/typical_vehicle_biplane.png b/assets/example_image/typical_vehicle_biplane.png
new file mode 100644
index 0000000000000000000000000000000000000000..7427cad3270d8ed33dad05c7a2ae1b0092b4beb2
Binary files /dev/null and b/assets/example_image/typical_vehicle_biplane.png differ
diff --git a/assets/example_image/typical_vehicle_bulldozer.png b/assets/example_image/typical_vehicle_bulldozer.png
new file mode 100644
index 0000000000000000000000000000000000000000..17ffe389498d9561ef92766de654ef17b5755f60
Binary files /dev/null and b/assets/example_image/typical_vehicle_bulldozer.png differ
diff --git a/assets/example_image/typical_vehicle_cart.png b/assets/example_image/typical_vehicle_cart.png
new file mode 100644
index 0000000000000000000000000000000000000000..137bb4887f3879691ffff21227951790eb1840b4
Binary files /dev/null and b/assets/example_image/typical_vehicle_cart.png differ
diff --git a/assets/example_image/typical_vehicle_excavator.png b/assets/example_image/typical_vehicle_excavator.png
new file mode 100644
index 0000000000000000000000000000000000000000..c434e8b0ab142ecc35caf91d42df5f4541825b8f
Binary files /dev/null and b/assets/example_image/typical_vehicle_excavator.png differ
diff --git a/assets/example_image/typical_vehicle_helicopter.png b/assets/example_image/typical_vehicle_helicopter.png
new file mode 100644
index 0000000000000000000000000000000000000000..39c2497d22ea519cddf576c6504954c338f943e4
Binary files /dev/null and b/assets/example_image/typical_vehicle_helicopter.png differ
diff --git a/assets/example_image/typical_vehicle_locomotive.png b/assets/example_image/typical_vehicle_locomotive.png
new file mode 100644
index 0000000000000000000000000000000000000000..dac6a2a2de9e8830bac53d3893aa1d3741916b1a
Binary files /dev/null and b/assets/example_image/typical_vehicle_locomotive.png differ
diff --git a/assets/example_image/typical_vehicle_pirate_ship.png b/assets/example_image/typical_vehicle_pirate_ship.png
new file mode 100644
index 0000000000000000000000000000000000000000..9eed1529f309c64fc6237caba97631cc1f2bab53
Binary files /dev/null and b/assets/example_image/typical_vehicle_pirate_ship.png differ
diff --git a/assets/example_image/weatherworn_misc_paper_machine3.png b/assets/example_image/weatherworn_misc_paper_machine3.png
new file mode 100644
index 0000000000000000000000000000000000000000..46e8a9dc123aaf71a41e994a8e50eabb4e53e721
Binary files /dev/null and b/assets/example_image/weatherworn_misc_paper_machine3.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..93a5d99a3436c988b46fe75255da205d66dde789
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+--find-links https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html
+
+
+torch==2.4.0
+torchvision==0.19.0
+pillow==10.4.0
+imageio==2.36.1
+imageio-ffmpeg==0.5.1
+tqdm==4.67.1
+easydict==1.13
+opencv-python-headless==4.10.0.84
+scipy==1.14.1
+rembg==2.0.60
+onnxruntime==1.20.1
+trimesh==4.5.3
+xatlas==0.0.9
+pyvista==0.44.2
+pymeshfix==0.17.0
+igraph==0.11.8
+git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
+xformers==0.0.27.post2+cu118
+flash-attn==2.7.0.post2
+kaolin==0.17.0
+spconv-cu118==2.3.6
+transformers==4.46.3
+wheels/nvdiffrast-0.3.3-py3-none-any.whl
+wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
diff --git a/trellis/__init__.py b/trellis/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd
--- /dev/null
+++ b/trellis/__init__.py
@@ -0,0 +1,6 @@
+from . import models
+from . import modules
+from . import pipelines
+from . import renderers
+from . import representations
+from . import utils
diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90e9f9ab48e7028a370a0df663182f4b8ccadc5
--- /dev/null
+++ b/trellis/models/__init__.py
@@ -0,0 +1,70 @@
+import importlib
+
+__attributes = {
+    'SparseStructureEncoder': 'sparse_structure_vae',
+    'SparseStructureDecoder': 'sparse_structure_vae',
+    'SparseStructureFlowModel': 'sparse_structure_flow',
+    'SLatEncoder': 'structured_latent_vae',
+    'SLatGaussianDecoder': 'structured_latent_vae',
+    'SLatRadianceFieldDecoder': 'structured_latent_vae',
+    'SLatMeshDecoder': 'structured_latent_vae',
+    'SLatFlowModel': 'structured_latent_flow',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+    if name not in globals():
+        if name in __attributes:
+            module_name = __attributes[name]
+            module = importlib.import_module(f".{module_name}", __name__)
+            globals()[name] = getattr(module, name)
+        elif name in __submodules:
+            module = importlib.import_module(f".{name}", __name__)
+            globals()[name] = module
+        else:
+            raise AttributeError(f"module {__name__} has no attribute {name}")
+    return globals()[name]
+
+
+def from_pretrained(path: str, **kwargs):
+    """
+    Load a model from a pretrained checkpoint.
+
+    Args:
+        path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
+              NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
+        **kwargs: Additional arguments for the model constructor.
+    """
+    import os
+    import json
+    from safetensors.torch import load_file
+    is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
+
+    if is_local:
+        config_file = f"{path}.json"
+        model_file = f"{path}.safetensors"
+    else:
+        from huggingface_hub import hf_hub_download
+        path_parts = path.split('/')
+        repo_id = f'{path_parts[0]}/{path_parts[1]}'
+        model_name = '/'.join(path_parts[2:])
+        config_file = hf_hub_download(repo_id, f"{model_name}.json")
+        model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
+
+    with open(config_file, 'r') as f:
+        config = json.load(f)
+    model = __getattr__(config['name'])(**config['args'], **kwargs)
+    model.load_state_dict(load_file(model_file))
+
+    return model
+
+
+# For Pylance
+if __name__ == '__main__':
+    from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
+    from .sparse_structure_flow import SparseStructureFlowModel
+    from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
+    from .structured_latent_flow import SLatFlowModel
diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..aee71a9686fd3795960cf1df970e9b8db0ebd57a
--- /dev/null
+++ b/trellis/models/sparse_structure_flow.py
@@ -0,0 +1,200 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..modules.utils import convert_module_to_f16, convert_module_to_f32
+from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
+from ..modules.spatial import patchify, unpatchify
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, hidden_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    @staticmethod
+    def timestep_embedding(t, dim, max_period=10000):
+        """
+        Create sinusoidal timestep embeddings.
+
+        Args:
+            t: a 1-D Tensor of N indices, one per batch element.
+                These may be fractional.
+            dim: the dimension of the output.
+            max_period: controls the minimum frequency of the embeddings.
+
+        Returns:
+            an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+        half = dim // 2
+        freqs = torch.exp(
+            -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        ).to(device=t.device)
+        args = t[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+class SparseStructureFlowModel(nn.Module):
+    def __init__(
+        self,
+        resolution: int,
+        in_channels: int,
+        model_channels: int,
+        cond_channels: int,
+        out_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        patch_size: int = 2,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        share_mod: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+    ):
+        super().__init__()
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.cond_channels = cond_channels
+        self.out_channels = out_channels
+        self.num_blocks = num_blocks
+        self.num_heads = num_heads or model_channels // num_head_channels
+        self.mlp_ratio = mlp_ratio
+        self.patch_size = patch_size
+        self.pe_mode = pe_mode
+        self.use_fp16 = use_fp16
+        self.use_checkpoint = use_checkpoint
+        self.share_mod = share_mod
+        self.qk_rms_norm = qk_rms_norm
+        self.qk_rms_norm_cross = qk_rms_norm_cross
+        self.dtype = torch.float16 if use_fp16 else torch.float32
+
+        self.t_embedder = TimestepEmbedder(model_channels)
+        if share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(model_channels, 6 * model_channels, bias=True)
+            )
+
+        if pe_mode == "ape":
+            pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
+            coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
+            coords = torch.stack(coords, dim=-1).reshape(-1, 3)
+            pos_emb = pos_embedder(coords)
+            self.register_buffer("pos_emb", pos_emb)
+
+        self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
+            
+        self.blocks = nn.ModuleList([
+            ModulatedTransformerCrossBlock(
+                model_channels,
+                cond_channels,
+                num_heads=self.num_heads,
+                mlp_ratio=self.mlp_ratio,
+                attn_mode='full',
+                use_checkpoint=self.use_checkpoint,
+                use_rope=(pe_mode == "rope"),
+                share_mod=share_mod,
+                qk_rms_norm=self.qk_rms_norm,
+                qk_rms_norm_cross=self.qk_rms_norm_cross,
+            )
+            for _ in range(num_blocks)
+        ])
+
+        self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    @property
+    def device(self) -> torch.device:
+        """
+        Return the device of the model.
+        """
+        return next(self.parameters()).device
+
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        self.blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        self.blocks.apply(convert_module_to_f32)
+
+    def initialize_weights(self) -> None:
+        # Initialize transformer layers:
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+        # Initialize timestep embedding MLP:
+        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+        # Zero-out adaLN modulation layers in DiT blocks:
+        if self.share_mod:
+            nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+            nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+        else:
+            for block in self.blocks:
+                nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+                nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+        assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
+                f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
+
+        h = patchify(x, self.patch_size)
+        h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
+
+        h = self.input_layer(h)
+        h = h + self.pos_emb[None]
+        t_emb = self.t_embedder(t)
+        if self.share_mod:
+            t_emb = self.adaLN_modulation(t_emb)
+        t_emb = t_emb.type(self.dtype)
+        h = h.type(self.dtype)
+        cond = cond.type(self.dtype)
+        for block in self.blocks:
+            h = block(h, t_emb, cond)
+        h = h.type(x.dtype)
+        h = F.layer_norm(h, h.shape[-1:])
+        h = self.out_layer(h)
+
+        h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
+        h = unpatchify(h, self.patch_size).contiguous()
+
+        return h
diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e09136cf294c4c1b47b0f09fa6ee57bad2166d
--- /dev/null
+++ b/trellis/models/sparse_structure_vae.py
@@ -0,0 +1,306 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ..modules.norm import GroupNorm32, ChannelLayerNorm32
+from ..modules.spatial import pixel_shuffle_3d
+from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+
+
+def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
+    """
+    Return a normalization layer.
+    """
+    if norm_type == "group":
+        return GroupNorm32(32, *args, **kwargs)
+    elif norm_type == "layer":
+        return ChannelLayerNorm32(*args, **kwargs)
+    else:
+        raise ValueError(f"Invalid norm type {norm_type}")
+
+
+class ResBlock3d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        out_channels: Optional[int] = None,
+        norm_type: Literal["group", "layer"] = "layer",
+    ):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.norm1 = norm_layer(norm_type, channels)
+        self.norm2 = norm_layer(norm_type, self.out_channels)
+        self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
+        self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
+        self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
+    
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        h = self.norm1(x)
+        h = F.silu(h)
+        h = self.conv1(h)
+        h = self.norm2(h)
+        h = F.silu(h)
+        h = self.conv2(h)
+        h = h + self.skip_connection(x)
+        return h
+
+
+class DownsampleBlock3d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        mode: Literal["conv", "avgpool"] = "conv",
+    ):
+        assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
+
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        if mode == "conv":
+            self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
+        elif mode == "avgpool":
+            assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if hasattr(self, "conv"):
+            return self.conv(x)
+        else:
+            return F.avg_pool3d(x, 2)
+
+
+class UpsampleBlock3d(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        mode: Literal["conv", "nearest"] = "conv",
+    ):
+        assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
+
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        if mode == "conv":
+            self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
+        elif mode == "nearest":
+            assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if hasattr(self, "conv"):
+            x = self.conv(x)
+            return pixel_shuffle_3d(x, 2)
+        else:
+            return F.interpolate(x, scale_factor=2, mode="nearest")
+        
+
+class SparseStructureEncoder(nn.Module):
+    """
+    Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
+    
+    Args:
+        in_channels (int): Channels of the input.
+        latent_channels (int): Channels of the latent representation.
+        num_res_blocks (int): Number of residual blocks at each resolution.
+        channels (List[int]): Channels of the encoder blocks.
+        num_res_blocks_middle (int): Number of residual blocks in the middle.
+        norm_type (Literal["group", "layer"]): Type of normalization layer.
+        use_fp16 (bool): Whether to use FP16.
+    """
+    def __init__(
+        self,
+        in_channels: int,
+        latent_channels: int,
+        num_res_blocks: int,
+        channels: List[int],
+        num_res_blocks_middle: int = 2,
+        norm_type: Literal["group", "layer"] = "layer",
+        use_fp16: bool = False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.latent_channels = latent_channels
+        self.num_res_blocks = num_res_blocks
+        self.channels = channels
+        self.num_res_blocks_middle = num_res_blocks_middle
+        self.norm_type = norm_type
+        self.use_fp16 = use_fp16
+        self.dtype = torch.float16 if use_fp16 else torch.float32
+
+        self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
+
+        self.blocks = nn.ModuleList([])
+        for i, ch in enumerate(channels):
+            self.blocks.extend([
+                ResBlock3d(ch, ch)
+                for _ in range(num_res_blocks)
+            ])
+            if i < len(channels) - 1:
+                self.blocks.append(
+                    DownsampleBlock3d(ch, channels[i+1])
+                )
+        
+        self.middle_block = nn.Sequential(*[
+            ResBlock3d(channels[-1], channels[-1])
+            for _ in range(num_res_blocks_middle)
+        ])
+
+        self.out_layer = nn.Sequential(
+            norm_layer(norm_type, channels[-1]),
+            nn.SiLU(),
+            nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
+        )
+
+        if use_fp16:
+            self.convert_to_fp16()
+
+    @property
+    def device(self) -> torch.device:
+        """
+        Return the device of the model.
+        """
+        return next(self.parameters()).device
+
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        self.use_fp16 = True
+        self.dtype = torch.float16
+        self.blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        self.use_fp16 = False
+        self.dtype = torch.float32
+        self.blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+
+    def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
+        h = self.input_layer(x)
+        h = h.type(self.dtype)
+
+        for block in self.blocks:
+            h = block(h)
+        h = self.middle_block(h)
+
+        h = h.type(x.dtype)
+        h = self.out_layer(h)
+
+        mean, logvar = h.chunk(2, dim=1)
+
+        if sample_posterior:
+            std = torch.exp(0.5 * logvar)
+            z = mean + std * torch.randn_like(std)
+        else:
+            z = mean
+            
+        if return_raw:
+            return z, mean, logvar
+        return z
+        
+
+class SparseStructureDecoder(nn.Module):
+    """
+    Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
+    
+    Args:
+        out_channels (int): Channels of the output.
+        latent_channels (int): Channels of the latent representation.
+        num_res_blocks (int): Number of residual blocks at each resolution.
+        channels (List[int]): Channels of the decoder blocks.
+        num_res_blocks_middle (int): Number of residual blocks in the middle.
+        norm_type (Literal["group", "layer"]): Type of normalization layer.
+        use_fp16 (bool): Whether to use FP16.
+    """ 
+    def __init__(
+        self,
+        out_channels: int,
+        latent_channels: int,
+        num_res_blocks: int,
+        channels: List[int],
+        num_res_blocks_middle: int = 2,
+        norm_type: Literal["group", "layer"] = "layer",
+        use_fp16: bool = False,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.latent_channels = latent_channels
+        self.num_res_blocks = num_res_blocks
+        self.channels = channels
+        self.num_res_blocks_middle = num_res_blocks_middle
+        self.norm_type = norm_type
+        self.use_fp16 = use_fp16
+        self.dtype = torch.float16 if use_fp16 else torch.float32
+
+        self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
+
+        self.middle_block = nn.Sequential(*[
+            ResBlock3d(channels[0], channels[0])
+            for _ in range(num_res_blocks_middle)
+        ])
+
+        self.blocks = nn.ModuleList([])
+        for i, ch in enumerate(channels):
+            self.blocks.extend([
+                ResBlock3d(ch, ch)
+                for _ in range(num_res_blocks)
+            ])
+            if i < len(channels) - 1:
+                self.blocks.append(
+                    UpsampleBlock3d(ch, channels[i+1])
+                )
+
+        self.out_layer = nn.Sequential(
+            norm_layer(norm_type, channels[-1]),
+            nn.SiLU(),
+            nn.Conv3d(channels[-1], out_channels, 3, padding=1)
+        )
+
+        if use_fp16:
+            self.convert_to_fp16()
+
+    @property
+    def device(self) -> torch.device:
+        """
+        Return the device of the model.
+        """
+        return next(self.parameters()).device
+    
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        self.use_fp16 = True
+        self.dtype = torch.float16
+        self.blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        self.use_fp16 = False
+        self.dtype = torch.float32
+        self.blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+    
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        h = self.input_layer(x)
+        
+        h = h.type(self.dtype)
+                
+        h = self.middle_block(h)
+        for block in self.blocks:
+            h = block(h)
+
+        h = h.type(x.dtype)
+        h = self.out_layer(h)
+        return h
diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1463d79bc472ce3ef6859a42e10a06de1f9ebf7
--- /dev/null
+++ b/trellis/models/structured_latent_flow.py
@@ -0,0 +1,262 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+from ..modules.transformer import AbsolutePositionEmbedder
+from ..modules.norm import LayerNorm32
+from ..modules import sparse as sp
+from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
+from .sparse_structure_flow import TimestepEmbedder
+
+
+class SparseResBlock3d(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        emb_channels: int,
+        out_channels: Optional[int] = None,
+        downsample: bool = False,
+        upsample: bool = False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.out_channels = out_channels or channels
+        self.downsample = downsample
+        self.upsample = upsample
+        
+        assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
+
+        self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+        self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
+        self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
+        self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
+        )
+        self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
+        self.updown = None
+        if self.downsample:
+            self.updown = sp.SparseDownsample(2)
+        elif self.upsample:
+            self.updown = sp.SparseUpsample(2)
+
+    def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
+        if self.updown is not None:
+            x = self.updown(x)
+        return x
+
+    def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
+        emb_out = self.emb_layers(emb).type(x.dtype)
+        scale, shift = torch.chunk(emb_out, 2, dim=1)
+
+        x = self._updown(x)
+        h = x.replace(self.norm1(x.feats))
+        h = h.replace(F.silu(h.feats))
+        h = self.conv1(h)
+        h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
+        h = h.replace(F.silu(h.feats))
+        h = self.conv2(h)
+        h = h + self.skip_connection(x)
+
+        return h
+    
+
+class SLatFlowModel(nn.Module):
+    def __init__(
+        self,
+        resolution: int,
+        in_channels: int,
+        model_channels: int,
+        cond_channels: int,
+        out_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        patch_size: int = 2,
+        num_io_res_blocks: int = 2,
+        io_block_channels: List[int] = None,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        use_skip_connection: bool = True,
+        share_mod: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+    ):
+        super().__init__()
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.cond_channels = cond_channels
+        self.out_channels = out_channels
+        self.num_blocks = num_blocks
+        self.num_heads = num_heads or model_channels // num_head_channels
+        self.mlp_ratio = mlp_ratio
+        self.patch_size = patch_size
+        self.num_io_res_blocks = num_io_res_blocks
+        self.io_block_channels = io_block_channels
+        self.pe_mode = pe_mode
+        self.use_fp16 = use_fp16
+        self.use_checkpoint = use_checkpoint
+        self.use_skip_connection = use_skip_connection
+        self.share_mod = share_mod
+        self.qk_rms_norm = qk_rms_norm
+        self.qk_rms_norm_cross = qk_rms_norm_cross
+        self.dtype = torch.float16 if use_fp16 else torch.float32
+
+        assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
+        assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
+
+        self.t_embedder = TimestepEmbedder(model_channels)
+        if share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(model_channels, 6 * model_channels, bias=True)
+            )
+
+        if pe_mode == "ape":
+            self.pos_embedder = AbsolutePositionEmbedder(model_channels)
+
+        self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
+        self.input_blocks = nn.ModuleList([])
+        for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
+            self.input_blocks.extend([
+                SparseResBlock3d(
+                    chs,
+                    model_channels,
+                    out_channels=chs,
+                )
+                for _ in range(num_io_res_blocks-1)
+            ])
+            self.input_blocks.append(
+                SparseResBlock3d(
+                    chs,
+                    model_channels,
+                    out_channels=next_chs,
+                    downsample=True,
+                )
+            )
+            
+        self.blocks = nn.ModuleList([
+            ModulatedSparseTransformerCrossBlock(
+                model_channels,
+                cond_channels,
+                num_heads=self.num_heads,
+                mlp_ratio=self.mlp_ratio,
+                attn_mode='full',
+                use_checkpoint=self.use_checkpoint,
+                use_rope=(pe_mode == "rope"),
+                share_mod=self.share_mod,
+                qk_rms_norm=self.qk_rms_norm,
+                qk_rms_norm_cross=self.qk_rms_norm_cross,
+            )
+            for _ in range(num_blocks)
+        ])
+
+        self.out_blocks = nn.ModuleList([])
+        for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
+            self.out_blocks.append(
+                SparseResBlock3d(
+                    prev_chs * 2 if self.use_skip_connection else prev_chs,
+                    model_channels,
+                    out_channels=chs,
+                    upsample=True,
+                )
+            )
+            self.out_blocks.extend([
+                SparseResBlock3d(
+                    chs * 2 if self.use_skip_connection else chs,
+                    model_channels,
+                    out_channels=chs,
+                )
+                for _ in range(num_io_res_blocks-1)
+            ])
+        self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    @property
+    def device(self) -> torch.device:
+        """
+        Return the device of the model.
+        """
+        return next(self.parameters()).device
+
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.blocks.apply(convert_module_to_f16)
+        self.out_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.blocks.apply(convert_module_to_f32)
+        self.out_blocks.apply(convert_module_to_f32)
+
+    def initialize_weights(self) -> None:
+        # Initialize transformer layers:
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+        # Initialize timestep embedding MLP:
+        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+        # Zero-out adaLN modulation layers in DiT blocks:
+        if self.share_mod:
+            nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+            nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+        else:
+            for block in self.blocks:
+                nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+                nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
+        h = self.input_layer(x).type(self.dtype)
+        t_emb = self.t_embedder(t)
+        if self.share_mod:
+            t_emb = self.adaLN_modulation(t_emb)
+        t_emb = t_emb.type(self.dtype)
+        cond = cond.type(self.dtype)
+
+        skips = []
+        # pack with input blocks
+        for block in self.input_blocks:
+            h = block(h, t_emb)
+            skips.append(h.feats)
+        
+        if self.pe_mode == "ape":
+            h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
+        for block in self.blocks:
+            h = block(h, t_emb, cond)
+
+        # unpack with output blocks
+        for block, skip in zip(self.out_blocks, reversed(skips)):
+            if self.use_skip_connection:
+                h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
+            else:
+                h = block(h, t_emb)
+
+        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+        h = self.out_layer(h.type(x.dtype))
+        return h
diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..75603bc1d86c3036972c3d740ca7cb93d872f836
--- /dev/null
+++ b/trellis/models/structured_latent_vae/__init__.py
@@ -0,0 +1,4 @@
+from .encoder import SLatEncoder
+from .decoder_gs import SLatGaussianDecoder
+from .decoder_rf import SLatRadianceFieldDecoder
+from .decoder_mesh import SLatMeshDecoder
diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab0bf6a850b1c146e081c32ad92c7c44ead5ef6e
--- /dev/null
+++ b/trellis/models/structured_latent_vae/base.py
@@ -0,0 +1,117 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ...modules.utils import convert_module_to_f16, convert_module_to_f32
+from ...modules import sparse as sp
+from ...modules.transformer import AbsolutePositionEmbedder
+from ...modules.sparse.transformer import SparseTransformerBlock
+
+
+def block_attn_config(self):
+    """
+    Return the attention configuration of the model.
+    """
+    for i in range(self.num_blocks):
+        if self.attn_mode == "shift_window":
+            yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
+        elif self.attn_mode == "shift_sequence":
+            yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
+        elif self.attn_mode == "shift_order":
+            yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
+        elif self.attn_mode == "full":
+            yield "full", None, None, None, None
+        elif self.attn_mode == "swin":
+            yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
+
+
+class SparseTransformerBase(nn.Module):
+    """
+    Sparse Transformer without output layers.
+    Serve as the base class for encoder and decoder.
+    """
+    def __init__(
+        self,
+        in_channels: int,
+        model_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+        window_size: Optional[int] = None,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        qk_rms_norm: bool = False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.num_blocks = num_blocks
+        self.window_size = window_size
+        self.num_heads = num_heads or model_channels // num_head_channels
+        self.mlp_ratio = mlp_ratio
+        self.attn_mode = attn_mode
+        self.pe_mode = pe_mode
+        self.use_fp16 = use_fp16
+        self.use_checkpoint = use_checkpoint
+        self.qk_rms_norm = qk_rms_norm
+        self.dtype = torch.float16 if use_fp16 else torch.float32
+
+        if pe_mode == "ape":
+            self.pos_embedder = AbsolutePositionEmbedder(model_channels)
+
+        self.input_layer = sp.SparseLinear(in_channels, model_channels)
+        self.blocks = nn.ModuleList([
+            SparseTransformerBlock(
+                model_channels,
+                num_heads=self.num_heads,
+                mlp_ratio=self.mlp_ratio,
+                attn_mode=attn_mode,
+                window_size=window_size,
+                shift_sequence=shift_sequence,
+                shift_window=shift_window,
+                serialize_mode=serialize_mode,
+                use_checkpoint=self.use_checkpoint,
+                use_rope=(pe_mode == "rope"),
+                qk_rms_norm=self.qk_rms_norm,
+            )
+            for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
+        ])
+
+    @property
+    def device(self) -> torch.device:
+        """
+        Return the device of the model.
+        """
+        return next(self.parameters()).device
+
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        self.blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        self.blocks.apply(convert_module_to_f32)
+
+    def initialize_weights(self) -> None:
+        # Initialize transformer layers:
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+        h = self.input_layer(x)
+        if self.pe_mode == "ape":
+            h = h + self.pos_embedder(x.coords[:, 1:])
+        h = h.type(self.dtype)
+        for block in self.blocks:
+            h = block(h)
+        return h
diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py
new file mode 100644
index 0000000000000000000000000000000000000000..b893cfcfb2a166c7d57f96086a79317bd91884b9
--- /dev/null
+++ b/trellis/models/structured_latent_vae/decoder_gs.py
@@ -0,0 +1,122 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules import sparse as sp
+from ...utils.random_utils import hammersley_sequence
+from .base import SparseTransformerBase
+from ...representations import Gaussian
+
+
+class SLatGaussianDecoder(SparseTransformerBase):
+    def __init__(
+        self,
+        resolution: int,
+        model_channels: int,
+        latent_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+        window_size: int = 8,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        qk_rms_norm: bool = False,
+        representation_config: dict = None,
+    ):
+        super().__init__(
+            in_channels=latent_channels,
+            model_channels=model_channels,
+            num_blocks=num_blocks,
+            num_heads=num_heads,
+            num_head_channels=num_head_channels,
+            mlp_ratio=mlp_ratio,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            pe_mode=pe_mode,
+            use_fp16=use_fp16,
+            use_checkpoint=use_checkpoint,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.resolution = resolution
+        self.rep_config = representation_config
+        self._calc_layout()
+        self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
+        self._build_perturbation()
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    def initialize_weights(self) -> None:
+        super().initialize_weights()
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def _build_perturbation(self) -> None:
+        perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
+        perturbation = torch.tensor(perturbation).float() * 2 - 1
+        perturbation = perturbation / self.rep_config['voxel_size']
+        perturbation = torch.atanh(perturbation).to(self.device)
+        self.register_buffer('offset_perturbation', perturbation)
+
+    def _calc_layout(self) -> None:
+        self.layout = {
+            '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
+            '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
+            '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
+            '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
+            '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
+        }
+        start = 0
+        for k, v in self.layout.items():
+            v['range'] = (start, start + v['size'])
+            start += v['size']
+        self.out_channels = start
+    
+    def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
+        """
+        Convert a batch of network outputs to 3D representations.
+
+        Args:
+            x: The [N x * x C] sparse tensor output by the network.
+
+        Returns:
+            list of representations
+        """
+        ret = []
+        for i in range(x.shape[0]):
+            representation = Gaussian(
+                sh_degree=0,
+                aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
+                mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
+                scaling_bias = self.rep_config['scaling_bias'],
+                opacity_bias = self.rep_config['opacity_bias'],
+                scaling_activation = self.rep_config['scaling_activation']
+            )
+            xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
+            for k, v in self.layout.items():
+                if k == '_xyz':
+                    offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
+                    offset = offset * self.rep_config['lr'][k]
+                    if self.rep_config['perturb_offset']:
+                        offset = offset + self.offset_perturbation
+                    offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
+                    _xyz = xyz.unsqueeze(1) + offset
+                    setattr(representation, k, _xyz.flatten(0, 1))
+                else:
+                    feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
+                    feats = feats * self.rep_config['lr'][k]
+                    setattr(representation, k, feats)
+            ret.append(representation)
+        return ret
+
+    def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
+        h = super().forward(x)
+        h = h.type(x.dtype)
+        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+        h = self.out_layer(h)
+        return self.to_representation(h)
diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..75c1b1ec7b6fdc28e787be283e55589b36461e50
--- /dev/null
+++ b/trellis/models/structured_latent_vae/decoder_mesh.py
@@ -0,0 +1,167 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
+from ...modules import sparse as sp
+from .base import SparseTransformerBase
+from ...representations import MeshExtractResult
+from ...representations.mesh import SparseFeatures2Mesh
+
+
+class SparseSubdivideBlock3d(nn.Module):
+    """
+    A 3D subdivide block that can subdivide the sparse tensor.
+
+    Args:
+        channels: channels in the inputs and outputs.
+        out_channels: if specified, the number of output channels.
+        num_groups: the number of groups for the group norm.
+    """
+    def __init__(
+        self,
+        channels: int,
+        resolution: int,
+        out_channels: Optional[int] = None,
+        num_groups: int = 32
+    ):
+        super().__init__()
+        self.channels = channels
+        self.resolution = resolution
+        self.out_resolution = resolution * 2
+        self.out_channels = out_channels or channels
+
+        self.act_layers = nn.Sequential(
+            sp.SparseGroupNorm32(num_groups, channels),
+            sp.SparseSiLU()
+        )
+        
+        self.sub = sp.SparseSubdivide()
+        
+        self.out_layers = nn.Sequential(
+            sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
+            sp.SparseGroupNorm32(num_groups, self.out_channels),
+            sp.SparseSiLU(),
+            zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
+        )
+        
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        else:
+            self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
+        
+    def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+
+        Args:
+            x: an [N x C x ...] Tensor of features.
+        Returns:
+            an [N x C x ...] Tensor of outputs.
+        """
+        h = self.act_layers(x)
+        h = self.sub(h)
+        x = self.sub(x)
+        h = self.out_layers(h)
+        h = h + self.skip_connection(x)
+        return h
+
+
+class SLatMeshDecoder(SparseTransformerBase):
+    def __init__(
+        self,
+        resolution: int,
+        model_channels: int,
+        latent_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+        window_size: int = 8,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        qk_rms_norm: bool = False,
+        representation_config: dict = None,
+    ):
+        super().__init__(
+            in_channels=latent_channels,
+            model_channels=model_channels,
+            num_blocks=num_blocks,
+            num_heads=num_heads,
+            num_head_channels=num_head_channels,
+            mlp_ratio=mlp_ratio,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            pe_mode=pe_mode,
+            use_fp16=use_fp16,
+            use_checkpoint=use_checkpoint,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.resolution = resolution
+        self.rep_config = representation_config
+        self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
+        self.out_channels = self.mesh_extractor.feats_channels
+        self.upsample = nn.ModuleList([
+            SparseSubdivideBlock3d(
+                channels=model_channels,
+                resolution=resolution,
+                out_channels=model_channels // 4
+            ),
+            SparseSubdivideBlock3d(
+                channels=model_channels // 4,
+                resolution=resolution * 2,
+                out_channels=model_channels // 8
+            )
+        ])
+        self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    def initialize_weights(self) -> None:
+        super().initialize_weights()
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def convert_to_fp16(self) -> None:
+        """
+        Convert the torso of the model to float16.
+        """
+        super().convert_to_fp16()
+        self.upsample.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self) -> None:
+        """
+        Convert the torso of the model to float32.
+        """
+        super().convert_to_fp32()
+        self.upsample.apply(convert_module_to_f32)  
+    
+    def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
+        """
+        Convert a batch of network outputs to 3D representations.
+
+        Args:
+            x: The [N x * x C] sparse tensor output by the network.
+
+        Returns:
+            list of representations
+        """
+        ret = []
+        for i in range(x.shape[0]):
+            mesh = self.mesh_extractor(x[i], training=self.training)
+            ret.append(mesh)
+        return ret
+
+    def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
+        h = super().forward(x)
+        for block in self.upsample:
+            h = block(h)
+        h = h.type(x.dtype)
+        h = self.out_layer(h)
+        return self.to_representation(h)
diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py
new file mode 100644
index 0000000000000000000000000000000000000000..968bb30596647224292da0392dfdefeed49d214d
--- /dev/null
+++ b/trellis/models/structured_latent_vae/decoder_rf.py
@@ -0,0 +1,104 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ...modules import sparse as sp
+from .base import SparseTransformerBase
+from ...representations import Strivec
+
+
+class SLatRadianceFieldDecoder(SparseTransformerBase):
+    def __init__(
+        self,
+        resolution: int,
+        model_channels: int,
+        latent_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+        window_size: int = 8,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        qk_rms_norm: bool = False,
+        representation_config: dict = None,
+    ):
+        super().__init__(
+            in_channels=latent_channels,
+            model_channels=model_channels,
+            num_blocks=num_blocks,
+            num_heads=num_heads,
+            num_head_channels=num_head_channels,
+            mlp_ratio=mlp_ratio,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            pe_mode=pe_mode,
+            use_fp16=use_fp16,
+            use_checkpoint=use_checkpoint,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.resolution = resolution
+        self.rep_config = representation_config
+        self._calc_layout()
+        self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    def initialize_weights(self) -> None:
+        super().initialize_weights()
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def _calc_layout(self) -> None:
+        self.layout = {
+            'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
+            'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
+            'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
+        }
+        start = 0
+        for k, v in self.layout.items():
+            v['range'] = (start, start + v['size'])
+            start += v['size']
+        self.out_channels = start    
+    
+    def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
+        """
+        Convert a batch of network outputs to 3D representations.
+
+        Args:
+            x: The [N x * x C] sparse tensor output by the network.
+
+        Returns:
+            list of representations
+        """
+        ret = []
+        for i in range(x.shape[0]):
+            representation = Strivec(
+                sh_degree=0,
+                resolution=self.resolution,
+                aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
+                rank=self.rep_config['rank'],
+                dim=self.rep_config['dim'],
+                device='cuda',
+            )
+            representation.density_shift = 0.0
+            representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
+            representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
+            for k, v in self.layout.items():
+                setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
+            representation.trivec = representation.trivec + 1
+            ret.append(representation)
+        return ret
+
+    def forward(self, x: sp.SparseTensor) -> List[Strivec]:
+        h = super().forward(x)
+        h = h.type(x.dtype)
+        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+        h = self.out_layer(h)
+        return self.to_representation(h)
diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8370921d8d61954b43dcf3e251b8d9b315f4f536
--- /dev/null
+++ b/trellis/models/structured_latent_vae/encoder.py
@@ -0,0 +1,72 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ...modules import sparse as sp
+from .base import SparseTransformerBase
+
+
+class SLatEncoder(SparseTransformerBase):
+    def __init__(
+        self,
+        resolution: int,
+        in_channels: int,
+        model_channels: int,
+        latent_channels: int,
+        num_blocks: int,
+        num_heads: Optional[int] = None,
+        num_head_channels: Optional[int] = 64,
+        mlp_ratio: float = 4,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
+        window_size: int = 8,
+        pe_mode: Literal["ape", "rope"] = "ape",
+        use_fp16: bool = False,
+        use_checkpoint: bool = False,
+        qk_rms_norm: bool = False,
+    ):
+        super().__init__(
+            in_channels=in_channels,
+            model_channels=model_channels,
+            num_blocks=num_blocks,
+            num_heads=num_heads,
+            num_head_channels=num_head_channels,
+            mlp_ratio=mlp_ratio,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            pe_mode=pe_mode,
+            use_fp16=use_fp16,
+            use_checkpoint=use_checkpoint,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.resolution = resolution
+        self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
+
+        self.initialize_weights()
+        if use_fp16:
+            self.convert_to_fp16()
+
+    def initialize_weights(self) -> None:
+        super().initialize_weights()
+        # Zero-out output layers:
+        nn.init.constant_(self.out_layer.weight, 0)
+        nn.init.constant_(self.out_layer.bias, 0)
+
+    def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
+        h = super().forward(x)
+        h = h.type(x.dtype)
+        h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
+        h = self.out_layer(h)
+        
+        # Sample from the posterior distribution
+        mean, logvar = h.feats.chunk(2, dim=-1)
+        if sample_posterior:
+            std = torch.exp(0.5 * logvar)
+            z = mean + std * torch.randn_like(std)
+        else:
+            z = mean
+        z = h.replace(z)
+            
+        if return_raw:
+            return z, mean, logvar
+        else:
+            return z
diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..f452320d5dbc4c0aa1664e33f76c56ff4bbe2039
--- /dev/null
+++ b/trellis/modules/attention/__init__.py
@@ -0,0 +1,36 @@
+from typing import *
+
+BACKEND = 'flash_attn' 
+DEBUG = False
+
+def __from_env():
+    import os
+    
+    global BACKEND
+    global DEBUG
+    
+    env_attn_backend = os.environ.get('ATTN_BACKEND')
+    env_sttn_debug = os.environ.get('ATTN_DEBUG')
+    
+    if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
+        BACKEND = env_attn_backend
+    if env_sttn_debug is not None:
+        DEBUG = env_sttn_debug == '1'
+
+    print(f"[ATTENTION] Using backend: {BACKEND}")
+        
+
+__from_env()
+    
+
+def set_backend(backend: Literal['xformers', 'flash_attn']):
+    global BACKEND
+    BACKEND = backend
+
+def set_debug(debug: bool):
+    global DEBUG
+    DEBUG = debug
+
+
+from .full_attn import *
+from .modules import *
diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py
new file mode 100755
index 0000000000000000000000000000000000000000..d9ebf6380a78906d4c6e969c63223fb7b398e5a7
--- /dev/null
+++ b/trellis/modules/attention/full_attn.py
@@ -0,0 +1,140 @@
+from typing import *
+import torch
+import math
+from . import DEBUG, BACKEND
+
+if BACKEND == 'xformers':
+    import xformers.ops as xops
+elif BACKEND == 'flash_attn':
+    import flash_attn
+elif BACKEND == 'sdpa':
+    from torch.nn.functional import scaled_dot_product_attention as sdpa
+elif BACKEND == 'naive':
+    pass
+else:
+    raise ValueError(f"Unknown attention backend: {BACKEND}")
+
+
+__all__ = [
+    'scaled_dot_product_attention',
+]
+
+
+def _naive_sdpa(q, k, v):
+    """
+    Naive implementation of scaled dot product attention.
+    """
+    q = q.permute(0, 2, 1, 3)   # [N, H, L, C]
+    k = k.permute(0, 2, 1, 3)   # [N, H, L, C]
+    v = v.permute(0, 2, 1, 3)   # [N, H, L, C]
+    scale_factor = 1 / math.sqrt(q.size(-1))
+    attn_weight = q @ k.transpose(-2, -1) * scale_factor
+    attn_weight = torch.softmax(attn_weight, dim=-1)
+    out = attn_weight @ v
+    out = out.permute(0, 2, 1, 3)   # [N, L, H, C]
+    return out
+
+
+@overload
+def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
+    """
+    Apply scaled dot product attention.
+
+    Args:
+        qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
+    """
+    ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
+    """
+    Apply scaled dot product attention.
+
+    Args:
+        q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
+        kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
+    """
+    ...
+
+@overload
+def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+    """
+    Apply scaled dot product attention.
+
+    Args:
+        q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
+        k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
+        v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
+
+    Note:
+        k and v are assumed to have the same coordinate map.
+    """
+    ...
+
+def scaled_dot_product_attention(*args, **kwargs):
+    arg_names_dict = {
+        1: ['qkv'],
+        2: ['q', 'kv'],
+        3: ['q', 'k', 'v']
+    }
+    num_all_args = len(args) + len(kwargs)
+    assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+    for key in arg_names_dict[num_all_args][len(args):]:
+        assert key in kwargs, f"Missing argument {key}"
+
+    if num_all_args == 1:
+        qkv = args[0] if len(args) > 0 else kwargs['qkv']
+        assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
+        device = qkv.device
+
+    elif num_all_args == 2:
+        q = args[0] if len(args) > 0 else kwargs['q']
+        kv = args[1] if len(args) > 1 else kwargs['kv']
+        assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+        assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+        assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+        device = q.device
+
+    elif num_all_args == 3:
+        q = args[0] if len(args) > 0 else kwargs['q']
+        k = args[1] if len(args) > 1 else kwargs['k']
+        v = args[2] if len(args) > 2 else kwargs['v']
+        assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+        assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+        assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+        assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+        device = q.device    
+
+    if BACKEND == 'xformers':
+        if num_all_args == 1:
+            q, k, v = qkv.unbind(dim=2)
+        elif num_all_args == 2:
+            k, v = kv.unbind(dim=2)
+        out = xops.memory_efficient_attention(q, k, v)
+    elif BACKEND == 'flash_attn':
+        if num_all_args == 1:
+            out = flash_attn.flash_attn_qkvpacked_func(qkv)
+        elif num_all_args == 2:
+            out = flash_attn.flash_attn_kvpacked_func(q, kv)
+        elif num_all_args == 3:
+            out = flash_attn.flash_attn_func(q, k, v)
+    elif BACKEND == 'sdpa':
+        if num_all_args == 1:
+            q, k, v = qkv.unbind(dim=2)
+        elif num_all_args == 2:
+            k, v = kv.unbind(dim=2)
+        q = q.permute(0, 2, 1, 3)   # [N, H, L, C]
+        k = k.permute(0, 2, 1, 3)   # [N, H, L, C]
+        v = v.permute(0, 2, 1, 3)   # [N, H, L, C]
+        out = sdpa(q, k, v)         # [N, H, L, C]
+        out = out.permute(0, 2, 1, 3)   # [N, L, H, C]
+    elif BACKEND == 'naive':
+        if num_all_args == 1:
+            q, k, v = qkv.unbind(dim=2)
+        elif num_all_args == 2:
+            k, v = kv.unbind(dim=2)
+        out = _naive_sdpa(q, k, v)
+    else:
+        raise ValueError(f"Unknown attention module: {BACKEND}")
+    
+    return out
diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..dbe6235c27134f0477e48d3e12de3068c6a500ef
--- /dev/null
+++ b/trellis/modules/attention/modules.py
@@ -0,0 +1,146 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .full_attn import scaled_dot_product_attention
+
+
+class MultiHeadRMSNorm(nn.Module):
+    def __init__(self, dim: int, heads: int):
+        super().__init__()
+        self.scale = dim ** 0.5
+        self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
+
+
+class RotaryPositionEmbedder(nn.Module):
+    def __init__(self, hidden_size: int, in_channels: int = 3):
+        super().__init__()
+        assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
+        self.hidden_size = hidden_size
+        self.in_channels = in_channels
+        self.freq_dim = hidden_size // in_channels // 2
+        self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+        self.freqs = 1.0 / (10000 ** self.freqs)
+        
+    def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
+        self.freqs = self.freqs.to(indices.device)
+        phases = torch.outer(indices, self.freqs)
+        phases = torch.polar(torch.ones_like(phases), phases)
+        return phases
+        
+    def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
+        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+        x_rotated = x_complex * phases
+        x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
+        return x_embed
+        
+    def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            q (sp.SparseTensor): [..., N, D] tensor of queries
+            k (sp.SparseTensor): [..., N, D] tensor of keys
+            indices (torch.Tensor): [..., N, C] tensor of spatial positions
+        """
+        if indices is None:
+            indices = torch.arange(q.shape[-2], device=q.device)
+            if len(q.shape) > 2:
+                indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
+        
+        phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
+        if phases.shape[1] < self.hidden_size // 2:
+            phases = torch.cat([phases, torch.polar(
+                torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
+                torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
+            )], dim=-1)
+        q_embed = self._rotary_embedding(q, phases)
+        k_embed = self._rotary_embedding(k, phases)
+        return q_embed, k_embed
+    
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        ctx_channels: Optional[int]=None,
+        type: Literal["self", "cross"] = "self",
+        attn_mode: Literal["full", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        qkv_bias: bool = True,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+    ):
+        super().__init__()
+        assert channels % num_heads == 0
+        assert type in ["self", "cross"], f"Invalid attention type: {type}"
+        assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
+        assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+        
+        if attn_mode == "windowed":
+            raise NotImplementedError("Windowed attention is not yet implemented")
+        
+        self.channels = channels
+        self.head_dim = channels // num_heads
+        self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+        self.num_heads = num_heads
+        self._type = type
+        self.attn_mode = attn_mode
+        self.window_size = window_size
+        self.shift_window = shift_window
+        self.use_rope = use_rope
+        self.qk_rms_norm = qk_rms_norm
+
+        if self._type == "self":
+            self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+        else:
+            self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+            self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+            
+        if self.qk_rms_norm:
+            self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+            self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
+            
+        self.to_out = nn.Linear(channels, channels)
+
+        if use_rope:
+            self.rope = RotaryPositionEmbedder(channels)
+    
+    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
+        B, L, C = x.shape
+        if self._type == "self":
+            qkv = self.to_qkv(x)
+            qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
+            if self.use_rope:
+                q, k, v = qkv.unbind(dim=2)
+                q, k = self.rope(q, k, indices)
+                qkv = torch.stack([q, k, v], dim=2)
+            if self.attn_mode == "full":
+                if self.qk_rms_norm:
+                    q, k, v = qkv.unbind(dim=2)
+                    q = self.q_rms_norm(q)
+                    k = self.k_rms_norm(k)
+                    h = scaled_dot_product_attention(q, k, v)
+                else:
+                    h = scaled_dot_product_attention(qkv)
+            elif self.attn_mode == "windowed":
+                raise NotImplementedError("Windowed attention is not yet implemented")
+        else:
+            Lkv = context.shape[1]
+            q = self.to_q(x)
+            kv = self.to_kv(context)
+            q = q.reshape(B, L, self.num_heads, -1)
+            kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
+            if self.qk_rms_norm:
+                q = self.q_rms_norm(q)
+                k, v = kv.unbind(dim=2)
+                k = self.k_rms_norm(k)
+                h = scaled_dot_product_attention(q, k, v)
+            else:
+                h = scaled_dot_product_attention(q, kv)
+        h = h.reshape(B, L, -1)
+        h = self.to_out(h)
+        return h
diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f
--- /dev/null
+++ b/trellis/modules/norm.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+
+class LayerNorm32(nn.LayerNorm):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return super().forward(x.float()).type(x.dtype)
+    
+
+class GroupNorm32(nn.GroupNorm):
+    """
+    A GroupNorm layer that converts to float32 before the forward pass.
+    """
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return super().forward(x.float()).type(x.dtype)
+    
+    
+class ChannelLayerNorm32(LayerNorm32):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        DIM = x.dim()
+        x = x.permute(0, *range(2, DIM), 1).contiguous()
+        x = super().forward(x)
+        x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
+        return x
+    
\ No newline at end of file
diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..5b04954c494a12c9534e36724d0113cb67916788
--- /dev/null
+++ b/trellis/modules/sparse/__init__.py
@@ -0,0 +1,100 @@
+from typing import *
+
+BACKEND = 'spconv' 
+DEBUG = False
+ATTN = 'flash_attn'
+
+def __from_env():
+    import os
+    
+    global BACKEND
+    global DEBUG
+    global ATTN
+    
+    env_sparse_backend = os.environ.get('SPARSE_BACKEND')
+    env_sparse_debug = os.environ.get('SPARSE_DEBUG')
+    env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
+    if env_sparse_attn is None:
+        env_sparse_attn = os.environ.get('ATTN_BACKEND')
+
+    if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
+        BACKEND = env_sparse_backend
+    if env_sparse_debug is not None:
+        DEBUG = env_sparse_debug == '1'
+    if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
+        ATTN = env_sparse_attn
+        
+
+__from_env()
+    
+
+def set_backend(backend: Literal['spconv', 'torchsparse']):
+    global BACKEND
+    BACKEND = backend
+
+def set_debug(debug: bool):
+    global DEBUG
+    DEBUG = debug
+
+def set_attn(attn: Literal['xformers', 'flash_attn']):
+    global ATTN
+    ATTN = attn
+    
+    
+import importlib
+
+__attributes = {
+    'SparseTensor': 'basic',
+    'sparse_batch_broadcast': 'basic',
+    'sparse_batch_op': 'basic',
+    'sparse_cat': 'basic',
+    'sparse_unbind': 'basic',
+    'SparseGroupNorm': 'norm',
+    'SparseLayerNorm': 'norm',
+    'SparseGroupNorm32': 'norm',
+    'SparseLayerNorm32': 'norm',
+    'SparseReLU': 'nonlinearity',
+    'SparseSiLU': 'nonlinearity',
+    'SparseGELU': 'nonlinearity',
+    'SparseActivation': 'nonlinearity',
+    'SparseLinear': 'linear',
+    'sparse_scaled_dot_product_attention': 'attention',
+    'SerializeMode': 'attention',
+    'sparse_serialized_scaled_dot_product_self_attention': 'attention',
+    'sparse_windowed_scaled_dot_product_self_attention': 'attention',
+    'SparseMultiHeadAttention': 'attention',
+    'SparseConv3d': 'conv',
+    'SparseInverseConv3d': 'conv',
+    'SparseDownsample': 'spatial',
+    'SparseUpsample': 'spatial',
+    'SparseSubdivide' : 'spatial'
+}
+
+__submodules = ['transformer']
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+    if name not in globals():
+        if name in __attributes:
+            module_name = __attributes[name]
+            module = importlib.import_module(f".{module_name}", __name__)
+            globals()[name] = getattr(module, name)
+        elif name in __submodules:
+            module = importlib.import_module(f".{name}", __name__)
+            globals()[name] = module
+        else:
+            raise AttributeError(f"module {__name__} has no attribute {name}")
+    return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+    from .basic import *
+    from .norm import *
+    from .nonlinearity import *
+    from .linear import *
+    from .attention import *
+    from .conv import *
+    from .spatial import *
+    import transformer
diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..32b3c2c837c613e41755ac4c85f9ed057a6f5bfb
--- /dev/null
+++ b/trellis/modules/sparse/attention/__init__.py
@@ -0,0 +1,4 @@
+from .full_attn import *
+from .serialized_attn import *
+from .windowed_attn import *
+from .modules import *
diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40
--- /dev/null
+++ b/trellis/modules/sparse/attention/full_attn.py
@@ -0,0 +1,215 @@
+from typing import *
+import torch
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+    import xformers.ops as xops
+elif ATTN == 'flash_attn':
+    import flash_attn
+else:
+    raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+    'sparse_scaled_dot_product_attention',
+]
+
+
+@overload
+def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+    """
+    ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
+        kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
+    """
+    ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
+        kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
+    """
+    ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+        k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+        v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
+
+    Note:
+        k and v are assumed to have the same coordinate map.
+    """
+    ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
+        k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
+        v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
+    """
+    ...
+
+@overload
+def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
+    """
+    Apply scaled dot product attention to a sparse tensor.
+
+    Args:
+        q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
+        k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
+        v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
+    """
+    ...
+
+def sparse_scaled_dot_product_attention(*args, **kwargs):
+    arg_names_dict = {
+        1: ['qkv'],
+        2: ['q', 'kv'],
+        3: ['q', 'k', 'v']
+    }
+    num_all_args = len(args) + len(kwargs)
+    assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
+    for key in arg_names_dict[num_all_args][len(args):]:
+        assert key in kwargs, f"Missing argument {key}"
+
+    if num_all_args == 1:
+        qkv = args[0] if len(args) > 0 else kwargs['qkv']
+        assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
+        assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+        device = qkv.device
+
+        s = qkv
+        q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
+        kv_seqlen = q_seqlen
+        qkv = qkv.feats     # [T, 3, H, C]
+
+    elif num_all_args == 2:
+        q = args[0] if len(args) > 0 else kwargs['q']
+        kv = args[1] if len(args) > 1 else kwargs['kv']
+        assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
+               isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
+               f"Invalid types, got {type(q)} and {type(kv)}"
+        assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
+        device = q.device
+
+        if isinstance(q, SparseTensor):
+            assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
+            s = q
+            q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+            q = q.feats     # [T_Q, H, C]
+        else:
+            assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
+            s = None
+            N, L, H, C = q.shape
+            q_seqlen = [L] * N
+            q = q.reshape(N * L, H, C)   # [T_Q, H, C]
+
+        if isinstance(kv, SparseTensor):
+            assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
+            kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
+            kv = kv.feats     # [T_KV, 2, H, C]
+        else:
+            assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
+            N, L, _, H, C = kv.shape
+            kv_seqlen = [L] * N
+            kv = kv.reshape(N * L, 2, H, C)   # [T_KV, 2, H, C]
+
+    elif num_all_args == 3:
+        q = args[0] if len(args) > 0 else kwargs['q']
+        k = args[1] if len(args) > 1 else kwargs['k']
+        v = args[2] if len(args) > 2 else kwargs['v']
+        assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
+               isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
+               f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
+        assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
+        device = q.device
+
+        if isinstance(q, SparseTensor):
+            assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
+            s = q
+            q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
+            q = q.feats     # [T_Q, H, Ci]
+        else:
+            assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
+            s = None
+            N, L, H, CI = q.shape
+            q_seqlen = [L] * N
+            q = q.reshape(N * L, H, CI)  # [T_Q, H, Ci]
+
+        if isinstance(k, SparseTensor):
+            assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
+            assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
+            kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
+            k = k.feats     # [T_KV, H, Ci]
+            v = v.feats     # [T_KV, H, Co]
+        else:
+            assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
+            assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
+            N, L, H, CI, CO = *k.shape, v.shape[-1]
+            kv_seqlen = [L] * N
+            k = k.reshape(N * L, H, CI)     # [T_KV, H, Ci]
+            v = v.reshape(N * L, H, CO)     # [T_KV, H, Co]
+
+    if DEBUG:
+        if s is not None:
+            for i in range(s.shape[0]):
+                assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
+        if num_all_args in [2, 3]:
+            assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
+        if num_all_args == 3:
+            assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
+            assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
+
+    if ATTN == 'xformers':
+        if num_all_args == 1:
+            q, k, v = qkv.unbind(dim=1)
+        elif num_all_args == 2:
+            k, v = kv.unbind(dim=1)
+        q = q.unsqueeze(0)
+        k = k.unsqueeze(0)
+        v = v.unsqueeze(0)
+        mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+        out = xops.memory_efficient_attention(q, k, v, mask)[0]
+    elif ATTN == 'flash_attn':
+        cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
+        if num_all_args in [2, 3]:
+            cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
+        if num_all_args == 1:
+            out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
+        elif num_all_args == 2:
+            out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+        elif num_all_args == 3:
+            out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
+    else:
+        raise ValueError(f"Unknown attention module: {ATTN}")
+    
+    if s is not None:
+        return s.replace(out)
+    else:
+        return out.reshape(N, L, H, -1)
diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..5d2fe782b0947700e308e9ec0325e7e91c84e3c2
--- /dev/null
+++ b/trellis/modules/sparse/attention/modules.py
@@ -0,0 +1,139 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .. import SparseTensor
+from .full_attn import sparse_scaled_dot_product_attention
+from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
+from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
+from ...attention import RotaryPositionEmbedder
+
+
+class SparseMultiHeadRMSNorm(nn.Module):
+    def __init__(self, dim: int, heads: int):
+        super().__init__()
+        self.scale = dim ** 0.5
+        self.gamma = nn.Parameter(torch.ones(heads, dim))
+
+    def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
+        x_type = x.dtype
+        x = x.float()
+        if isinstance(x, SparseTensor):
+            x = x.replace(F.normalize(x.feats, dim=-1))
+        else:
+            x = F.normalize(x, dim=-1)            
+        return (x * self.gamma * self.scale).to(x_type)
+
+
+class SparseMultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        ctx_channels: Optional[int] = None,
+        type: Literal["self", "cross"] = "self",
+        attn_mode: Literal["full", "serialized", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_sequence: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        serialize_mode: Optional[SerializeMode] = None,
+        qkv_bias: bool = True,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+    ):
+        super().__init__()
+        assert channels % num_heads == 0
+        assert type in ["self", "cross"], f"Invalid attention type: {type}"
+        assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
+        assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
+        assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
+        self.channels = channels
+        self.ctx_channels = ctx_channels if ctx_channels is not None else channels
+        self.num_heads = num_heads
+        self._type = type
+        self.attn_mode = attn_mode
+        self.window_size = window_size
+        self.shift_sequence = shift_sequence
+        self.shift_window = shift_window
+        self.serialize_mode = serialize_mode
+        self.use_rope = use_rope
+        self.qk_rms_norm = qk_rms_norm
+
+        if self._type == "self":
+            self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
+        else:
+            self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
+            self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
+        
+        if self.qk_rms_norm:
+            self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
+            self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
+            
+        self.to_out = nn.Linear(channels, channels)
+
+        if use_rope:
+            self.rope = RotaryPositionEmbedder(channels)
+
+    @staticmethod
+    def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
+        if isinstance(x, SparseTensor):
+            return x.replace(module(x.feats))
+        else:
+            return module(x)
+
+    @staticmethod
+    def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
+        if isinstance(x, SparseTensor):
+            return x.reshape(*shape)
+        else:
+            return x.reshape(*x.shape[:2], *shape)
+
+    def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
+        if isinstance(x, SparseTensor):
+            x_feats = x.feats.unsqueeze(0)
+        else:
+            x_feats = x
+        x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
+        return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
+
+    def _rope(self, qkv: SparseTensor) -> SparseTensor:
+        q, k, v = qkv.feats.unbind(dim=1)   # [T, H, C]
+        q, k = self.rope(q, k, qkv.coords[:, 1:])
+        qkv = qkv.replace(torch.stack([q, k, v], dim=1)) 
+        return qkv
+    
+    def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
+        if self._type == "self":
+            qkv = self._linear(self.to_qkv, x)
+            qkv = self._fused_pre(qkv, num_fused=3)
+            if self.use_rope:
+                qkv = self._rope(qkv)
+            if self.qk_rms_norm:
+                q, k, v = qkv.unbind(dim=1)
+                q = self.q_rms_norm(q)
+                k = self.k_rms_norm(k)
+                qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
+            if self.attn_mode == "full":
+                h = sparse_scaled_dot_product_attention(qkv)
+            elif self.attn_mode == "serialized":
+                h = sparse_serialized_scaled_dot_product_self_attention(
+                    qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
+                )
+            elif self.attn_mode == "windowed":
+                h = sparse_windowed_scaled_dot_product_self_attention(
+                    qkv, self.window_size, shift_window=self.shift_window
+                )
+        else:
+            q = self._linear(self.to_q, x)
+            q = self._reshape_chs(q, (self.num_heads, -1))
+            kv = self._linear(self.to_kv, context)
+            kv = self._fused_pre(kv, num_fused=2)
+            if self.qk_rms_norm:
+                q = self.q_rms_norm(q)
+                k, v = kv.unbind(dim=1)
+                k = self.k_rms_norm(k)
+                kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
+            h = sparse_scaled_dot_product_attention(q, kv)
+        h = self._reshape_chs(h, (-1,))
+        h = self._linear(self.to_out, h)
+        return h
diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py
new file mode 100755
index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26
--- /dev/null
+++ b/trellis/modules/sparse/attention/serialized_attn.py
@@ -0,0 +1,193 @@
+from typing import *
+from enum import Enum
+import torch
+import math
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+    import xformers.ops as xops
+elif ATTN == 'flash_attn':
+    import flash_attn
+else:
+    raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+    'sparse_serialized_scaled_dot_product_self_attention',
+]
+
+
+class SerializeMode(Enum):
+    Z_ORDER = 0
+    Z_ORDER_TRANSPOSED = 1
+    HILBERT = 2
+    HILBERT_TRANSPOSED = 3
+
+
+SerializeModes = [
+    SerializeMode.Z_ORDER,
+    SerializeMode.Z_ORDER_TRANSPOSED,
+    SerializeMode.HILBERT,
+    SerializeMode.HILBERT_TRANSPOSED
+]
+
+
+def calc_serialization(
+    tensor: SparseTensor,
+    window_size: int,
+    serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
+    shift_sequence: int = 0,
+    shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
+    """
+    Calculate serialization and partitioning for a set of coordinates.
+
+    Args:
+        tensor (SparseTensor): The input tensor.
+        window_size (int): The window size to use.
+        serialize_mode (SerializeMode): The serialization mode to use.
+        shift_sequence (int): The shift of serialized sequence.
+        shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+
+    Returns:
+        (torch.Tensor, torch.Tensor): Forwards and backwards indices.
+    """
+    fwd_indices = []
+    bwd_indices = []
+    seq_lens = []
+    seq_batch_indices = []
+    offsets = [0]
+    
+    if 'vox2seq' not in globals():
+        import vox2seq
+
+    # Serialize the input
+    serialize_coords = tensor.coords[:, 1:].clone()
+    serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
+    if serialize_mode == SerializeMode.Z_ORDER:
+        code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
+    elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
+        code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
+    elif serialize_mode == SerializeMode.HILBERT:
+        code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
+    elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
+        code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
+    else:
+        raise ValueError(f"Unknown serialize mode: {serialize_mode}")
+    
+    for bi, s in enumerate(tensor.layout):
+        num_points = s.stop - s.start
+        num_windows = (num_points + window_size - 1) // window_size
+        valid_window_size = num_points / num_windows
+        to_ordered = torch.argsort(code[s.start:s.stop])
+        if num_windows == 1:
+            fwd_indices.append(to_ordered)
+            bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
+            fwd_indices[-1] += s.start
+            bwd_indices[-1] += offsets[-1]
+            seq_lens.append(num_points)
+            seq_batch_indices.append(bi)
+            offsets.append(offsets[-1] + seq_lens[-1])
+        else:
+            # Partition the input
+            offset = 0
+            mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
+            split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
+            bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
+            for i in range(num_windows):
+                mid = mids[i]
+                valid_start = split[i]
+                valid_end = split[i + 1]
+                padded_start = math.floor(mid - 0.5 * window_size)
+                padded_end = padded_start + window_size
+                fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
+                offset += valid_start - padded_start
+                bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
+                offset += padded_end - valid_start
+                fwd_indices[-1] += s.start
+            seq_lens.extend([window_size] * num_windows)
+            seq_batch_indices.extend([bi] * num_windows)
+            bwd_indices.append(bwd_index + offsets[-1])
+            offsets.append(offsets[-1] + num_windows * window_size)
+
+    fwd_indices = torch.cat(fwd_indices)
+    bwd_indices = torch.cat(bwd_indices)
+
+    return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
+    
+
+def sparse_serialized_scaled_dot_product_self_attention(
+    qkv: SparseTensor,
+    window_size: int,
+    serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
+    shift_sequence: int = 0,
+    shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+    """
+    Apply serialized scaled dot product self attention to a sparse tensor.
+
+    Args:
+        qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+        window_size (int): The window size to use.
+        serialize_mode (SerializeMode): The serialization mode to use.
+        shift_sequence (int): The shift of serialized sequence.
+        shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+        shift (int): The shift to use.
+    """
+    assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+
+    serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
+    serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
+    if serialization_spatial_cache is None:
+        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
+        qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
+    else:
+        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
+
+    M = fwd_indices.shape[0]
+    T = qkv.feats.shape[0]
+    H = qkv.feats.shape[2]
+    C = qkv.feats.shape[3]
+    
+    qkv_feats = qkv.feats[fwd_indices]      # [M, 3, H, C]
+
+    if DEBUG:
+        start = 0
+        qkv_coords = qkv.coords[fwd_indices]
+        for i in range(len(seq_lens)):
+            assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
+            start += seq_lens[i]
+
+    if all([seq_len == window_size for seq_len in seq_lens]):
+        B = len(seq_lens)
+        N = window_size
+        qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
+        if ATTN == 'xformers':
+            q, k, v = qkv_feats.unbind(dim=2)                       # [B, N, H, C]
+            out = xops.memory_efficient_attention(q, k, v)          # [B, N, H, C]
+        elif ATTN == 'flash_attn':
+            out = flash_attn.flash_attn_qkvpacked_func(qkv_feats)   # [B, N, H, C]
+        else:
+            raise ValueError(f"Unknown attention module: {ATTN}")
+        out = out.reshape(B * N, H, C)                              # [M, H, C]
+    else:
+        if ATTN == 'xformers':
+            q, k, v = qkv_feats.unbind(dim=1)                       # [M, H, C]
+            q = q.unsqueeze(0)                                      # [1, M, H, C]
+            k = k.unsqueeze(0)                                      # [1, M, H, C]
+            v = v.unsqueeze(0)                                      # [1, M, H, C]
+            mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
+            out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
+        elif ATTN == 'flash_attn':
+            cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
+                        .to(qkv.device).int()
+            out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
+
+    out = out[bwd_indices]      # [T, H, C]
+
+    if DEBUG:
+        qkv_coords = qkv_coords[bwd_indices]
+        assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
+
+    return qkv.replace(out)
diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py
new file mode 100755
index 0000000000000000000000000000000000000000..cd642c5252e29a3a5e59fad7ed3880b7b00bcf9a
--- /dev/null
+++ b/trellis/modules/sparse/attention/windowed_attn.py
@@ -0,0 +1,135 @@
+from typing import *
+import torch
+import math
+from .. import SparseTensor
+from .. import DEBUG, ATTN
+
+if ATTN == 'xformers':
+    import xformers.ops as xops
+elif ATTN == 'flash_attn':
+    import flash_attn
+else:
+    raise ValueError(f"Unknown attention module: {ATTN}")
+
+
+__all__ = [
+    'sparse_windowed_scaled_dot_product_self_attention',
+]
+
+
+def calc_window_partition(
+    tensor: SparseTensor,
+    window_size: Union[int, Tuple[int, ...]],
+    shift_window: Union[int, Tuple[int, ...]] = 0
+) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
+    """
+    Calculate serialization and partitioning for a set of coordinates.
+
+    Args:
+        tensor (SparseTensor): The input tensor.
+        window_size (int): The window size to use.
+        shift_window (Tuple[int, ...]): The shift of serialized coordinates.
+
+    Returns:
+        (torch.Tensor): Forwards indices.
+        (torch.Tensor): Backwards indices.
+        (List[int]): Sequence lengths.
+        (List[int]): Sequence batch indices.
+    """
+    DIM = tensor.coords.shape[1] - 1
+    shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
+    window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
+    shifted_coords = tensor.coords.clone().detach()
+    shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+
+    MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
+    NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
+    OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
+
+    shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
+    shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
+    fwd_indices = torch.argsort(shifted_indices)
+    bwd_indices = torch.empty_like(fwd_indices)
+    bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
+    seq_lens = torch.bincount(shifted_indices)
+    seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
+    mask = seq_lens != 0
+    seq_lens = seq_lens[mask].tolist()
+    seq_batch_indices = seq_batch_indices[mask].tolist()
+
+    return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
+    
+
+def sparse_windowed_scaled_dot_product_self_attention(
+    qkv: SparseTensor,
+    window_size: int,
+    shift_window: Tuple[int, int, int] = (0, 0, 0)
+) -> SparseTensor:
+    """
+    Apply windowed scaled dot product self attention to a sparse tensor.
+
+    Args:
+        qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
+        window_size (int): The window size to use.
+        shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
+        shift (int): The shift to use.
+    """
+    assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
+
+    serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
+    serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
+    if serialization_spatial_cache is None:
+        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
+        qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
+    else:
+        fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
+
+    M = fwd_indices.shape[0]
+    T = qkv.feats.shape[0]
+    H = qkv.feats.shape[2]
+    C = qkv.feats.shape[3]
+    
+    qkv_feats = qkv.feats[fwd_indices]      # [M, 3, H, C]
+
+    if DEBUG:
+        start = 0
+        qkv_coords = qkv.coords[fwd_indices]
+        for i in range(len(seq_lens)):
+            seq_coords = qkv_coords[start:start+seq_lens[i]]
+            assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
+            assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
+                    f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
+            start += seq_lens[i]
+
+    if all([seq_len == window_size for seq_len in seq_lens]):
+        B = len(seq_lens)
+        N = window_size
+        qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
+        if ATTN == 'xformers':
+            q, k, v = qkv_feats.unbind(dim=2)                       # [B, N, H, C]
+            out = xops.memory_efficient_attention(q, k, v)          # [B, N, H, C]
+        elif ATTN == 'flash_attn':
+            out = flash_attn.flash_attn_qkvpacked_func(qkv_feats)   # [B, N, H, C]
+        else:
+            raise ValueError(f"Unknown attention module: {ATTN}")
+        out = out.reshape(B * N, H, C)                              # [M, H, C]
+    else:
+        if ATTN == 'xformers':
+            q, k, v = qkv_feats.unbind(dim=1)                       # [M, H, C]
+            q = q.unsqueeze(0)                                      # [1, M, H, C]
+            k = k.unsqueeze(0)                                      # [1, M, H, C]
+            v = v.unsqueeze(0)                                      # [1, M, H, C]
+            mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
+            out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
+        elif ATTN == 'flash_attn':
+            cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
+                        .to(qkv.device).int()
+            out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
+
+    out = out[bwd_indices]      # [T, H, C]
+
+    if DEBUG:
+        qkv_coords = qkv_coords[bwd_indices]
+        assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
+
+    return qkv.replace(out)
diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py
new file mode 100755
index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5
--- /dev/null
+++ b/trellis/modules/sparse/basic.py
@@ -0,0 +1,459 @@
+from typing import *
+import torch
+import torch.nn as nn
+from . import BACKEND, DEBUG
+SparseTensorData = None # Lazy import
+
+
+__all__ = [
+    'SparseTensor',
+    'sparse_batch_broadcast',
+    'sparse_batch_op',
+    'sparse_cat',
+    'sparse_unbind',
+]
+
+
+class SparseTensor:
+    """
+    Sparse tensor with support for both torchsparse and spconv backends.
+    
+    Parameters:
+    - feats (torch.Tensor): Features of the sparse tensor.
+    - coords (torch.Tensor): Coordinates of the sparse tensor.
+    - shape (torch.Size): Shape of the sparse tensor.
+    - layout (List[slice]): Layout of the sparse tensor for each batch
+    - data (SparseTensorData): Sparse tensor data used for convolusion
+
+    NOTE:
+    - Data corresponding to a same batch should be contiguous.
+    - Coords should be in [0, 1023]
+    """
+    @overload
+    def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+    @overload
+    def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
+
+    def __init__(self, *args, **kwargs):
+        # Lazy import of sparse tensor backend
+        global SparseTensorData
+        if SparseTensorData is None:
+            import importlib
+            if BACKEND == 'torchsparse':
+                SparseTensorData = importlib.import_module('torchsparse').SparseTensor
+            elif BACKEND == 'spconv':
+                SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
+                
+        method_id = 0
+        if len(args) != 0:
+            method_id = 0 if isinstance(args[0], torch.Tensor) else 1
+        else:
+            method_id = 1 if 'data' in kwargs else 0
+
+        if method_id == 0:
+            feats, coords, shape, layout = args + (None,) * (4 - len(args))
+            if 'feats' in kwargs:
+                feats = kwargs['feats']
+                del kwargs['feats']
+            if 'coords' in kwargs:
+                coords = kwargs['coords']
+                del kwargs['coords']
+            if 'shape' in kwargs:
+                shape = kwargs['shape']
+                del kwargs['shape']
+            if 'layout' in kwargs:
+                layout = kwargs['layout']
+                del kwargs['layout']
+
+            if shape is None:
+                shape = self.__cal_shape(feats, coords)
+            if layout is None:
+                layout = self.__cal_layout(coords, shape[0])
+            if BACKEND == 'torchsparse':
+                self.data = SparseTensorData(feats, coords, **kwargs)
+            elif BACKEND == 'spconv':
+                spatial_shape = list(coords.max(0)[0] + 1)[1:]
+                self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
+                self.data._features = feats
+        elif method_id == 1:
+            data, shape, layout = args + (None,) * (3 - len(args))
+            if 'data' in kwargs:
+                data = kwargs['data']
+                del kwargs['data']
+            if 'shape' in kwargs:
+                shape = kwargs['shape']
+                del kwargs['shape']
+            if 'layout' in kwargs:
+                layout = kwargs['layout']
+                del kwargs['layout']
+
+            self.data = data
+            if shape is None:
+                shape = self.__cal_shape(self.feats, self.coords)
+            if layout is None:
+                layout = self.__cal_layout(self.coords, shape[0])
+
+        self._shape = shape
+        self._layout = layout
+        self._scale = kwargs.get('scale', (1, 1, 1))
+        self._spatial_cache = kwargs.get('spatial_cache', {})
+
+        if DEBUG:
+            try:
+                assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
+                assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
+                assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
+                for i in range(self.shape[0]):
+                    assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
+            except Exception as e:
+                print('Debugging information:')
+                print(f"- Shape: {self.shape}")
+                print(f"- Layout: {self.layout}")
+                print(f"- Scale: {self._scale}")
+                print(f"- Coords: {self.coords}")
+                raise e
+        
+    def __cal_shape(self, feats, coords):
+        shape = []
+        shape.append(coords[:, 0].max().item() + 1)
+        shape.extend([*feats.shape[1:]])
+        return torch.Size(shape)
+    
+    def __cal_layout(self, coords, batch_size):
+        seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
+        offset = torch.cumsum(seq_len, dim=0) 
+        layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
+        return layout
+    
+    @property
+    def shape(self) -> torch.Size:
+        return self._shape
+    
+    def dim(self) -> int:
+        return len(self.shape)
+    
+    @property
+    def layout(self) -> List[slice]:
+        return self._layout
+
+    @property
+    def feats(self) -> torch.Tensor:
+        if BACKEND == 'torchsparse':
+            return self.data.F
+        elif BACKEND == 'spconv':
+            return self.data.features
+    
+    @feats.setter
+    def feats(self, value: torch.Tensor):
+        if BACKEND == 'torchsparse':
+            self.data.F = value
+        elif BACKEND == 'spconv':
+            self.data.features = value
+
+    @property
+    def coords(self) -> torch.Tensor:
+        if BACKEND == 'torchsparse':
+            return self.data.C
+        elif BACKEND == 'spconv':
+            return self.data.indices
+        
+    @coords.setter
+    def coords(self, value: torch.Tensor):
+        if BACKEND == 'torchsparse':
+            self.data.C = value
+        elif BACKEND == 'spconv':
+            self.data.indices = value
+
+    @property
+    def dtype(self):
+        return self.feats.dtype
+
+    @property
+    def device(self):
+        return self.feats.device
+
+    @overload
+    def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
+
+    @overload
+    def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
+
+    def to(self, *args, **kwargs) -> 'SparseTensor':
+        device = None
+        dtype = None
+        if len(args) == 2:
+            device, dtype = args
+        elif len(args) == 1:
+            if isinstance(args[0], torch.dtype):
+                dtype = args[0]
+            else:
+                device = args[0]
+        if 'dtype' in kwargs:
+            assert dtype is None, "to() received multiple values for argument 'dtype'"
+            dtype = kwargs['dtype']
+        if 'device' in kwargs:
+            assert device is None, "to() received multiple values for argument 'device'"
+            device = kwargs['device']
+        
+        new_feats = self.feats.to(device=device, dtype=dtype)
+        new_coords = self.coords.to(device=device)
+        return self.replace(new_feats, new_coords)
+
+    def type(self, dtype):
+        new_feats = self.feats.type(dtype)
+        return self.replace(new_feats)
+
+    def cpu(self) -> 'SparseTensor':
+        new_feats = self.feats.cpu()
+        new_coords = self.coords.cpu()
+        return self.replace(new_feats, new_coords)
+    
+    def cuda(self) -> 'SparseTensor':
+        new_feats = self.feats.cuda()
+        new_coords = self.coords.cuda()
+        return self.replace(new_feats, new_coords)
+
+    def half(self) -> 'SparseTensor':
+        new_feats = self.feats.half()
+        return self.replace(new_feats)
+    
+    def float(self) -> 'SparseTensor':
+        new_feats = self.feats.float()
+        return self.replace(new_feats)
+    
+    def detach(self) -> 'SparseTensor':
+        new_coords = self.coords.detach()
+        new_feats = self.feats.detach()
+        return self.replace(new_feats, new_coords)
+
+    def dense(self) -> torch.Tensor:
+        if BACKEND == 'torchsparse':
+            return self.data.dense()
+        elif BACKEND == 'spconv':
+            return self.data.dense()
+
+    def reshape(self, *shape) -> 'SparseTensor':
+        new_feats = self.feats.reshape(self.feats.shape[0], *shape)
+        return self.replace(new_feats)
+    
+    def unbind(self, dim: int) -> List['SparseTensor']:
+        return sparse_unbind(self, dim)
+
+    def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
+        new_shape = [self.shape[0]]
+        new_shape.extend(feats.shape[1:])
+        if BACKEND == 'torchsparse':
+            new_data = SparseTensorData(
+                feats=feats,
+                coords=self.data.coords if coords is None else coords,
+                stride=self.data.stride,
+                spatial_range=self.data.spatial_range,
+            )
+            new_data._caches = self.data._caches
+        elif BACKEND == 'spconv':
+            new_data = SparseTensorData(
+                self.data.features.reshape(self.data.features.shape[0], -1),
+                self.data.indices,
+                self.data.spatial_shape,
+                self.data.batch_size,
+                self.data.grid,
+                self.data.voxel_num,
+                self.data.indice_dict
+            )
+            new_data._features = feats
+            new_data.benchmark = self.data.benchmark
+            new_data.benchmark_record = self.data.benchmark_record
+            new_data.thrust_allocator = self.data.thrust_allocator
+            new_data._timer = self.data._timer
+            new_data.force_algo = self.data.force_algo
+            new_data.int8_scale = self.data.int8_scale
+            if coords is not None:
+                new_data.indices = coords
+        new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
+        return new_tensor
+
+    @staticmethod
+    def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
+        N, C = dim
+        x = torch.arange(aabb[0], aabb[3] + 1)
+        y = torch.arange(aabb[1], aabb[4] + 1)
+        z = torch.arange(aabb[2], aabb[5] + 1)
+        coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
+        coords = torch.cat([
+            torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
+            coords.repeat(N, 1),
+        ], dim=1).to(dtype=torch.int32, device=device)
+        feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
+        return SparseTensor(feats=feats, coords=coords)
+
+    def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
+        new_cache = {}
+        for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
+            if k in self._spatial_cache:
+                new_cache[k] = self._spatial_cache[k]
+            if k in other._spatial_cache:
+                if k not in new_cache:
+                    new_cache[k] = other._spatial_cache[k]
+                else:
+                    new_cache[k].update(other._spatial_cache[k])
+        return new_cache
+
+    def __neg__(self) -> 'SparseTensor':
+        return self.replace(-self.feats)
+    
+    def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
+        if isinstance(other, torch.Tensor):
+            try:
+                other = torch.broadcast_to(other, self.shape)
+                other = sparse_batch_broadcast(self, other)
+            except:
+                pass
+        if isinstance(other, SparseTensor):
+            other = other.feats
+        new_feats = op(self.feats, other)
+        new_tensor = self.replace(new_feats)
+        if isinstance(other, SparseTensor):
+            new_tensor._spatial_cache = self.__merge_sparse_cache(other)
+        return new_tensor
+
+    def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.add)
+
+    def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.add)
+    
+    def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.sub)
+    
+    def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
+
+    def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.mul)
+
+    def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.mul)
+
+    def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, torch.div)
+
+    def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
+        return self.__elemwise__(other, lambda x, y: torch.div(y, x))
+
+    def __getitem__(self, idx):
+        if isinstance(idx, int):
+            idx = [idx]
+        elif isinstance(idx, slice):
+            idx = range(*idx.indices(self.shape[0]))
+        elif isinstance(idx, torch.Tensor):
+            if idx.dtype == torch.bool:
+                assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
+                idx = idx.nonzero().squeeze(1)
+            elif idx.dtype in [torch.int32, torch.int64]:
+                assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
+            else:
+                raise ValueError(f"Unknown index type: {idx.dtype}")
+        else:
+            raise ValueError(f"Unknown index type: {type(idx)}")
+        
+        coords = []
+        feats = []
+        for new_idx, old_idx in enumerate(idx):
+            coords.append(self.coords[self.layout[old_idx]].clone())
+            coords[-1][:, 0] = new_idx
+            feats.append(self.feats[self.layout[old_idx]])
+        coords = torch.cat(coords, dim=0).contiguous()
+        feats = torch.cat(feats, dim=0).contiguous()
+        return SparseTensor(feats=feats, coords=coords)
+
+    def register_spatial_cache(self, key, value) -> None:
+        """
+        Register a spatial cache.
+        The spatial cache can be any thing you want to cache.
+        The registery and retrieval of the cache is based on current scale.
+        """
+        scale_key = str(self._scale)
+        if scale_key not in self._spatial_cache:
+            self._spatial_cache[scale_key] = {}
+        self._spatial_cache[scale_key][key] = value
+
+    def get_spatial_cache(self, key=None):
+        """
+        Get a spatial cache.
+        """
+        scale_key = str(self._scale)
+        cur_scale_cache = self._spatial_cache.get(scale_key, {})
+        if key is None:
+            return cur_scale_cache
+        return cur_scale_cache.get(key, None)
+
+
+def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
+    """
+    Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+    
+    Args:
+        input (torch.Tensor): 1D tensor to broadcast.
+        target (SparseTensor): Sparse tensor to broadcast to.
+        op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+    """
+    coords, feats = input.coords, input.feats
+    broadcasted = torch.zeros_like(feats)
+    for k in range(input.shape[0]):
+        broadcasted[input.layout[k]] = other[k]
+    return broadcasted
+
+
+def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
+    """
+    Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
+    
+    Args:
+        input (torch.Tensor): 1D tensor to broadcast.
+        target (SparseTensor): Sparse tensor to broadcast to.
+        op (callable): Operation to perform after broadcasting. Defaults to torch.add.
+    """
+    return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
+
+
+def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
+    """
+    Concatenate a list of sparse tensors.
+    
+    Args:
+        inputs (List[SparseTensor]): List of sparse tensors to concatenate.
+    """
+    if dim == 0:
+        start = 0
+        coords = []
+        for input in inputs:
+            coords.append(input.coords.clone())
+            coords[-1][:, 0] += start
+            start += input.shape[0]
+        coords = torch.cat(coords, dim=0)
+        feats = torch.cat([input.feats for input in inputs], dim=0)
+        output = SparseTensor(
+            coords=coords,
+            feats=feats,
+        )
+    else:
+        feats = torch.cat([input.feats for input in inputs], dim=dim)
+        output = inputs[0].replace(feats)
+
+    return output
+
+
+def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
+    """
+    Unbind a sparse tensor along a dimension.
+    
+    Args:
+        input (SparseTensor): Sparse tensor to unbind.
+        dim (int): Dimension to unbind.
+    """
+    if dim == 0:
+        return [input[i] for i in range(input.shape[0])]
+    else:
+        feats = input.feats.unbind(dim)
+        return [input.replace(f) for f in feats]
diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..8fe2fefb0823b7e489d75d392b2cf6d1a6fa34e7
--- /dev/null
+++ b/trellis/modules/sparse/conv/__init__.py
@@ -0,0 +1,6 @@
+from .. import BACKEND
+
+if BACKEND == 'torchsparse':
+    from .conv_torchsparse import *
+elif BACKEND == 'spconv':
+    from .conv_spconv import *
\ No newline at end of file
diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py
new file mode 100755
index 0000000000000000000000000000000000000000..6cfa3d33d20fe7d9f980b4278b0f9c5a747bec81
--- /dev/null
+++ b/trellis/modules/sparse/conv/conv_spconv.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+from .. import DEBUG
+
+class SparseConv3d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
+        super(SparseConv3d, self).__init__()
+        if 'spconv' not in globals():
+            import spconv.pytorch as spconv
+        if stride == 1 and (padding is None):
+            self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key)
+        else:
+            self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key)
+        self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+        self.padding = padding
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
+        new_data = self.conv(x.data)
+        new_shape = [x.shape[0], self.conv.out_channels]
+        new_layout = None if spatial_changed else x.layout
+
+        if spatial_changed and (x.shape[0] != 1):
+            # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
+            fwd = new_data.indices[:, 0].argsort()
+            bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
+            sorted_feats = new_data.features[fwd]
+            sorted_coords = new_data.indices[fwd]
+            unsorted_data = new_data
+            new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size)  # type: ignore
+
+        out = SparseTensor(
+            new_data, shape=torch.Size(new_shape), layout=new_layout,
+            scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
+            spatial_cache=x._spatial_cache,
+        )
+
+        if spatial_changed and (x.shape[0] != 1):
+            out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
+            out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
+ 
+        return out
+
+
+class SparseInverseConv3d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+        super(SparseInverseConv3d, self).__init__()
+        if 'spconv' not in globals():
+            import spconv.pytorch as spconv
+        self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
+        self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        spatial_changed = any(s != 1 for s in self.stride)
+        if spatial_changed:
+            # recover the original spconv order
+            data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
+            bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
+            data = data.replace_feature(x.feats[bwd])
+            if DEBUG:
+                assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
+        else:
+            data = x.data
+
+        new_data = self.conv(data)
+        new_shape = [x.shape[0], self.conv.out_channels]
+        new_layout = None if spatial_changed else x.layout
+        out = SparseTensor(
+            new_data, shape=torch.Size(new_shape), layout=new_layout,
+            scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
+            spatial_cache=x._spatial_cache,
+        )
+        return out
diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py
new file mode 100755
index 0000000000000000000000000000000000000000..1d612582d4b31f90aca3c00b693bbbc2550dc62c
--- /dev/null
+++ b/trellis/modules/sparse/conv/conv_torchsparse.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+from .. import SparseTensor
+
+
+class SparseConv3d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+        super(SparseConv3d, self).__init__()
+        if 'torchsparse' not in globals():
+            import torchsparse
+        self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        out = self.conv(x.data)
+        new_shape = [x.shape[0], self.conv.out_channels]
+        out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+        out._spatial_cache = x._spatial_cache
+        out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
+        return out
+
+
+class SparseInverseConv3d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
+        super(SparseInverseConv3d, self).__init__()
+        if 'torchsparse' not in globals():
+            import torchsparse
+        self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        out = self.conv(x.data)        
+        new_shape = [x.shape[0], self.conv.out_channels]
+        out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
+        out._spatial_cache = x._spatial_cache
+        out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
+        return out
+
+
+
diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py
new file mode 100755
index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd
--- /dev/null
+++ b/trellis/modules/sparse/linear.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+    'SparseLinear'
+]
+
+
+class SparseLinear(nn.Linear):
+    def __init__(self, in_features, out_features, bias=True):
+        super(SparseLinear, self).__init__(in_features, out_features, bias)
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        return input.replace(super().forward(input.feats))
diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py
new file mode 100755
index 0000000000000000000000000000000000000000..f200098dd82011a3aeee1688b9eb17018fa78295
--- /dev/null
+++ b/trellis/modules/sparse/nonlinearity.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+    'SparseReLU',
+    'SparseSiLU',
+    'SparseGELU',
+    'SparseActivation'
+]
+
+
+class SparseReLU(nn.ReLU):
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        return input.replace(super().forward(input.feats))
+    
+
+class SparseSiLU(nn.SiLU):
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        return input.replace(super().forward(input.feats))
+
+
+class SparseGELU(nn.GELU):
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        return input.replace(super().forward(input.feats))
+
+
+class SparseActivation(nn.Module):
+    def __init__(self, activation: nn.Module):
+        super().__init__()
+        self.activation = activation
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        return input.replace(self.activation(input.feats))
+    
diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py
new file mode 100755
index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d
--- /dev/null
+++ b/trellis/modules/sparse/norm.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+from . import SparseTensor
+from . import DEBUG
+
+__all__ = [
+    'SparseGroupNorm',
+    'SparseLayerNorm',
+    'SparseGroupNorm32',
+    'SparseLayerNorm32',
+]
+
+
+class SparseGroupNorm(nn.GroupNorm):
+    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
+        super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        nfeats = torch.zeros_like(input.feats)
+        for k in range(input.shape[0]):
+            if DEBUG:
+                assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
+            bfeats = input.feats[input.layout[k]]
+            bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+            bfeats = super().forward(bfeats)
+            bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+            nfeats[input.layout[k]] = bfeats
+        return input.replace(nfeats)
+
+
+class SparseLayerNorm(nn.LayerNorm):
+    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+        super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        nfeats = torch.zeros_like(input.feats)
+        for k in range(input.shape[0]):
+            bfeats = input.feats[input.layout[k]]
+            bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
+            bfeats = super().forward(bfeats)
+            bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
+            nfeats[input.layout[k]] = bfeats
+        return input.replace(nfeats)
+
+
+class SparseGroupNorm32(SparseGroupNorm):
+    """
+    A GroupNorm layer that converts to float32 before the forward pass.
+    """
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        return super().forward(x.float()).type(x.dtype)
+
+class SparseLayerNorm32(SparseLayerNorm):
+    """
+    A LayerNorm layer that converts to float32 before the forward pass.
+    """
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        return super().forward(x.float()).type(x.dtype)
diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py
new file mode 100755
index 0000000000000000000000000000000000000000..ad7121473f335b307e2f7ea5f05c964d3aec0440
--- /dev/null
+++ b/trellis/modules/sparse/spatial.py
@@ -0,0 +1,110 @@
+from typing import *
+import torch
+import torch.nn as nn
+from . import SparseTensor
+
+__all__ = [
+    'SparseDownsample',
+    'SparseUpsample',
+    'SparseSubdivide'
+]
+
+
+class SparseDownsample(nn.Module):
+    """
+    Downsample a sparse tensor by a factor of `factor`.
+    Implemented as average pooling.
+    """
+    def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
+        super(SparseDownsample, self).__init__()
+        self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        DIM = input.coords.shape[-1] - 1
+        factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
+        assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
+
+        coord = list(input.coords.unbind(dim=-1))
+        for i, f in enumerate(factor):
+            coord[i+1] = coord[i+1] // f
+
+        MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
+        OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
+        code = sum([c * o for c, o in zip(coord, OFFSET)])
+        code, idx = code.unique(return_inverse=True)
+
+        new_feats = torch.scatter_reduce(
+            torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
+            dim=0,
+            index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
+            src=input.feats,
+            reduce='mean'
+        )
+        new_coords = torch.stack(
+            [code // OFFSET[0]] +
+            [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
+            dim=-1
+        )
+        out = SparseTensor(new_feats, new_coords, input.shape,)
+        out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
+        out._spatial_cache = input._spatial_cache
+
+        out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
+        out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
+        out.register_spatial_cache(f'upsample_{factor}_idx', idx)
+
+        return out
+
+
+class SparseUpsample(nn.Module):
+    """
+    Upsample a sparse tensor by a factor of `factor`.
+    Implemented as nearest neighbor interpolation.
+    """
+    def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
+        super(SparseUpsample, self).__init__()
+        self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        DIM = input.coords.shape[-1] - 1
+        factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
+        assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
+
+        new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
+        new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
+        idx = input.get_spatial_cache(f'upsample_{factor}_idx')
+        if any([x is None for x in [new_coords, new_layout, idx]]):
+            raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
+        new_feats = input.feats[idx]
+        out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
+        out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
+        out._spatial_cache = input._spatial_cache
+        return out
+    
+class SparseSubdivide(nn.Module):
+    """
+    Upsample a sparse tensor by a factor of `factor`.
+    Implemented as nearest neighbor interpolation.
+    """
+    def __init__(self):
+        super(SparseSubdivide, self).__init__()
+
+    def forward(self, input: SparseTensor) -> SparseTensor:
+        DIM = input.coords.shape[-1] - 1
+        # upsample scale=2^DIM
+        n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
+        n_coords = torch.nonzero(n_cube)
+        n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
+        factor = n_coords.shape[0]
+        assert factor == 2 ** DIM
+        # print(n_coords.shape)
+        new_coords = input.coords.clone()
+        new_coords[:, 1:] *= 2
+        new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
+        
+        new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
+        out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
+        out._scale = input._scale * 2
+        out._spatial_cache = input._spatial_cache
+        return out
+
diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/trellis/modules/sparse/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0
--- /dev/null
+++ b/trellis/modules/sparse/transformer/blocks.py
@@ -0,0 +1,151 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor
+from ..linear import SparseLinear
+from ..nonlinearity import SparseGELU
+from ..attention import SparseMultiHeadAttention, SerializeMode
+from ...norm import LayerNorm32
+
+
+class SparseFeedForwardNet(nn.Module):
+    def __init__(self, channels: int, mlp_ratio: float = 4.0):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            SparseLinear(channels, int(channels * mlp_ratio)),
+            SparseGELU(approximate="tanh"),
+            SparseLinear(int(channels * mlp_ratio), channels),
+        )
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        return self.mlp(x)
+
+
+class SparseTransformerBlock(nn.Module):
+    """
+    Sparse Transformer block (MSA + FFN).
+    """
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+        window_size: Optional[int] = None,
+        shift_sequence: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        serialize_mode: Optional[SerializeMode] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qkv_bias: bool = True,
+        ln_affine: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.attn = SparseMultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_sequence=shift_sequence,
+            shift_window=shift_window,
+            serialize_mode=serialize_mode,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.mlp = SparseFeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+
+    def _forward(self, x: SparseTensor) -> SparseTensor:
+        h = x.replace(self.norm1(x.feats))
+        h = self.attn(h)
+        x = x + h
+        h = x.replace(self.norm2(x.feats))
+        h = self.mlp(h)
+        x = x + h
+        return x
+
+    def forward(self, x: SparseTensor) -> SparseTensor:
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+        else:
+            return self._forward(x)
+
+
+class SparseTransformerCrossBlock(nn.Module):
+    """
+    Sparse Transformer cross-attention block (MSA + MCA + FFN).
+    """
+    def __init__(
+        self,
+        channels: int,
+        ctx_channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+        window_size: Optional[int] = None,
+        shift_sequence: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        serialize_mode: Optional[SerializeMode] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+        qkv_bias: bool = True,
+        ln_affine: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.self_attn = SparseMultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            type="self",
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_sequence=shift_sequence,
+            shift_window=shift_window,
+            serialize_mode=serialize_mode,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.cross_attn = SparseMultiHeadAttention(
+            channels,
+            ctx_channels=ctx_channels,
+            num_heads=num_heads,
+            type="cross",
+            attn_mode="full",
+            qkv_bias=qkv_bias,
+            qk_rms_norm=qk_rms_norm_cross,
+        )
+        self.mlp = SparseFeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+
+    def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
+        h = x.replace(self.norm1(x.feats))
+        h = self.self_attn(h)
+        x = x + h
+        h = x.replace(self.norm2(x.feats))
+        h = self.cross_attn(h, context)
+        x = x + h
+        h = x.replace(self.norm3(x.feats))
+        h = self.mlp(h)
+        x = x + h
+        return x
+
+    def forward(self, x: SparseTensor, context: torch.Tensor):
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
+        else:
+            return self._forward(x, context)
diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a8416559f39acbed9e5996e9891c97f95c80c8f
--- /dev/null
+++ b/trellis/modules/sparse/transformer/modulated.py
@@ -0,0 +1,166 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..basic import SparseTensor
+from ..attention import SparseMultiHeadAttention, SerializeMode
+from ...norm import LayerNorm32
+from .blocks import SparseFeedForwardNet
+
+
+class ModulatedSparseTransformerBlock(nn.Module):
+    """
+    Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+    """
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+        window_size: Optional[int] = None,
+        shift_sequence: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        serialize_mode: Optional[SerializeMode] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qkv_bias: bool = True,
+        share_mod: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.share_mod = share_mod
+        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.attn = SparseMultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_sequence=shift_sequence,
+            shift_window=shift_window,
+            serialize_mode=serialize_mode,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.mlp = SparseFeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+        if not share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(channels, 6 * channels, bias=True)
+            )
+
+    def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+        if self.share_mod:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+        else:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+        h = x.replace(self.norm1(x.feats))
+        h = h * (1 + scale_msa) + shift_msa
+        h = self.attn(h)
+        h = h * gate_msa
+        x = x + h
+        h = x.replace(self.norm2(x.feats))
+        h = h * (1 + scale_mlp) + shift_mlp
+        h = self.mlp(h)
+        h = h * gate_mlp
+        x = x + h
+        return x
+
+    def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
+        else:
+            return self._forward(x, mod)
+
+
+class ModulatedSparseTransformerCrossBlock(nn.Module):
+    """
+    Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+    """
+    def __init__(
+        self,
+        channels: int,
+        ctx_channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
+        window_size: Optional[int] = None,
+        shift_sequence: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        serialize_mode: Optional[SerializeMode] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+        qkv_bias: bool = True,
+        share_mod: bool = False,
+
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.share_mod = share_mod
+        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+        self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.self_attn = SparseMultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            type="self",
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_sequence=shift_sequence,
+            shift_window=shift_window,
+            serialize_mode=serialize_mode,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.cross_attn = SparseMultiHeadAttention(
+            channels,
+            ctx_channels=ctx_channels,
+            num_heads=num_heads,
+            type="cross",
+            attn_mode="full",
+            qkv_bias=qkv_bias,
+            qk_rms_norm=qk_rms_norm_cross,
+        )
+        self.mlp = SparseFeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+        if not share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(channels, 6 * channels, bias=True)
+            )
+
+    def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
+        if self.share_mod:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+        else:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+        h = x.replace(self.norm1(x.feats))
+        h = h * (1 + scale_msa) + shift_msa
+        h = self.self_attn(h)
+        h = h * gate_msa
+        x = x + h
+        h = x.replace(self.norm2(x.feats))
+        h = self.cross_attn(h, context)
+        x = x + h
+        h = x.replace(self.norm3(x.feats))
+        h = h * (1 + scale_mlp) + shift_mlp
+        h = self.mlp(h)
+        h = h * gate_mlp
+        x = x + h
+        return x
+
+    def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
+        else:
+            return self._forward(x, mod, context)
diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py
new file mode 100644
index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3
--- /dev/null
+++ b/trellis/modules/spatial.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
+    """
+    3D pixel shuffle.
+    """
+    B, C, H, W, D = x.shape
+    C_ = C // scale_factor**3
+    x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
+    x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
+    x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
+    return x
+
+
+def patchify(x: torch.Tensor, patch_size: int):
+    """
+    Patchify a tensor.
+
+    Args:
+        x (torch.Tensor): (N, C, *spatial) tensor
+        patch_size (int): Patch size
+    """
+    DIM = x.dim() - 2
+    for d in range(2, DIM + 2):
+        assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
+
+    x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
+    x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
+    x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
+    return x
+
+
+def unpatchify(x: torch.Tensor, patch_size: int):
+    """
+    Unpatchify a tensor.
+
+    Args:
+        x (torch.Tensor): (N, C, *spatial) tensor
+        patch_size (int): Patch size
+    """
+    DIM = x.dim() - 2
+    assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
+
+    x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
+    x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
+    x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
+    return x
diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd
--- /dev/null
+++ b/trellis/modules/transformer/__init__.py
@@ -0,0 +1,2 @@
+from .blocks import *
+from .modulated import *
\ No newline at end of file
diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..c37eb7ed92f4aacfc9e974a63b247589d95977da
--- /dev/null
+++ b/trellis/modules/transformer/blocks.py
@@ -0,0 +1,182 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention
+from ..norm import LayerNorm32
+
+
+class AbsolutePositionEmbedder(nn.Module):
+    """
+    Embeds spatial positions into vector representations.
+    """
+    def __init__(self, channels: int, in_channels: int = 3):
+        super().__init__()
+        self.channels = channels
+        self.in_channels = in_channels
+        self.freq_dim = channels // in_channels // 2
+        self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
+        self.freqs = 1.0 / (10000 ** self.freqs)
+        
+    def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Create sinusoidal position embeddings.
+
+        Args:
+            x: a 1-D Tensor of N indices
+
+        Returns:
+            an (N, D) Tensor of positional embeddings.
+        """
+        self.freqs = self.freqs.to(x.device)
+        out = torch.outer(x, self.freqs)
+        out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
+        return out
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            x (torch.Tensor): (N, D) tensor of spatial positions
+        """
+        N, D = x.shape
+        assert D == self.in_channels, "Input dimension must match number of input channels"
+        embed = self._sin_cos_embedding(x.reshape(-1))
+        embed = embed.reshape(N, -1)
+        if embed.shape[1] < self.channels:
+            embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
+        return embed
+
+
+class FeedForwardNet(nn.Module):
+    def __init__(self, channels: int, mlp_ratio: float = 4.0):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            nn.Linear(channels, int(channels * mlp_ratio)),
+            nn.GELU(approximate="tanh"),
+            nn.Linear(int(channels * mlp_ratio), channels),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.mlp(x)
+
+
+class TransformerBlock(nn.Module):
+    """
+    Transformer block (MSA + FFN).
+    """
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_window: Optional[int] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qkv_bias: bool = True,
+        ln_affine: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.attn = MultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_window=shift_window,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.mlp = FeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+
+    def _forward(self, x: torch.Tensor) -> torch.Tensor:
+        h = self.norm1(x)
+        h = self.attn(h)
+        x = x + h
+        h = self.norm2(x)
+        h = self.mlp(h)
+        x = x + h
+        return x
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
+        else:
+            return self._forward(x)
+
+
+class TransformerCrossBlock(nn.Module):
+    """
+    Transformer cross-attention block (MSA + MCA + FFN).
+    """
+    def __init__(
+        self,
+        channels: int,
+        ctx_channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+        qkv_bias: bool = True,
+        ln_affine: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
+        self.self_attn = MultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            type="self",
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_window=shift_window,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.cross_attn = MultiHeadAttention(
+            channels,
+            ctx_channels=ctx_channels,
+            num_heads=num_heads,
+            type="cross",
+            attn_mode="full",
+            qkv_bias=qkv_bias,
+            qk_rms_norm=qk_rms_norm_cross,
+        )
+        self.mlp = FeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+
+    def _forward(self, x: torch.Tensor, context: torch.Tensor):
+        h = self.norm1(x)
+        h = self.self_attn(h)
+        x = x + h
+        h = self.norm2(x)
+        h = self.cross_attn(h, context)
+        x = x + h
+        h = self.norm3(x)
+        h = self.mlp(h)
+        x = x + h
+        return x
+
+    def forward(self, x: torch.Tensor, context: torch.Tensor):
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
+        else:
+            return self._forward(x, context)
+        
\ No newline at end of file
diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4aeca0689e68f656b08f7aa822b7be839aa727d
--- /dev/null
+++ b/trellis/modules/transformer/modulated.py
@@ -0,0 +1,157 @@
+from typing import *
+import torch
+import torch.nn as nn
+from ..attention import MultiHeadAttention
+from ..norm import LayerNorm32
+from .blocks import FeedForwardNet
+
+
+class ModulatedTransformerBlock(nn.Module):
+    """
+    Transformer block (MSA + FFN) with adaptive layer norm conditioning.
+    """
+    def __init__(
+        self,
+        channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qkv_bias: bool = True,
+        share_mod: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.share_mod = share_mod
+        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.attn = MultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_window=shift_window,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.mlp = FeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+        if not share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(channels, 6 * channels, bias=True)
+            )
+
+    def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+        if self.share_mod:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+        else:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+        h = self.norm1(x)
+        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+        h = self.attn(h)
+        h = h * gate_msa.unsqueeze(1)
+        x = x + h
+        h = self.norm2(x)
+        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+        h = self.mlp(h)
+        h = h * gate_mlp.unsqueeze(1)
+        x = x + h
+        return x
+
+    def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
+        else:
+            return self._forward(x, mod)
+
+
+class ModulatedTransformerCrossBlock(nn.Module):
+    """
+    Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
+    """
+    def __init__(
+        self,
+        channels: int,
+        ctx_channels: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        attn_mode: Literal["full", "windowed"] = "full",
+        window_size: Optional[int] = None,
+        shift_window: Optional[Tuple[int, int, int]] = None,
+        use_checkpoint: bool = False,
+        use_rope: bool = False,
+        qk_rms_norm: bool = False,
+        qk_rms_norm_cross: bool = False,
+        qkv_bias: bool = True,
+        share_mod: bool = False,
+    ):
+        super().__init__()
+        self.use_checkpoint = use_checkpoint
+        self.share_mod = share_mod
+        self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
+        self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
+        self.self_attn = MultiHeadAttention(
+            channels,
+            num_heads=num_heads,
+            type="self",
+            attn_mode=attn_mode,
+            window_size=window_size,
+            shift_window=shift_window,
+            qkv_bias=qkv_bias,
+            use_rope=use_rope,
+            qk_rms_norm=qk_rms_norm,
+        )
+        self.cross_attn = MultiHeadAttention(
+            channels,
+            ctx_channels=ctx_channels,
+            num_heads=num_heads,
+            type="cross",
+            attn_mode="full",
+            qkv_bias=qkv_bias,
+            qk_rms_norm=qk_rms_norm_cross,
+        )
+        self.mlp = FeedForwardNet(
+            channels,
+            mlp_ratio=mlp_ratio,
+        )
+        if not share_mod:
+            self.adaLN_modulation = nn.Sequential(
+                nn.SiLU(),
+                nn.Linear(channels, 6 * channels, bias=True)
+            )
+
+    def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
+        if self.share_mod:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
+        else:
+            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
+        h = self.norm1(x)
+        h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+        h = self.self_attn(h)
+        h = h * gate_msa.unsqueeze(1)
+        x = x + h
+        h = self.norm2(x)
+        h = self.cross_attn(h, context)
+        x = x + h
+        h = self.norm3(x)
+        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+        h = self.mlp(h)
+        h = h * gate_mlp.unsqueeze(1)
+        x = x + h
+        return x
+
+    def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
+        if self.use_checkpoint:
+            return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
+        else:
+            return self._forward(x, mod, context)
+        
\ No newline at end of file
diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..f0afb1b6c767aa2ad00bad96649fb30315e696ea
--- /dev/null
+++ b/trellis/modules/utils.py
@@ -0,0 +1,54 @@
+import torch.nn as nn
+from ..modules import sparse as sp
+
+FP16_MODULES = (
+    nn.Conv1d,
+    nn.Conv2d,
+    nn.Conv3d,
+    nn.ConvTranspose1d,
+    nn.ConvTranspose2d,
+    nn.ConvTranspose3d,
+    nn.Linear,
+    sp.SparseConv3d,
+    sp.SparseInverseConv3d,
+    sp.SparseLinear,
+)
+
+def convert_module_to_f16(l):
+    """
+    Convert primitive modules to float16.
+    """
+    if isinstance(l, FP16_MODULES):
+        for p in l.parameters():
+            p.data = p.data.half()
+
+
+def convert_module_to_f32(l):
+    """
+    Convert primitive modules to float32, undoing convert_module_to_f16().
+    """
+    if isinstance(l, FP16_MODULES):
+        for p in l.parameters():
+            p.data = p.data.float()
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9e8548b894aeb3d354c739320ed3288be9c7b0e
--- /dev/null
+++ b/trellis/pipelines/__init__.py
@@ -0,0 +1,24 @@
+from . import samplers
+from .trellis_image_to_3d import TrellisImageTo3DPipeline
+
+
+def from_pretrained(path: str):
+    """
+    Load a pipeline from a model folder or a Hugging Face model hub.
+
+    Args:
+        path: The path to the model. Can be either local path or a Hugging Face model name.
+    """
+    import os
+    import json
+    is_local = os.path.exists(f"{path}/pipeline.json")
+
+    if is_local:
+        config_file = f"{path}/pipeline.json"
+    else:
+        from huggingface_hub import hf_hub_download
+        config_file = hf_hub_download(path, "pipeline.json")
+
+    with open(config_file, 'r') as f:
+        config = json.load(f)
+    return globals()[config['name']].from_pretrained(path)
diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a9e0df4ec5fb915d57d30189cac854e3f095620
--- /dev/null
+++ b/trellis/pipelines/base.py
@@ -0,0 +1,66 @@
+from typing import *
+import torch
+import torch.nn as nn
+from .. import models
+
+
+class Pipeline:
+    """
+    A base class for pipelines.
+    """
+    def __init__(
+        self,
+        models: dict[str, nn.Module] = None,
+    ):
+        if models is None:
+            return
+        self.models = models
+        for model in self.models.values():
+            model.eval()
+
+    @staticmethod
+    def from_pretrained(path: str) -> "Pipeline":
+        """
+        Load a pretrained model.
+        """
+        import os
+        import json
+        is_local = os.path.exists(f"{path}/pipeline.json")
+
+        if is_local:
+            config_file = f"{path}/pipeline.json"
+        else:
+            from huggingface_hub import hf_hub_download
+            config_file = hf_hub_download(path, "pipeline.json")
+
+        with open(config_file, 'r') as f:
+            args = json.load(f)['args']
+
+        _models = {
+            k: models.from_pretrained(f"{path}/{v}")
+            for k, v in args['models'].items()
+        }
+
+        new_pipeline = Pipeline(_models)
+        new_pipeline._pretrained_args = args
+        return new_pipeline
+
+    @property
+    def device(self) -> torch.device:
+        for model in self.models.values():
+            if hasattr(model, 'device'):
+                return model.device
+        for model in self.models.values():
+            if hasattr(model, 'parameters'):
+                return next(model.parameters()).device
+        raise RuntimeError("No device found.")
+
+    def to(self, device: torch.device) -> None:
+        for model in self.models.values():
+            model.to(device)
+
+    def cuda(self) -> None:
+        self.to(torch.device("cuda"))
+
+    def cpu(self) -> None:
+        self.to(torch.device("cpu"))
diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..54d412fc5d8eb662081a92a56ad078243988c2f9
--- /dev/null
+++ b/trellis/pipelines/samplers/__init__.py
@@ -0,0 +1,2 @@
+from .base import Sampler
+from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
\ No newline at end of file
diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf
--- /dev/null
+++ b/trellis/pipelines/samplers/base.py
@@ -0,0 +1,20 @@
+from typing import *
+from abc import ABC, abstractmethod
+
+
+class Sampler(ABC):
+    """
+    A base class for samplers.
+    """
+
+    @abstractmethod
+    def sample(
+        self,
+        model,
+        **kwargs
+    ):
+        """
+        Sample from a model.
+        """
+        pass
+    
\ No newline at end of file
diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..5701b25f5d7a2197612eb256f8ee13e8c489da1f
--- /dev/null
+++ b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py
@@ -0,0 +1,12 @@
+from typing import *
+
+
+class ClassifierFreeGuidanceSamplerMixin:
+    """
+    A mixin class for samplers that apply classifier-free guidance.
+    """
+
+    def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
+        pred = super()._inference_model(model, x_t, t, cond, **kwargs)
+        neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
+        return (1 + cfg_strength) * pred - cfg_strength * neg_pred
diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d79124cf1b07515e8f0b88684e271028b1e3a71d
--- /dev/null
+++ b/trellis/pipelines/samplers/flow_euler.py
@@ -0,0 +1,199 @@
+from typing import *
+import torch
+import numpy as np
+from tqdm import tqdm
+from easydict import EasyDict as edict
+from .base import Sampler
+from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
+from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
+
+
+class FlowEulerSampler(Sampler):
+    """
+    Generate samples from a flow-matching model using Euler sampling.
+
+    Args:
+        sigma_min: The minimum scale of noise in flow.
+    """
+    def __init__(
+        self,
+        sigma_min: float,
+    ):
+        self.sigma_min = sigma_min
+
+    def _eps_to_xstart(self, x_t, t, eps):
+        assert x_t.shape == eps.shape
+        return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
+
+    def _xstart_to_eps(self, x_t, t, x_0):
+        assert x_t.shape == x_0.shape
+        return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
+
+    def _v_to_xstart_eps(self, x_t, t, v):
+        assert x_t.shape == v.shape
+        eps = (1 - t) * v + x_t
+        x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
+        return x_0, eps
+
+    def _inference_model(self, model, x_t, t, cond=None, **kwargs):
+        t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
+        return model(x_t, t, cond, **kwargs)
+
+    def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
+        pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
+        pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
+        return pred_x_0, pred_eps, pred_v
+
+    @torch.no_grad()
+    def sample_once(
+        self,
+        model,
+        x_t,
+        t: float,
+        t_prev: float,
+        cond: Optional[Any] = None,
+        **kwargs
+    ):
+        """
+        Sample x_{t-1} from the model using Euler method.
+        
+        Args:
+            model: The model to sample from.
+            x_t: The [N x C x ...] tensor of noisy inputs at time t.
+            t: The current timestep.
+            t_prev: The previous timestep.
+            cond: conditional information.
+            **kwargs: Additional arguments for model inference.
+
+        Returns:
+            a dict containing the following
+            - 'pred_x_prev': x_{t-1}.
+            - 'pred_x_0': a prediction of x_0.
+        """
+        pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
+        pred_x_prev = x_t - (t - t_prev) * pred_v
+        return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
+
+    @torch.no_grad()
+    def sample(
+        self,
+        model,
+        noise,
+        cond: Optional[Any] = None,
+        steps: int = 50,
+        rescale_t: float = 1.0,
+        verbose: bool = True,
+        **kwargs
+    ):
+        """
+        Generate samples from the model using Euler method.
+        
+        Args:
+            model: The model to sample from.
+            noise: The initial noise tensor.
+            cond: conditional information.
+            steps: The number of steps to sample.
+            rescale_t: The rescale factor for t.
+            verbose: If True, show a progress bar.
+            **kwargs: Additional arguments for model_inference.
+
+        Returns:
+            a dict containing the following
+            - 'samples': the model samples.
+            - 'pred_x_t': a list of prediction of x_t.
+            - 'pred_x_0': a list of prediction of x_0.
+        """
+        sample = noise
+        t_seq = np.linspace(1, 0, steps + 1)
+        t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
+        t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
+        ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
+        for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
+            out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
+            sample = out.pred_x_prev
+            ret.pred_x_t.append(out.pred_x_prev)
+            ret.pred_x_0.append(out.pred_x_0)
+        ret.samples = sample
+        return ret
+
+
+class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
+    """
+    Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
+    """
+    @torch.no_grad()
+    def sample(
+        self,
+        model,
+        noise,
+        cond,
+        neg_cond,
+        steps: int = 50,
+        rescale_t: float = 1.0,
+        cfg_strength: float = 3.0,
+        verbose: bool = True,
+        **kwargs
+    ):
+        """
+        Generate samples from the model using Euler method.
+        
+        Args:
+            model: The model to sample from.
+            noise: The initial noise tensor.
+            cond: conditional information.
+            neg_cond: negative conditional information.
+            steps: The number of steps to sample.
+            rescale_t: The rescale factor for t.
+            cfg_strength: The strength of classifier-free guidance.
+            verbose: If True, show a progress bar.
+            **kwargs: Additional arguments for model_inference.
+
+        Returns:
+            a dict containing the following
+            - 'samples': the model samples.
+            - 'pred_x_t': a list of prediction of x_t.
+            - 'pred_x_0': a list of prediction of x_0.
+        """
+        return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
+
+
+class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
+    """
+    Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
+    """
+    @torch.no_grad()
+    def sample(
+        self,
+        model,
+        noise,
+        cond,
+        neg_cond,
+        steps: int = 50,
+        rescale_t: float = 1.0,
+        cfg_strength: float = 3.0,
+        cfg_interval: Tuple[float, float] = (0.0, 1.0),
+        verbose: bool = True,
+        **kwargs
+    ):
+        """
+        Generate samples from the model using Euler method.
+        
+        Args:
+            model: The model to sample from.
+            noise: The initial noise tensor.
+            cond: conditional information.
+            neg_cond: negative conditional information.
+            steps: The number of steps to sample.
+            rescale_t: The rescale factor for t.
+            cfg_strength: The strength of classifier-free guidance.
+            cfg_interval: The interval for classifier-free guidance.
+            verbose: If True, show a progress bar.
+            **kwargs: Additional arguments for model_inference.
+
+        Returns:
+            a dict containing the following
+            - 'samples': the model samples.
+            - 'pred_x_t': a list of prediction of x_t.
+            - 'pred_x_0': a list of prediction of x_0.
+        """
+        return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..7074a4d5fea20a8f799416aa6571faca4f9eea06
--- /dev/null
+++ b/trellis/pipelines/samplers/guidance_interval_mixin.py
@@ -0,0 +1,15 @@
+from typing import *
+
+
+class GuidanceIntervalSamplerMixin:
+    """
+    A mixin class for samplers that apply classifier-free guidance with interval.
+    """
+
+    def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
+        if cfg_interval[0] <= t <= cfg_interval[1]:
+            pred = super()._inference_model(model, x_t, t, cond, **kwargs)
+            neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
+            return (1 + cfg_strength) * pred - cfg_strength * neg_pred
+        else:
+            return super()._inference_model(model, x_t, t, cond, **kwargs)
diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..59d5894d870b8c03e014c0c29bb47873c11c8c84
--- /dev/null
+++ b/trellis/pipelines/trellis_image_to_3d.py
@@ -0,0 +1,281 @@
+from typing import *
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+from easydict import EasyDict as edict
+from torchvision import transforms
+from PIL import Image
+import rembg
+from .base import Pipeline
+from . import samplers
+from ..modules import sparse as sp
+from ..representations import Gaussian, Strivec, MeshExtractResult
+
+
+class TrellisImageTo3DPipeline(Pipeline):
+    """
+    Pipeline for inferring Trellis image-to-3D models.
+
+    Args:
+        models (dict[str, nn.Module]): The models to use in the pipeline.
+        sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
+        slat_sampler (samplers.Sampler): The sampler for the structured latent.
+        slat_normalization (dict): The normalization parameters for the structured latent.
+        image_cond_model (str): The name of the image conditioning model.
+    """
+    def __init__(
+        self,
+        models: dict[str, nn.Module] = None,
+        sparse_structure_sampler: samplers.Sampler = None,
+        slat_sampler: samplers.Sampler = None,
+        slat_normalization: dict = None,
+        image_cond_model: str = None,
+    ):
+        if models is None:
+            return
+        super().__init__(models)
+        self.sparse_structure_sampler = sparse_structure_sampler
+        self.slat_sampler = slat_sampler
+        self.sparse_structure_sampler_params = {}
+        self.slat_sampler_params = {}
+        self.slat_normalization = slat_normalization
+        self.rembg_session = None
+        self._init_image_cond_model(image_cond_model)
+
+    @staticmethod
+    def from_pretrained(path: str) -> "TrellisImageTo3DPipeline":
+        """
+        Load a pretrained model.
+
+        Args:
+            path (str): The path to the model. Can be either local path or a Hugging Face repository.
+        """
+        pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path)
+        new_pipeline = TrellisImageTo3DPipeline()
+        new_pipeline.__dict__ = pipeline.__dict__
+        args = pipeline._pretrained_args
+
+        new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
+        new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
+
+        new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
+        new_pipeline.slat_sampler_params = args['slat_sampler']['params']
+
+        new_pipeline.slat_normalization = args['slat_normalization']
+
+        new_pipeline._init_image_cond_model(args['image_cond_model'])
+
+        return new_pipeline
+    
+    def _init_image_cond_model(self, name: str):
+        """
+        Initialize the image conditioning model.
+        """
+        dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True)
+        dinov2_model.eval()
+        self.models['image_cond_model'] = dinov2_model
+        transform = transforms.Compose([
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+        ])
+        self.image_cond_model_transform = transform
+
+    def preprocess_image(self, input: Image.Image) -> Image.Image:
+        """
+        Preprocess the input image.
+        """
+        # if has alpha channel, use it directly; otherwise, remove background
+        has_alpha = False
+        if input.mode == 'RGBA':
+            alpha = np.array(input)[:, :, 3]
+            if not np.all(alpha == 255):
+                has_alpha = True
+        if has_alpha:
+            output = input
+        else:
+            input = input.convert('RGB')
+            max_size = max(input.size)
+            scale = min(1, 1024 / max_size)
+            if scale < 1:
+                input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
+            if getattr(self, 'rembg_session', None) is None:
+                self.rembg_session = rembg.new_session('u2net')
+            output = rembg.remove(input, session=self.rembg_session)
+        output_np = np.array(output)
+        alpha = output_np[:, :, 3]
+        bbox = np.argwhere(alpha > 0.8 * 255)
+        bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
+        center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+        size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+        size = int(size * 1.2)
+        bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
+        output = output.crop(bbox)  # type: ignore
+        output = output.resize((518, 518), Image.Resampling.LANCZOS)
+        output = np.array(output).astype(np.float32) / 255
+        output = output[:, :, :3] * output[:, :, 3:4]
+        output = Image.fromarray((output * 255).astype(np.uint8))
+        return output
+
+    @torch.no_grad()
+    def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor:
+        """
+        Encode the image.
+
+        Args:
+            image (Union[torch.Tensor, list[Image.Image]]): The image to encode
+
+        Returns:
+            torch.Tensor: The encoded features.
+        """
+        if isinstance(image, torch.Tensor):
+            assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
+        elif isinstance(image, list):
+            assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
+            image = [i.resize((518, 518), Image.LANCZOS) for i in image]
+            image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
+            image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
+            image = torch.stack(image).to(self.device)
+        else:
+            raise ValueError(f"Unsupported type of image: {type(image)}")
+        
+        image = self.image_cond_model_transform(image).to(self.device)
+        features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
+        patchtokens = F.layer_norm(features, features.shape[-1:])
+        return patchtokens
+        
+    def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
+        """
+        Get the conditioning information for the model.
+
+        Args:
+            image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
+
+        Returns:
+            dict: The conditioning information
+        """
+        cond = self.encode_image(image)
+        neg_cond = torch.zeros_like(cond)
+        return {
+            'cond': cond,
+            'neg_cond': neg_cond,
+        }
+
+    def sample_sparse_structure(
+        self,
+        cond: dict,
+        num_samples: int = 1,
+        sampler_params: dict = {},
+    ) -> torch.Tensor:
+        """
+        Sample sparse structures with the given conditioning.
+        
+        Args:
+            cond (dict): The conditioning information.
+            num_samples (int): The number of samples to generate.
+            sampler_params (dict): Additional parameters for the sampler.
+        """
+        # Sample occupancy latent
+        flow_model = self.models['sparse_structure_flow_model']
+        reso = flow_model.resolution
+        noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
+        sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
+        z_s = self.sparse_structure_sampler.sample(
+            flow_model,
+            noise,
+            **cond,
+            **sampler_params,
+            verbose=True
+        ).samples
+        
+        # Decode occupancy latent
+        decoder = self.models['sparse_structure_decoder']
+        coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
+
+        return coords
+
+    def decode_slat(
+        self,
+        slat: sp.SparseTensor,
+        formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
+    ) -> dict:
+        """
+        Decode the structured latent.
+
+        Args:
+            slat (sp.SparseTensor): The structured latent.
+            formats (List[str]): The formats to decode the structured latent to.
+
+        Returns:
+            dict: The decoded structured latent.
+        """
+        ret = {}
+        if 'mesh' in formats:
+            ret['mesh'] = self.models['slat_decoder_mesh'](slat)
+        if 'gaussian' in formats:
+            ret['gaussian'] = self.models['slat_decoder_gs'](slat)
+        if 'radiance_field' in formats:
+            ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
+        return ret
+    
+    def sample_slat(
+        self,
+        cond: dict,
+        coords: torch.Tensor,
+        sampler_params: dict = {},
+    ) -> sp.SparseTensor:
+        """
+        Sample structured latent with the given conditioning.
+        
+        Args:
+            cond (dict): The conditioning information.
+            coords (torch.Tensor): The coordinates of the sparse structure.
+            sampler_params (dict): Additional parameters for the sampler.
+        """
+        # Sample structured latent
+        flow_model = self.models['slat_flow_model']
+        noise = sp.SparseTensor(
+            feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
+            coords=coords,
+        )
+        sampler_params = {**self.slat_sampler_params, **sampler_params}
+        slat = self.slat_sampler.sample(
+            flow_model,
+            noise,
+            **cond,
+            **sampler_params,
+            verbose=True
+        ).samples
+
+        std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
+        mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
+        slat = slat * std + mean
+        
+        return slat
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        image: Image.Image,
+        num_samples: int = 1,
+        sparse_structure_sampler_params: dict = {},
+        slat_sampler_params: dict = {},
+        formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
+        preprocess_image: bool = True,
+    ) -> dict:
+        """
+        Run the pipeline.
+
+        Args:
+            image (Image.Image): The image prompt.
+            num_samples (int): The number of samples to generate.
+            sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
+            slat_sampler_params (dict): Additional parameters for the structured latent sampler.
+            preprocess_image (bool): Whether to preprocess the image.
+        """
+        if preprocess_image:
+            image = self.preprocess_image(image)
+        cond = self.get_cond([image])
+        coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
+        slat = self.sample_slat(cond, coords, slat_sampler_params)
+        return self.decode_slat(slat, formats)
diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..0339355c56b8d17f72e926650d140a658452fbe9
--- /dev/null
+++ b/trellis/renderers/__init__.py
@@ -0,0 +1,31 @@
+import importlib
+
+__attributes = {
+    'OctreeRenderer': 'octree_renderer',
+    'GaussianRenderer': 'gaussian_render',
+    'MeshRenderer': 'mesh_renderer',
+}
+
+__submodules = []
+
+__all__ = list(__attributes.keys()) + __submodules
+
+def __getattr__(name):
+    if name not in globals():
+        if name in __attributes:
+            module_name = __attributes[name]
+            module = importlib.import_module(f".{module_name}", __name__)
+            globals()[name] = getattr(module, name)
+        elif name in __submodules:
+            module = importlib.import_module(f".{name}", __name__)
+            globals()[name] = module
+        else:
+            raise AttributeError(f"module {__name__} has no attribute {name}")
+    return globals()[name]
+
+
+# For Pylance
+if __name__ == '__main__':
+    from .octree_renderer import OctreeRenderer
+    from .gaussian_render import GaussianRenderer
+    from .mesh_renderer import MeshRenderer
\ No newline at end of file
diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py
new file mode 100755
index 0000000000000000000000000000000000000000..57108e3cccf6aab8e3059431557c461de46aff1a
--- /dev/null
+++ b/trellis/renderers/gaussian_render.py
@@ -0,0 +1,231 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use 
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact  george.drettakis@inria.fr
+#
+
+import torch
+import math
+from easydict import EasyDict as edict
+import numpy as np
+from ..representations.gaussian import Gaussian
+from .sh_utils import eval_sh
+import torch.nn.functional as F
+from easydict import EasyDict as edict
+
+
+def intrinsics_to_projection(
+        intrinsics: torch.Tensor,
+        near: float,
+        far: float,
+    ) -> torch.Tensor:
+    """
+    OpenCV intrinsics to OpenGL perspective matrix
+
+    Args:
+        intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+        near (float): near plane to clip
+        far (float): far plane to clip
+    Returns:
+        (torch.Tensor): [4, 4] OpenGL perspective matrix
+    """
+    fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+    cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+    ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+    ret[0, 0] = 2 * fx
+    ret[1, 1] = 2 * fy
+    ret[0, 2] = 2 * cx - 1
+    ret[1, 2] = - 2 * cy + 1
+    ret[2, 2] = far / (far - near)
+    ret[2, 3] = near * far / (near - far)
+    ret[3, 2] = 1.
+    return ret
+
+
+def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
+    """
+    Render the scene. 
+    
+    Background tensor (bg_color) must be on GPU!
+    """
+    # lazy import
+    if 'GaussianRasterizer' not in globals():
+        from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
+    
+    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
+    try:
+        screenspace_points.retain_grad()
+    except:
+        pass
+    # Set up rasterization configuration
+    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+    
+    kernel_size = pipe.kernel_size
+    subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
+
+    raster_settings = GaussianRasterizationSettings(
+        image_height=int(viewpoint_camera.image_height),
+        image_width=int(viewpoint_camera.image_width),
+        tanfovx=tanfovx,
+        tanfovy=tanfovy,
+        kernel_size=kernel_size,
+        subpixel_offset=subpixel_offset,
+        bg=bg_color,
+        scale_modifier=scaling_modifier,
+        viewmatrix=viewpoint_camera.world_view_transform,
+        projmatrix=viewpoint_camera.full_proj_transform,
+        sh_degree=pc.active_sh_degree,
+        campos=viewpoint_camera.camera_center,
+        prefiltered=False,
+        debug=pipe.debug
+    )
+    
+    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+    means3D = pc.get_xyz
+    means2D = screenspace_points
+    opacity = pc.get_opacity
+
+    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+    # scaling / rotation by the rasterizer.
+    scales = None
+    rotations = None
+    cov3D_precomp = None
+    if pipe.compute_cov3D_python:
+        cov3D_precomp = pc.get_covariance(scaling_modifier)
+    else:
+        scales = pc.get_scaling
+        rotations = pc.get_rotation
+
+    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+    shs = None
+    colors_precomp = None
+    if override_color is None:
+        if pipe.convert_SHs_python:
+            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
+            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
+            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
+            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
+            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
+        else:
+            shs = pc.get_features
+    else:
+        colors_precomp = override_color
+
+    # Rasterize visible Gaussians to image, obtain their radii (on screen). 
+    rendered_image, radii = rasterizer(
+        means3D = means3D,
+        means2D = means2D,
+        shs = shs,
+        colors_precomp = colors_precomp,
+        opacities = opacity,
+        scales = scales,
+        rotations = rotations,
+        cov3D_precomp = cov3D_precomp
+    )
+
+    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
+    # They will be excluded from value updates used in the splitting criteria.
+    return edict({"render": rendered_image,
+            "viewspace_points": screenspace_points,
+            "visibility_filter" : radii > 0,
+            "radii": radii})
+
+
+class GaussianRenderer:
+    """
+    Renderer for the Voxel representation.
+
+    Args:
+        rendering_options (dict): Rendering options.
+    """
+
+    def __init__(self, rendering_options={}) -> None:
+        self.pipe = edict({
+            "kernel_size": 0.1,
+            "convert_SHs_python": False,
+            "compute_cov3D_python": False,
+            "scale_modifier": 1.0,
+            "debug": False
+        })
+        self.rendering_options = edict({
+            "resolution": None,
+            "near": None,
+            "far": None,
+            "ssaa": 1,
+            "bg_color": 'random',
+        })
+        self.rendering_options.update(rendering_options)
+        self.bg_color = None
+    
+    def render(
+            self,
+            gausssian: Gaussian,
+            extrinsics: torch.Tensor,
+            intrinsics: torch.Tensor,
+            colors_overwrite: torch.Tensor = None
+        ) -> edict:
+        """
+        Render the gausssian.
+
+        Args:
+            gaussian : gaussianmodule
+            extrinsics (torch.Tensor): (4, 4) camera extrinsics
+            intrinsics (torch.Tensor): (3, 3) camera intrinsics
+            colors_overwrite (torch.Tensor): (N, 3) override color
+
+        Returns:
+            edict containing:
+                color (torch.Tensor): (3, H, W) rendered color image
+        """
+        resolution = self.rendering_options["resolution"]
+        near = self.rendering_options["near"]
+        far = self.rendering_options["far"]
+        ssaa = self.rendering_options["ssaa"]
+        
+        if self.rendering_options["bg_color"] == 'random':
+            self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
+            if np.random.rand() < 0.5:
+                self.bg_color += 1
+        else:
+            self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
+
+        view = extrinsics
+        perspective = intrinsics_to_projection(intrinsics, near, far)
+        camera = torch.inverse(view)[:3, 3]
+        focalx = intrinsics[0, 0]
+        focaly = intrinsics[1, 1]
+        fovx = 2 * torch.atan(0.5 / focalx)
+        fovy = 2 * torch.atan(0.5 / focaly)
+            
+        camera_dict = edict({
+            "image_height": resolution * ssaa,
+            "image_width": resolution * ssaa,
+            "FoVx": fovx,
+            "FoVy": fovy,
+            "znear": near,
+            "zfar": far,
+            "world_view_transform": view.T.contiguous(),
+            "projection_matrix": perspective.T.contiguous(),
+            "full_proj_transform": (perspective @ view).T.contiguous(),
+            "camera_center": camera
+        })
+
+        # Render
+        render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
+
+        if ssaa > 1:
+            render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+            
+        ret = edict({
+            'color': render_ret['render']
+        })
+        return ret
diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..837094cf8f2125b212d2bdd61a05d99fa39358a1
--- /dev/null
+++ b/trellis/renderers/mesh_renderer.py
@@ -0,0 +1,144 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+import torch
+try:
+    import kaolin as kal
+    import nvdiffrast.torch as dr
+except :
+    print("Kaolin and nvdiffrast are not installed. Please install them to use the mesh renderer.")
+from easydict import EasyDict as edict
+from ..representations.mesh import MeshExtractResult
+import torch.nn.functional as F
+
+
+def intrinsics_to_projection(
+        intrinsics: torch.Tensor,
+        near: float,
+        far: float,
+    ) -> torch.Tensor:
+    """
+    OpenCV intrinsics to OpenGL perspective matrix
+
+    Args:
+        intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+        near (float): near plane to clip
+        far (float): far plane to clip
+    Returns:
+        (torch.Tensor): [4, 4] OpenGL perspective matrix
+    """
+    fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+    cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+    ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+    ret[0, 0] = 2 * fx
+    ret[1, 1] = 2 * fy
+    ret[0, 2] = 2 * cx - 1
+    ret[1, 2] = - 2 * cy + 1
+    ret[2, 2] = far / (far - near)
+    ret[2, 3] = near * far / (near - far)
+    ret[3, 2] = 1.
+    return ret
+
+
+class MeshRenderer:
+    """
+    Renderer for the Mesh representation.
+
+    Args:
+        rendering_options (dict): Rendering options.
+        glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
+        """
+    def __init__(self, rendering_options={}, device='cuda'):
+        self.rendering_options = edict({
+            "resolution": None,
+            "near": None,
+            "far": None,
+            "ssaa": 1
+        })
+        self.rendering_options.update(rendering_options)
+        self.glctx = dr.RasterizeCudaContext(device=device)
+        self.device=device
+        
+    def render(
+            self,
+            mesh : MeshExtractResult,
+            extrinsics: torch.Tensor,
+            intrinsics: torch.Tensor,
+            return_types = ["mask", "normal", "depth"]
+        ) -> edict:
+        """
+        Render the mesh.
+
+        Args:
+            mesh : meshmodel
+            extrinsics (torch.Tensor): (4, 4) camera extrinsics
+            intrinsics (torch.Tensor): (3, 3) camera intrinsics
+            return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color"
+
+        Returns:
+            edict based on return_types containing:
+                color (torch.Tensor): [3, H, W] rendered color image
+                depth (torch.Tensor): [H, W] rendered depth image
+                normal (torch.Tensor): [3, H, W] rendered normal image
+                normal_map (torch.Tensor): [3, H, W] rendered normal map image
+                mask (torch.Tensor): [H, W] rendered mask image
+        """
+        resolution = self.rendering_options["resolution"]
+        near = self.rendering_options["near"]
+        far = self.rendering_options["far"]
+        ssaa = self.rendering_options["ssaa"]
+        
+        if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
+            default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
+            ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
+            return ret_dict
+        
+        perspective = intrinsics_to_projection(intrinsics, near, far)
+        
+        RT = extrinsics.unsqueeze(0)
+        full_proj = (perspective @ extrinsics).unsqueeze(0)
+        
+        vertices = mesh.vertices.unsqueeze(0)
+
+        vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
+        vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
+        vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
+        faces_int = mesh.faces.int()
+        rast, _ = dr.rasterize(
+            self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
+        
+        out_dict = edict()
+        for type in return_types:
+            img = None
+            if type == "mask" :
+                img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
+            elif type == "depth":
+                img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
+                img = dr.antialias(img, rast, vertices_clip, faces_int)
+            elif type == "normal" :
+                img = dr.interpolate(
+                    mesh.face_normal.reshape(1, -1, 3), rast,
+                    torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
+                )[0]
+                img = dr.antialias(img, rast, vertices_clip, faces_int)
+                # normalize norm pictures
+                img = (img + 1) / 2
+            elif type == "normal_map" :
+                img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
+                img = dr.antialias(img, rast, vertices_clip, faces_int)
+            elif type == "color" :
+                img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
+                img = dr.antialias(img, rast, vertices_clip, faces_int)
+
+            if ssaa > 1:
+                img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
+                img = img.squeeze()
+            else:
+                img = img.permute(0, 3, 1, 2).squeeze()
+            out_dict[type] = img
+
+        return out_dict
diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py
new file mode 100755
index 0000000000000000000000000000000000000000..136069cdb0645b5759d5d17f7815612a1dfc7bea
--- /dev/null
+++ b/trellis/renderers/octree_renderer.py
@@ -0,0 +1,300 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+import math
+import cv2
+from scipy.stats import qmc
+from easydict import EasyDict as edict
+from ..representations.octree import DfsOctree
+
+
+def intrinsics_to_projection(
+        intrinsics: torch.Tensor,
+        near: float,
+        far: float,
+    ) -> torch.Tensor:
+    """
+    OpenCV intrinsics to OpenGL perspective matrix
+
+    Args:
+        intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
+        near (float): near plane to clip
+        far (float): far plane to clip
+    Returns:
+        (torch.Tensor): [4, 4] OpenGL perspective matrix
+    """
+    fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+    cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+    ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+    ret[0, 0] = 2 * fx
+    ret[1, 1] = 2 * fy
+    ret[0, 2] = 2 * cx - 1
+    ret[1, 2] = - 2 * cy + 1
+    ret[2, 2] = far / (far - near)
+    ret[2, 3] = near * far / (near - far)
+    ret[3, 2] = 1.
+    return ret
+
+
+def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None):
+    """
+    Render the scene. 
+    
+    Background tensor (bg_color) must be on GPU!
+    """
+    # lazy import
+    if 'OctreeTrivecRasterizer' not in globals():
+        from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer
+    
+    # Set up rasterization configuration
+    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+    raster_settings = edict(
+        image_height=int(viewpoint_camera.image_height),
+        image_width=int(viewpoint_camera.image_width),
+        tanfovx=tanfovx,
+        tanfovy=tanfovy,
+        bg=bg_color,
+        scale_modifier=scaling_modifier,
+        viewmatrix=viewpoint_camera.world_view_transform,
+        projmatrix=viewpoint_camera.full_proj_transform,
+        sh_degree=octree.active_sh_degree,
+        campos=viewpoint_camera.camera_center,
+        with_distloss=pipe.with_distloss,
+        jitter=pipe.jitter,
+        debug=pipe.debug,
+    )
+
+    positions = octree.get_xyz
+    if octree.primitive == "voxel":
+        densities = octree.get_density
+    elif octree.primitive == "gaussian":
+        opacities = octree.get_opacity
+    elif octree.primitive == "trivec":
+        trivecs = octree.get_trivec
+        densities = octree.get_density
+        raster_settings.density_shift = octree.density_shift
+    elif octree.primitive == "decoupoly":
+        decoupolys_V, decoupolys_g = octree.get_decoupoly
+        densities = octree.get_density
+        raster_settings.density_shift = octree.density_shift
+    else:
+        raise ValueError(f"Unknown primitive {octree.primitive}")
+    depths = octree.get_depth
+
+    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+    colors_precomp = None
+    shs = octree.get_features
+    if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None:
+        colors_precomp = colors_overwrite
+        shs = None
+
+    ret = edict()
+
+    if octree.primitive == "voxel":
+        renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
+        rgb, depth, alpha, distloss = renderer(
+            positions = positions,
+            densities = densities,
+            shs = shs,
+            colors_precomp = colors_precomp,
+            depths = depths,
+            aabb = octree.aabb,
+            aux = aux,
+        )
+        ret['rgb'] = rgb
+        ret['depth'] = depth
+        ret['alpha'] = alpha
+        ret['distloss'] = distloss
+    elif octree.primitive == "gaussian":
+        renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
+        rgb, depth, alpha = renderer(
+            positions = positions,
+            opacities = opacities,
+            shs = shs,
+            colors_precomp = colors_precomp,
+            depths = depths,
+            aabb = octree.aabb,
+            aux = aux,
+        )
+        ret['rgb'] = rgb
+        ret['depth'] = depth
+        ret['alpha'] = alpha
+    elif octree.primitive == "trivec":
+        raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1]
+        renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
+        rgb, depth, alpha, percent_depth = renderer(
+            positions = positions,
+            trivecs = trivecs,
+            densities = densities,
+            shs = shs,
+            colors_precomp = colors_precomp,
+            colors_overwrite = colors_overwrite,
+            depths = depths,
+            aabb = octree.aabb,
+            aux = aux,
+            halton_sampler = halton_sampler,
+        )
+        ret['percent_depth'] = percent_depth
+        ret['rgb'] = rgb
+        ret['depth'] = depth
+        ret['alpha'] = alpha
+    elif octree.primitive == "decoupoly":
+        raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1]
+        renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
+        rgb, depth, alpha = renderer(
+            positions = positions,
+            decoupolys_V = decoupolys_V,
+            decoupolys_g = decoupolys_g,
+            densities = densities,
+            shs = shs,
+            colors_precomp = colors_precomp,
+            depths = depths,
+            aabb = octree.aabb,
+            aux = aux,
+        )
+        ret['rgb'] = rgb
+        ret['depth'] = depth
+        ret['alpha'] = alpha
+    
+    return ret
+
+
+class OctreeRenderer:
+    """
+    Renderer for the Voxel representation.
+
+    Args:
+        rendering_options (dict): Rendering options.
+    """
+
+    def __init__(self, rendering_options={}) -> None:
+        try:
+            import diffoctreerast
+        except ImportError:
+            print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m")
+            self.unsupported = True
+        else:
+            self.unsupported = False
+        
+        self.pipe = edict({
+            "with_distloss": False,
+            "with_aux": False,
+            "scale_modifier": 1.0,
+            "used_rank": None,
+            "jitter": False,
+            "debug": False,
+        })
+        self.rendering_options = edict({
+            "resolution": None,
+            "near": None,
+            "far": None,
+            "ssaa": 1,
+            "bg_color": 'random',
+        })
+        self.halton_sampler = qmc.Halton(2, scramble=False)
+        self.rendering_options.update(rendering_options)
+        self.bg_color = None
+    
+    def render(
+            self,
+            octree: DfsOctree,
+            extrinsics: torch.Tensor,
+            intrinsics: torch.Tensor,
+            colors_overwrite: torch.Tensor = None,
+        ) -> edict:
+        """
+        Render the octree.
+
+        Args:
+            octree (Octree): octree
+            extrinsics (torch.Tensor): (4, 4) camera extrinsics
+            intrinsics (torch.Tensor): (3, 3) camera intrinsics
+            colors_overwrite (torch.Tensor): (N, 3) override color
+
+        Returns:
+            edict containing:
+                color (torch.Tensor): (3, H, W) rendered color
+                depth (torch.Tensor): (H, W) rendered depth
+                alpha (torch.Tensor): (H, W) rendered alpha
+                distloss (Optional[torch.Tensor]): (H, W) rendered distance loss
+                percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth
+                aux (Optional[edict]): auxiliary tensors
+        """
+        resolution = self.rendering_options["resolution"]
+        near = self.rendering_options["near"]
+        far = self.rendering_options["far"]
+        ssaa = self.rendering_options["ssaa"]
+        
+        if self.unsupported:
+            image = np.zeros((512, 512, 3), dtype=np.uint8)
+            text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
+            origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
+            image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
+            return {
+                'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255,
+            }
+        
+        if self.rendering_options["bg_color"] == 'random':
+            self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
+            if np.random.rand() < 0.5:
+                self.bg_color += 1
+        else:
+            self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
+
+        if self.pipe["with_aux"]:
+            aux = {
+                'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
+                'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
+            }
+            for k in aux.keys():
+                aux[k].requires_grad_()
+                aux[k].retain_grad()
+        else:
+            aux = None
+
+        view = extrinsics
+        perspective = intrinsics_to_projection(intrinsics, near, far)
+        camera = torch.inverse(view)[:3, 3]
+        focalx = intrinsics[0, 0]
+        focaly = intrinsics[1, 1]
+        fovx = 2 * torch.atan(0.5 / focalx)
+        fovy = 2 * torch.atan(0.5 / focaly)
+            
+        camera_dict = edict({
+            "image_height": resolution * ssaa,
+            "image_width": resolution * ssaa,
+            "FoVx": fovx,
+            "FoVy": fovy,
+            "znear": near,
+            "zfar": far,
+            "world_view_transform": view.T.contiguous(),
+            "projection_matrix": perspective.T.contiguous(),
+            "full_proj_transform": (perspective @ view).T.contiguous(),
+            "camera_center": camera
+        })
+
+        # Render
+        render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler)
+
+        if ssaa > 1:
+            render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+            render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+            render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+            if hasattr(render_ret, 'percent_depth'):
+                render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
+
+        ret = edict({
+            'color': render_ret.rgb,
+            'depth': render_ret.depth,
+            'alpha': render_ret.alpha,
+        })
+        if self.pipe["with_distloss"] and 'distloss' in render_ret:
+            ret['distloss'] = render_ret.distloss
+        if self.pipe["with_aux"]:
+            ret['aux'] = aux
+        if hasattr(render_ret, 'percent_depth'):
+            ret['percent_depth'] = render_ret.percent_depth
+        return ret
diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785
--- /dev/null
+++ b/trellis/renderers/sh_utils.py
@@ -0,0 +1,118 @@
+#  Copyright 2021 The PlenOctree Authors.
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#
+#  1. Redistributions of source code must retain the above copyright notice,
+#  this list of conditions and the following disclaimer.
+#
+#  2. Redistributions in binary form must reproduce the above copyright notice,
+#  this list of conditions and the following disclaimer in the documentation
+#  and/or other materials provided with the distribution.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+#  POSSIBILITY OF SUCH DAMAGE.
+
+import torch
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+    1.0925484305920792,
+    -1.0925484305920792,
+    0.31539156525252005,
+    -1.0925484305920792,
+    0.5462742152960396
+]
+C3 = [
+    -0.5900435899266435,
+    2.890611442640554,
+    -0.4570457994644658,
+    0.3731763325901154,
+    -0.4570457994644658,
+    1.445305721320277,
+    -0.5900435899266435
+]
+C4 = [
+    2.5033429417967046,
+    -1.7701307697799304,
+    0.9461746957575601,
+    -0.6690465435572892,
+    0.10578554691520431,
+    -0.6690465435572892,
+    0.47308734787878004,
+    -1.7701307697799304,
+    0.6258357354491761,
+]   
+
+
+def eval_sh(deg, sh, dirs):
+    """
+    Evaluate spherical harmonics at unit directions
+    using hardcoded SH polynomials.
+    Works with torch/np/jnp.
+    ... Can be 0 or more batch dimensions.
+    Args:
+        deg: int SH deg. Currently, 0-3 supported
+        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+        dirs: jnp.ndarray unit directions [..., 3]
+    Returns:
+        [..., C]
+    """
+    assert deg <= 4 and deg >= 0
+    coeff = (deg + 1) ** 2
+    assert sh.shape[-1] >= coeff
+
+    result = C0 * sh[..., 0]
+    if deg > 0:
+        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+        result = (result -
+                C1 * y * sh[..., 1] +
+                C1 * z * sh[..., 2] -
+                C1 * x * sh[..., 3])
+
+        if deg > 1:
+            xx, yy, zz = x * x, y * y, z * z
+            xy, yz, xz = x * y, y * z, x * z
+            result = (result +
+                    C2[0] * xy * sh[..., 4] +
+                    C2[1] * yz * sh[..., 5] +
+                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+                    C2[3] * xz * sh[..., 7] +
+                    C2[4] * (xx - yy) * sh[..., 8])
+
+            if deg > 2:
+                result = (result +
+                C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+                C3[1] * xy * z * sh[..., 10] +
+                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+                C3[5] * z * (xx - yy) * sh[..., 14] +
+                C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+                if deg > 3:
+                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+    return result
+
+def RGB2SH(rgb):
+    return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+    return sh * C0 + 0.5
\ No newline at end of file
diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..549ffdb97e87181552e9b3e086766f873e4bfb5e
--- /dev/null
+++ b/trellis/representations/__init__.py
@@ -0,0 +1,4 @@
+from .radiance_field import Strivec
+from .octree import DfsOctree as Octree
+from .gaussian import Gaussian
+from .mesh import MeshExtractResult
diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e3de6e180bd732836af876d748255595be2d4d74
--- /dev/null
+++ b/trellis/representations/gaussian/__init__.py
@@ -0,0 +1 @@
+from .gaussian_model import Gaussian
\ No newline at end of file
diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..8716bb5ed0b16b6d545d05b04191deeeee042bae
--- /dev/null
+++ b/trellis/representations/gaussian/gaussian_model.py
@@ -0,0 +1,185 @@
+import torch
+import numpy as np
+from plyfile import PlyData, PlyElement
+from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
+
+
+class Gaussian:
+    def __init__(
+            self, 
+            aabb : list,
+            sh_degree : int = 0,
+            mininum_kernel_size : float = 0.0,
+            scaling_bias : float = 0.01,
+            opacity_bias : float = 0.1,
+            scaling_activation : str = "exp",
+            device='cuda'
+        ):
+        self.sh_degree = sh_degree
+        self.active_sh_degree = sh_degree
+        self.mininum_kernel_size = mininum_kernel_size 
+        self.scaling_bias = scaling_bias
+        self.opacity_bias = opacity_bias
+        self.scaling_activation_type = scaling_activation
+        self.device = device
+        self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
+        self.setup_functions()
+
+        self._xyz = None
+        self._features_dc = None
+        self._features_rest = None
+        self._scaling = None
+        self._rotation = None
+        self._opacity = None
+
+    def setup_functions(self):
+        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
+            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+            actual_covariance = L @ L.transpose(1, 2)
+            symm = strip_symmetric(actual_covariance)
+            return symm
+        
+        if self.scaling_activation_type == "exp":
+            self.scaling_activation = torch.exp
+            self.inverse_scaling_activation = torch.log
+        elif self.scaling_activation_type == "softplus":
+            self.scaling_activation = torch.nn.functional.softplus
+            self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
+
+        self.covariance_activation = build_covariance_from_scaling_rotation
+
+        self.opacity_activation = torch.sigmoid
+        self.inverse_opacity_activation = inverse_sigmoid
+
+        self.rotation_activation = torch.nn.functional.normalize
+        
+        self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
+        self.rots_bias = torch.zeros((4)).cuda()
+        self.rots_bias[0] = 1
+        self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
+
+    @property
+    def get_scaling(self):
+        scales = self.scaling_activation(self._scaling + self.scale_bias)
+        scales = torch.square(scales) + self.mininum_kernel_size ** 2
+        scales = torch.sqrt(scales)
+        return scales
+    
+    @property
+    def get_rotation(self):
+        return self.rotation_activation(self._rotation + self.rots_bias[None, :])
+    
+    @property
+    def get_xyz(self):
+        return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
+    
+    @property
+    def get_features(self):
+        return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
+    
+    @property
+    def get_opacity(self):
+        return self.opacity_activation(self._opacity + self.opacity_bias)
+    
+    def get_covariance(self, scaling_modifier = 1):
+        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
+    
+    def from_scaling(self, scales):
+        scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
+        self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
+        
+    def from_rotation(self, rots):
+        self._rotation = rots - self.rots_bias[None, :]
+    
+    def from_xyz(self, xyz):
+        self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
+        
+    def from_features(self, features):
+        self._features_dc = features
+        
+    def from_opacity(self, opacities):
+        self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
+
+    def construct_list_of_attributes(self):
+        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+        # All channels except the 3 DC
+        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
+            l.append('f_dc_{}'.format(i))
+        l.append('opacity')
+        for i in range(self._scaling.shape[1]):
+            l.append('scale_{}'.format(i))
+        for i in range(self._rotation.shape[1]):
+            l.append('rot_{}'.format(i))
+        return l
+
+    def save_ply(self, path):
+        xyz = self.get_xyz.detach().cpu().numpy()
+        normals = np.zeros_like(xyz)
+        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+        opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
+        scale = torch.log(self.get_scaling).detach().cpu().numpy()
+        rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
+
+        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+        elements = np.empty(xyz.shape[0], dtype=dtype_full)
+        attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
+        elements[:] = list(map(tuple, attributes))
+        el = PlyElement.describe(elements, 'vertex')
+        PlyData([el]).write(path)
+
+    def load_ply(self, path):
+        plydata = PlyData.read(path)
+
+        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
+                        np.asarray(plydata.elements[0]["y"]),
+                        np.asarray(plydata.elements[0]["z"])),  axis=1)
+        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+        features_dc = np.zeros((xyz.shape[0], 3, 1))
+        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+        if self.sh_degree > 0:
+            extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+            extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
+            assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
+            features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+            for idx, attr_name in enumerate(extra_f_names):
+                features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+            # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+            features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
+
+        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
+        scales = np.zeros((xyz.shape[0], len(scale_names)))
+        for idx, attr_name in enumerate(scale_names):
+            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
+        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
+        rots = np.zeros((xyz.shape[0], len(rot_names)))
+        for idx, attr_name in enumerate(rot_names):
+            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+            
+        # convert to actual gaussian attributes
+        xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
+        features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
+        if self.sh_degree > 0:
+            features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
+        opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
+        scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
+        rots = torch.tensor(rots, dtype=torch.float, device=self.device)
+        
+        # convert to _hidden attributes
+        self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
+        self._features_dc = features_dc
+        if self.sh_degree > 0:
+            self._features_rest = features_extra
+        else:
+            self._features_rest = None
+        self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
+        self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
+        self._rotation = rots - self.rots_bias[None, :]
+        
\ No newline at end of file
diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d
--- /dev/null
+++ b/trellis/representations/gaussian/general_utils.py
@@ -0,0 +1,133 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use 
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact  george.drettakis@inria.fr
+#
+
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+    return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+    resized_image_PIL = pil_image.resize(resolution)
+    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+    if len(resized_image.shape) == 3:
+        return resized_image.permute(2, 0, 1)
+    else:
+        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+    """
+    Copied from Plenoxels
+
+    Continuous learning rate decay function. Adapted from JaxNeRF
+    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+    is log-linearly interpolated elsewhere (equivalent to exponential decay).
+    If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+    function of lr_delay_mult, such that the initial learning rate is
+    lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+    to the normal learning rate when steps>lr_delay_steps.
+    :param conf: config subtree 'lr' or similar
+    :param max_steps: int, the number of steps during optimization.
+    :return HoF which takes step as input
+    """
+
+    def helper(step):
+        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+            # Disable this parameter
+            return 0.0
+        if lr_delay_steps > 0:
+            # A kind of reverse cosine decay.
+            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+            )
+        else:
+            delay_rate = 1.0
+        t = np.clip(step / max_steps, 0, 1)
+        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+        return delay_rate * log_lerp
+
+    return helper
+
+def strip_lowerdiag(L):
+    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+    uncertainty[:, 0] = L[:, 0, 0]
+    uncertainty[:, 1] = L[:, 0, 1]
+    uncertainty[:, 2] = L[:, 0, 2]
+    uncertainty[:, 3] = L[:, 1, 1]
+    uncertainty[:, 4] = L[:, 1, 2]
+    uncertainty[:, 5] = L[:, 2, 2]
+    return uncertainty
+
+def strip_symmetric(sym):
+    return strip_lowerdiag(sym)
+
+def build_rotation(r):
+    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+    q = r / norm[:, None]
+
+    R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+    r = q[:, 0]
+    x = q[:, 1]
+    y = q[:, 2]
+    z = q[:, 3]
+
+    R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+    R[:, 0, 1] = 2 * (x*y - r*z)
+    R[:, 0, 2] = 2 * (x*z + r*y)
+    R[:, 1, 0] = 2 * (x*y + r*z)
+    R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+    R[:, 1, 2] = 2 * (y*z - r*x)
+    R[:, 2, 0] = 2 * (x*z - r*y)
+    R[:, 2, 1] = 2 * (y*z + r*x)
+    R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+    return R
+
+def build_scaling_rotation(s, r):
+    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+    R = build_rotation(r)
+
+    L[:,0,0] = s[:,0]
+    L[:,1,1] = s[:,1]
+    L[:,2,2] = s[:,2]
+
+    L = R @ L
+    return L
+
+def safe_state(silent):
+    old_f = sys.stdout
+    class F:
+        def __init__(self, silent):
+            self.silent = silent
+
+        def write(self, x):
+            if not self.silent:
+                if x.endswith("\n"):
+                    old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+                else:
+                    old_f.write(x)
+
+        def flush(self):
+            old_f.flush()
+
+    sys.stdout = F(silent)
+
+    random.seed(0)
+    np.random.seed(0)
+    torch.manual_seed(0)
+    torch.cuda.set_device(torch.device("cuda:0"))
diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38cf35c0853d11cf09bdc228a87ee9d0b2f34b62
--- /dev/null
+++ b/trellis/representations/mesh/__init__.py
@@ -0,0 +1 @@
+from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e32b51adc7755a5d6bdfa38e9f6b898a6aa7f8
--- /dev/null
+++ b/trellis/representations/mesh/cube2mesh.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+import torch
+from ...modules.sparse import SparseTensor
+from easydict import EasyDict as edict
+from .utils_cube import *
+try:
+    from .flexicube import FlexiCubes
+except:
+    print("Please install kaolin and diso to use the mesh extractor.")
+
+
+class MeshExtractResult:
+    def __init__(self,
+        vertices,
+        faces,
+        vertex_attrs=None,
+        res=64
+    ):
+        self.vertices = vertices
+        self.faces = faces.long()
+        self.vertex_attrs = vertex_attrs
+        self.face_normal = self.comput_face_normals(vertices, faces)
+        self.res = res
+        self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0)
+
+        # training only
+        self.tsdf_v = None
+        self.tsdf_s = None
+        self.reg_loss = None
+        
+    def comput_face_normals(self, verts, faces):
+        i0 = faces[..., 0].long()
+        i1 = faces[..., 1].long()
+        i2 = faces[..., 2].long()
+
+        v0 = verts[i0, :]
+        v1 = verts[i1, :]
+        v2 = verts[i2, :]
+        face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+        face_normals = torch.nn.functional.normalize(face_normals, dim=1)
+        # print(face_normals.min(), face_normals.max(), face_normals.shape)
+        return face_normals[:, None, :].repeat(1, 3, 1)
+                
+    def comput_v_normals(self, verts, faces):
+        i0 = faces[..., 0].long()
+        i1 = faces[..., 1].long()
+        i2 = faces[..., 2].long()
+
+        v0 = verts[i0, :]
+        v1 = verts[i1, :]
+        v2 = verts[i2, :]
+        face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
+        v_normals = torch.zeros_like(verts)
+        v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
+        v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
+        v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
+
+        v_normals = torch.nn.functional.normalize(v_normals, dim=1)
+        return v_normals   
+
+
+class SparseFeatures2Mesh:
+    def __init__(self, device="cuda", res=64, use_color=True):
+        '''
+        a model to generate a mesh from sparse features structures using flexicube
+        '''
+        super().__init__()
+        self.device=device
+        self.res = res
+        self.mesh_extractor = FlexiCubes(device=device)
+        self.sdf_bias = -1.0 / res
+        verts, cube = construct_dense_grid(self.res, self.device)
+        self.reg_c = cube.to(self.device)
+        self.reg_v = verts.to(self.device)
+        self.use_color = use_color
+        self._calc_layout()
+    
+    def _calc_layout(self):
+        LAYOUTS = {
+            'sdf': {'shape': (8, 1), 'size': 8},
+            'deform': {'shape': (8, 3), 'size': 8 * 3},
+            'weights': {'shape': (21,), 'size': 21}
+        }
+        if self.use_color:
+            '''
+            6 channel color including normal map
+            '''
+            LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
+        self.layouts = edict(LAYOUTS)
+        start = 0
+        for k, v in self.layouts.items():
+            v['range'] = (start, start + v['size'])
+            start += v['size']
+        self.feats_channels = start
+        
+    def get_layout(self, feats : torch.Tensor, name : str):
+        if name not in self.layouts:
+            return None
+        return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape'])
+    
+    def __call__(self, cubefeats : SparseTensor, training=False):
+        """
+        Generates a mesh based on the specified sparse voxel structures.
+        Args:
+            cube_attrs [Nx21] : Sparse Tensor attrs about cube weights
+            verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal 
+        Returns:
+            return the success tag and ni you loss, 
+        """
+        # add sdf bias to verts_attrs
+        coords = cubefeats.coords[:, 1:]
+        feats = cubefeats.feats
+        
+        sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']]
+        sdf += self.sdf_bias
+        v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
+        v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training)
+        v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True)
+        weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False)
+        if self.use_color:
+            sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:]
+        else:
+            sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
+            colors_d = None
+            
+        x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
+        
+        vertices, faces, L_dev, colors = self.mesh_extractor(
+            voxelgrid_vertices=x_nx3,
+            scalar_field=sdf_d,
+            cube_idx=self.reg_c,
+            resolution=self.res,
+            beta=weights_d[:, :12],
+            alpha=weights_d[:, 12:20],
+            gamma_f=weights_d[:, 20],
+            voxelgrid_colors=colors_d,
+            training=training)
+        
+        mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res)
+        if training:
+            if mesh.success:
+                reg_loss += L_dev.mean() * 0.5
+            reg_loss += (weights[:,:20]).abs().mean() * 0.2
+            mesh.reg_loss = reg_loss
+            mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
+            mesh.tsdf_s = v_attrs[:, 0]
+        return mesh
diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c786b32bf53775a45f2bb4c343b009b81fec9c
--- /dev/null
+++ b/trellis/representations/mesh/flexicube.py
@@ -0,0 +1,384 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+
+import torch
+from .tables import *
+from kaolin.utils.testing import check_tensor
+
+__all__ = [
+    'FlexiCubes'
+]
+
+
+class FlexiCubes:
+    def __init__(self, device="cuda"):
+
+        self.device = device
+        self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
+        self.num_vd_table = torch.tensor(num_vd_table,
+                                         dtype=torch.long, device=device, requires_grad=False)
+        self.check_table = torch.tensor(
+            check_table,
+            dtype=torch.long, device=device, requires_grad=False)
+
+        self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_train = torch.tensor(
+            [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
+
+        self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
+                                         1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
+        self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
+        self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
+                                       2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
+
+        self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
+                                           dtype=torch.long, device=device)
+        self.dir_faces_table = torch.tensor([
+            [[5, 4], [3, 2], [4, 5], [2, 3]],
+            [[5, 4], [1, 0], [4, 5], [0, 1]],
+            [[3, 2], [1, 0], [2, 3], [0, 1]]
+        ], dtype=torch.long, device=device)
+        self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
+
+    def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
+                 weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
+        assert torch.is_tensor(voxelgrid_vertices) and \
+            check_tensor(voxelgrid_vertices, (None, 3), throw=False), \
+            "'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)"
+        num_vertices = voxelgrid_vertices.shape[0]
+        assert torch.is_tensor(scalar_field) and \
+            check_tensor(scalar_field, (num_vertices,), throw=False), \
+            "'scalar_field' should be a tensor of shape (num_vertices,)"
+        assert torch.is_tensor(cube_idx) and \
+            check_tensor(cube_idx, (None, 8), throw=False), \
+            "'cube_idx' should be a tensor of shape (num_cubes, 8)"
+        num_cubes = cube_idx.shape[0]
+        assert beta is None or (
+            torch.is_tensor(beta) and
+            check_tensor(beta, (num_cubes, 12), throw=False)
+        ), "'beta' should be a tensor of shape (num_cubes, 12)"
+        assert alpha is None or (
+            torch.is_tensor(alpha) and
+            check_tensor(alpha, (num_cubes, 8), throw=False)
+        ), "'alpha' should be a tensor of shape (num_cubes, 8)"
+        assert gamma_f is None or (
+            torch.is_tensor(gamma_f) and
+            check_tensor(gamma_f, (num_cubes,), throw=False)
+        ), "'gamma_f' should be a tensor of shape (num_cubes,)"
+
+        surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
+        if surf_cubes.sum() == 0:
+            return (
+                torch.zeros((0, 3), device=self.device),
+                torch.zeros((0, 3), dtype=torch.long, device=self.device),
+                torch.zeros((0), device=self.device),
+                torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
+            )
+        beta, alpha, gamma_f = self._normalize_weights(
+            beta, alpha, gamma_f, surf_cubes, weight_scale)
+        
+        if voxelgrid_colors is not None:
+            voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
+
+        case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
+
+        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
+            scalar_field, cube_idx, surf_cubes
+        )
+
+        vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
+            voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
+            case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
+        vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
+            scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
+            vd_idx_map, surf_edges_mask, training, vd_color)
+        return vertices, faces, L_dev, vertices_color
+
+    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
+        """
+        Regularizer L_dev as in Equation 8
+        """
+        dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
+        mean_l2 = torch.zeros_like(vd[:, 0])
+        mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
+        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
+        return mad
+
+    def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
+        """
+        Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
+        """
+        n_cubes = surf_cubes.shape[0]
+
+        if beta is not None:
+            beta = (torch.tanh(beta) * weight_scale + 1)
+        else:
+            beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
+
+        if alpha is not None:
+            alpha = (torch.tanh(alpha) * weight_scale + 1)
+        else:
+            alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
+
+        if gamma_f is not None:
+            gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
+        else:
+            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
+
+        return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
+
+    @torch.no_grad()
+    def _get_case_id(self, occ_fx8, surf_cubes, res):
+        """
+        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the 
+        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the 
+        supplementary material. It should be noted that this function assumes a regular grid.
+        """
+        case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
+
+        problem_config = self.check_table.to(self.device)[case_ids]
+        to_check = problem_config[..., 0] == 1
+        problem_config = problem_config[to_check]
+        if not isinstance(res, (list, tuple)):
+            res = [res, res, res]
+
+        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
+        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
+        # This allows efficient checking on adjacent cubes.
+        problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
+        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3
+        vol_idx_problem = vol_idx[surf_cubes][to_check]
+        problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
+        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
+
+        within_range = (
+            vol_idx_problem_adj[..., 0] >= 0) & (
+            vol_idx_problem_adj[..., 0] < res[0]) & (
+            vol_idx_problem_adj[..., 1] >= 0) & (
+            vol_idx_problem_adj[..., 1] < res[1]) & (
+            vol_idx_problem_adj[..., 2] >= 0) & (
+            vol_idx_problem_adj[..., 2] < res[2])
+
+        vol_idx_problem = vol_idx_problem[within_range]
+        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
+        problem_config = problem_config[within_range]
+        problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
+                                                 vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
+        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
+        to_invert = (problem_config_adj[..., 0] == 1)
+        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
+        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
+        return case_ids
+
+    @torch.no_grad()
+    def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
+        """
+        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge 
+        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge 
+        and marks the cube edges with this index.
+        """
+        occ_n = scalar_field < 0
+        all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+
+        surf_edges_mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
+        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
+        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
+        idx_map = mapping[_idx_map]
+        surf_edges = unique_edges[mask_edges]
+        return surf_edges, idx_map, counts, surf_edges_mask
+
+    @torch.no_grad()
+    def _identify_surf_cubes(self, scalar_field, cube_idx):
+        """
+        Identifies grid cubes that intersect with the underlying surface by checking if the signs at 
+        all corners are not identical.
+        """
+        occ_n = scalar_field < 0
+        occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
+        _occ_sum = torch.sum(occ_fx8, -1)
+        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
+        return surf_cubes, occ_fx8
+
+    def _linear_interp(self, edges_weight, edges_x):
+        """
+        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
+        """
+        edge_dim = edges_weight.dim() - 2
+        assert edges_weight.shape[edge_dim] == 2
+        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
+                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
+                                 , edge_dim)
+        denominator = edges_weight.sum(edge_dim)
+        ue = (edges_x * edges_weight).sum(edge_dim) / denominator
+        return ue
+
+    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
+        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
+        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
+        c_bx3 = c_bx3.reshape(-1, 3)
+        A = norm_bxnx3
+        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
+
+        A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
+        B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
+        A = torch.cat([A, A_reg], 1)
+        B = torch.cat([B, B_reg], 1)
+        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
+        return dual_verts
+
+    def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
+                    case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
+        """
+        Computes the location of dual vertices as described in Section 4.2
+        """
+        alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
+        surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
+        surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
+        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
+        
+        if voxelgrid_colors is not None:
+            C = voxelgrid_colors.shape[-1]
+            surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
+
+        idx_map = idx_map.reshape(-1, 12)
+        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
+        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
+        
+        # if color is not None:
+        #     vd_color = []
+
+        total_num_vd = 0
+        vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
+
+        for num in torch.unique(num_vd):
+            cur_cubes = (num_vd == num)  # consider cubes with the same numbers of vd emitted (for batching)
+            curr_num_vd = cur_cubes.sum() * num
+            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
+            curr_edge_group_to_vd = torch.arange(
+                curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
+            total_num_vd += curr_num_vd
+            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
+                cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
+
+            curr_mask = (curr_edge_group != -1)
+            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
+            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
+            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
+            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
+            vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
+            # if color is not None:
+            #     vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
+            
+        edge_group = torch.cat(edge_group)
+        edge_group_to_vd = torch.cat(edge_group_to_vd)
+        edge_group_to_cube = torch.cat(edge_group_to_cube)
+        vd_num_edges = torch.cat(vd_num_edges)
+        vd_gamma = torch.cat(vd_gamma)
+        # if color is not None:
+        #     vd_color = torch.cat(vd_color)
+        # else:
+        #     vd_color = None
+
+        vd = torch.zeros((total_num_vd, 3), device=self.device)
+        beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
+
+        idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
+
+        x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
+        s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
+        
+
+        zero_crossing_group = torch.index_select(
+            input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
+
+        alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
+                                            index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
+        ue_group = self._linear_interp(s_group * alpha_group, x_group)
+
+        beta_group = torch.gather(input=beta.reshape(-1), dim=0,
+                                    index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
+        beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
+        vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
+        
+        '''
+        interpolate colors use the same method as dual vertices
+        '''
+        if voxelgrid_colors is not None:
+            vd_color = torch.zeros((total_num_vd, C), device=self.device)
+            c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
+            uc_group = self._linear_interp(s_group * alpha_group, c_group)
+            vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
+        else:
+            vd_color = None
+        
+        L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
+
+        v_idx = torch.arange(vd.shape[0], device=self.device)  # + total_num_vd
+
+        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
+                                                      12 + edge_group, src=v_idx[edge_group_to_vd])
+
+        return vd, L_dev, vd_gamma, vd_idx_map, vd_color
+
+    def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
+        """
+        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into 
+        triangles based on the gamma parameter, as described in Section 4.3.
+        """
+        with torch.no_grad():
+            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.
+            group = idx_map.reshape(-1)[group_mask]
+            vd_idx = vd_idx_map[group_mask]
+            edge_indices, indices = torch.sort(group, stable=True)
+            quad_vd_idx = vd_idx[indices].reshape(-1, 4)
+
+            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
+            s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
+            flip_mask = s_edges[:, 0] > 0
+            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
+                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
+
+        quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
+        gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
+        gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
+        if not training:
+            mask = (gamma_02 > gamma_13)
+            faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
+            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
+            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
+            faces = faces.reshape(-1, 3)
+        else:
+            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
+            vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
+            vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
+            weight_sum = (gamma_02 + gamma_13) + 1e-8
+            vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
+            
+            if vd_color is not None:
+                color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
+                color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
+                color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
+                color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
+                vd_color = torch.cat([vd_color, color_center])
+            
+            
+            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
+            vd = torch.cat([vd, vd_center])
+            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
+            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
+        return vd, faces, s_edges, edge_indices, vd_color
diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c02dd7f4133aef487f623c02b11e3075cab0916
--- /dev/null
+++ b/trellis/representations/mesh/tables.py
@@ -0,0 +1,791 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+dmc_table = [
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
+]
+num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
+2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
+1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
+1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
+2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
+3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
+2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
+1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
+1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
+1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
+check_table = [
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 194],
+[1, -1, 0, 0, 193],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 164],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 161],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 152],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 145],
+[1, 0, 0, 1, 144],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 137],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 133],
+[1, 0, 1, 0, 132],
+[1, 1, 0, 0, 131],
+[1, 1, 0, 0, 130],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 100],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 98],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 96],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 88],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 82],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 74],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 72],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 70],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 67],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 65],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 56],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 52],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 44],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 40],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 38],
+[1, 0, -1, 0, 37],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 33],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 28],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 26],
+[1, 0, 0, -1, 25],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 20],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 18],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 9],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 6],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0]
+]
+tet_table = [
+[-1, -1, -1, -1, -1, -1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, -1],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, -1],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, -1, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, -1, 2, 4, 4, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, 5, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, -1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[-1, 1, 1, 4, 4, 1],
+[0, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[8, 8, 8, 8, 8, 8],
+[1, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 4, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 5, 5, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[6, 6, 6, 6, 6, 6],
+[6, -1, 0, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 4, -1, 6, 4, 6],
+[6, 4, 0, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 2, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 1, 1, 6, -1, 6],
+[6, 1, 1, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 4],
+[2, 2, 2, 2, 2, 2],
+[6, 1, 1, 6, 4, 6],
+[6, 1, 1, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 5, 0, 5, 0, 5],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 5, 5, 5, 5],
+[0, 5, 0, 5, 0, 5],
+[-1, 5, 0, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 5, -1, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[4, 5, 0, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 6, 6, 6, 6, 6],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, -1, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 5, 2, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 4],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 6, 2, 6, 6, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 1, 4, 1],
+[0, 1, 1, 1, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 0, 0, 6, 0, 6],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[4, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[8, 8, 8, 8, 8, 8],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 1, 1, 4, 4, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 4, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[12, 12, 12, 12, 12, 12]
+]
\ No newline at end of file
diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py
new file mode 100644
index 0000000000000000000000000000000000000000..23913c97bb2d57dfa0384667c69f9860ea0a4155
--- /dev/null
+++ b/trellis/representations/mesh/utils_cube.py
@@ -0,0 +1,61 @@
+import torch
+cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
+        1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int)
+cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]])
+cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
+                2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False)
+     
+def construct_dense_grid(res, device='cuda'):
+    '''construct a dense grid based on resolution'''
+    res_v = res + 1
+    vertsid = torch.arange(res_v ** 3, device=device)
+    coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten()
+    cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
+    cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device))
+    verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1)
+    return verts, cube_fx8
+
+
+def construct_voxel_grid(coords):
+    verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3)
+    verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True)
+    cubes = inverse_indices.reshape(-1, 8)
+    return verts_unique, cubes
+
+
+def cubes_to_verts(num_verts, cubes, value, reduce='mean'):
+    """
+    Args:
+        cubes [Vx8] verts index for each cube
+        value [Vx8xM] value to be scattered
+    Operation:
+        reduced[cubes[i][j]][k] += value[i][k]
+    """
+    M = value.shape[2] # number of channels
+    reduced = torch.zeros(num_verts, M, device=cubes.device)
+    return torch.scatter_reduce(reduced, 0, 
+        cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), 
+        value.flatten(0, 1), reduce=reduce, include_self=False)
+    
+def sparse_cube2verts(coords, feats, training=True):
+    new_coords, cubes = construct_voxel_grid(coords)
+    new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats)
+    if training:
+        con_loss = torch.mean((feats - new_feats[cubes]) ** 2)
+    else:
+        con_loss = 0.0
+    return new_coords, new_feats, con_loss
+    
+
+def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True):
+    F = feats.shape[-1]
+    dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device)
+    if sdf_init:
+        dense_attrs[..., 0] = 1 # initial outside sdf value
+    dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats
+    return dense_attrs.reshape(-1, F)
+
+
+def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res):
+    return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)
+        
\ No newline at end of file
diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..f66a39a5a7498e2e99fe9d94d663796b3bc157b5
--- /dev/null
+++ b/trellis/representations/octree/__init__.py
@@ -0,0 +1 @@
+from .octree_dfs import DfsOctree
\ No newline at end of file
diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py
new file mode 100755
index 0000000000000000000000000000000000000000..9d1f7898f30414f304953cfb2d51d00511ec8325
--- /dev/null
+++ b/trellis/representations/octree/octree_dfs.py
@@ -0,0 +1,362 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+DEFAULT_TRIVEC_CONFIG = {
+    'dim': 8,
+    'rank': 8,
+}
+
+DEFAULT_VOXEL_CONFIG = {
+    'solid': False,
+}
+
+DEFAULT_DECOPOLY_CONFIG = {
+    'degree': 8,
+    'rank': 16,
+}
+
+
+class DfsOctree:
+    """
+    Sparse Voxel Octree (SVO) implementation for PyTorch.
+    Using Depth-First Search (DFS) order to store the octree.
+    DFS order suits rendering and ray tracing.
+
+    The structure and data are separatedly stored.
+    Structure is stored as a continuous array, each element is a 3*32 bits descriptor.
+    |-----------------------------------------|
+    |      0:3 bits      |      4:31 bits     |
+    |      leaf num      |       unused       |
+    |-----------------------------------------|
+    |               0:31  bits                |
+    |                child ptr                |
+    |-----------------------------------------|
+    |               0:31  bits                |
+    |                data ptr                 |
+    |-----------------------------------------|
+    Each element represents a non-leaf node in the octree.
+    The valid mask is used to indicate whether the children are valid.
+    The leaf mask is used to indicate whether the children are leaf nodes.
+    The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr.
+    The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr.
+
+    There are also auxiliary arrays to store the additional structural information to facilitate parallel processing.
+      - Position: the position of the octree nodes.
+      - Depth: the depth of the octree nodes.
+
+    Args:
+        depth (int): the depth of the octree.
+    """
+
+    def __init__(
+            self,
+            depth,
+            aabb=[0,0,0,1,1,1],
+            sh_degree=2,
+            primitive='voxel',
+            primitive_config={},
+            device='cuda',
+        ):
+        self.max_depth = depth
+        self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
+        self.device = device
+        self.sh_degree = sh_degree
+        self.active_sh_degree = sh_degree
+        self.primitive = primitive
+        self.primitive_config = primitive_config
+
+        self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device)
+        self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device)
+        self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device)
+        self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device)
+        self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device)
+        self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device)
+        self.depth[:, 0] = 1
+
+        self.data = ['position', 'depth']
+        self.param_names = []
+
+        if primitive == 'voxel':
+            self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
+            self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
+            self.data += ['features_dc', 'features_ac']
+            self.param_names += ['features_dc', 'features_ac']
+            if not primitive_config.get('solid', False):
+                self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
+                self.data.append('density')
+                self.param_names.append('density')
+        elif primitive == 'gaussian':
+            self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
+            self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
+            self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
+            self.data += ['features_dc', 'features_ac', 'opacity']
+            self.param_names += ['features_dc', 'features_ac', 'opacity']
+        elif primitive == 'trivec':
+            self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device)
+            self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
+            self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
+            self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
+            self.density_shift = 0
+            self.data += ['trivec', 'density', 'features_dc', 'features_ac']
+            self.param_names += ['trivec', 'density', 'features_dc', 'features_ac']
+        elif primitive == 'decoupoly':
+            self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device)
+            self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device)
+            self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
+            self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
+            self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
+            self.density_shift = 0
+            self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
+            self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
+
+        self.setup_functions()
+
+    def setup_functions(self):
+        self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x)
+        self.opacity_activation = lambda x: torch.sigmoid(x - 6)
+        self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6
+        self.color_activation = lambda x: torch.sigmoid(x)
+
+    @property
+    def num_non_leaf_nodes(self):
+        return self.structure.shape[0]
+    
+    @property
+    def num_leaf_nodes(self):
+        return self.depth.shape[0]
+
+    @property
+    def cur_depth(self):
+        return self.depth.max().item()
+    
+    @property
+    def occupancy(self):
+        return self.num_leaf_nodes / 8 ** self.cur_depth
+    
+    @property
+    def get_xyz(self):
+        return self.position
+
+    @property
+    def get_depth(self):
+        return self.depth
+
+    @property
+    def get_density(self):
+        if self.primitive == 'voxel' and self.voxel_config['solid']:
+            return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device)
+        return self.density_activation(self.density)
+    
+    @property
+    def get_opacity(self):
+        return self.opacity_activation(self.density)
+
+    @property
+    def get_trivec(self):
+        return self.trivec
+
+    @property
+    def get_decoupoly(self):
+        return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g
+
+    @property
+    def get_color(self):
+        return self.color_activation(self.colors)
+
+    @property
+    def get_features(self):
+        if self.sh_degree == 0:
+            return self.features_dc
+        return torch.cat([self.features_dc, self.features_ac], dim=-2)
+
+    def state_dict(self):
+        ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive}
+        if hasattr(self, 'density_shift'):
+            ret['density_shift'] = self.density_shift
+        for data in set(self.data + self.param_names):
+            if not isinstance(getattr(self, data), nn.Module):
+                ret[data] = getattr(self, data)
+            else:
+                ret[data] = getattr(self, data).state_dict()
+        return ret
+
+    def load_state_dict(self, state_dict):
+        keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth']))
+        for key in keys:
+            if key not in state_dict:
+                print(f"Warning: key {key} not found in the state_dict.")
+                continue
+            try:
+                if not isinstance(getattr(self, key), nn.Module):
+                    setattr(self, key, state_dict[key])
+                else:
+                    getattr(self, key).load_state_dict(state_dict[key])
+            except Exception as e:
+                print(e)
+                raise ValueError(f"Error loading key {key}.")
+
+    def gather_from_leaf_children(self, data):
+        """
+        Gather the data from the leaf children.
+
+        Args:
+            data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
+        """
+        leaf_cnt = self.structure[:, 0]
+        leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)]
+        ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device)
+        for i in range(8):
+            if leaf_cnt_masks[i].sum() == 0:
+                continue
+            start = self.structure[leaf_cnt_masks[i], 2]
+            for j in range(i+1):
+                ret[leaf_cnt_masks[i]] += data[start + j]
+        return ret
+
+    def gather_from_non_leaf_children(self, data):
+        """
+        Gather the data from the non-leaf children.
+
+        Args:
+            data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
+        """
+        non_leaf_cnt = 8 - self.structure[:, 0]
+        non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)]
+        ret = torch.zeros_like(data, device=self.device)
+        for i in range(8):
+            if non_leaf_cnt_masks[i].sum() == 0:
+                continue
+            start = self.structure[non_leaf_cnt_masks[i], 1]
+            for j in range(i+1):
+                ret[non_leaf_cnt_masks[i]] += data[start + j]
+        return ret
+
+    def structure_control(self, mask):
+        """
+        Control the structure of the octree.
+
+        Args:
+            mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep.
+        """
+        # Dont subdivide when the depth is the maximum.
+        mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0)
+        # Dont merge when the depth is the minimum.
+        mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0)
+
+        # Gather control mask
+        structre_ctrl = self.gather_from_leaf_children(mask)
+        structre_ctrl[structre_ctrl==-8] = -1
+
+        new_leaf_num = self.structure[:, 0].clone()
+        # Modify the leaf num.
+        structre_valid = structre_ctrl >= 0
+        new_leaf_num[structre_valid] -= structre_ctrl[structre_valid]                               # Add the new nodes.
+        structre_delete = structre_ctrl < 0
+        merged_nodes = self.gather_from_non_leaf_children(structre_delete.int())
+        new_leaf_num += merged_nodes                                                                # Delete the merged nodes.
+
+        # Update the structure array to allocate new nodes.
+        mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
+        mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid])  # Add the new nodes.
+        mem_offset[:-1] -= structre_delete.int()                                                    # Delete the merged nodes.
+        new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
+        new_structure_length = new_structre_idx[-1].item()
+        new_structre_idx = new_structre_idx[:-1]
+        new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device)
+        new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid]
+
+        # Initialize the new nodes.
+        new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device)
+        new_node_mask[new_structre_idx[structre_valid]] = False
+        new_structure[new_node_mask, 0] = 8                                                         # Initialize to all leaf nodes.
+        new_node_num = new_node_mask.sum().item()
+
+        # Rebuild child ptr.
+        non_leaf_cnt = 8 - new_structure[:, 0]
+        new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]])
+        new_structure[:, 1] = new_child_ptr + 1
+
+        # Rebuild data ptr with old data.
+        leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device)
+        leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0])
+        old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
+
+        # Update the data array
+        subdivide_mask = mask == 1
+        merge_mask = mask == -1
+        data_valid = ~(subdivide_mask | merge_mask)
+        mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
+        mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device))    # Add data array for new nodes
+        mem_offset[:-1] -= subdivide_mask.int()                                                                                         # Delete data elements for subdivide nodes
+        mem_offset[:-1] -= merge_mask.int()                                                                                             # Delete data elements for merge nodes
+        mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid])                                       # Add data elements for merge nodes
+        new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
+        new_data_length = new_data_idx[-1].item()
+        new_data_idx = new_data_idx[:-1]
+        new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data}
+        for data in self.data:
+            new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid]
+
+        # Rebuild data ptr
+        leaf_cnt = new_structure[:, 0]
+        new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
+        new_structure[:, 2] = new_data_ptr
+
+        # Initialize the new data array
+        ## For subdivide nodes
+        if subdivide_mask.sum() > 0:
+            subdivide_data_ptr = new_structure[new_node_mask, 2]
+            for data in self.data:
+                for i in range(8):
+                    if data == 'position':
+                        offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5
+                        scale = 2 ** (-1.0 - self.depth[subdivide_mask])
+                        new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale
+                    elif data == 'depth':
+                        new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1
+                    elif data == 'opacity':
+                        new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask])))
+                    elif data == 'trivec':
+                        offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5
+                        coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1)
+                        axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1)
+                        coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1
+                        new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True)
+                    else:
+                        new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask]
+        ## For merge nodes
+        if merge_mask.sum() > 0:
+            merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device)
+            merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]])
+            for i in range(8):
+                merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i
+            old_merge_data_ptr = self.structure[structre_delete, 2]
+            for data in self.data:
+                if data == 'position':
+                    scale = 2 ** (1.0 - self.depth[old_merge_data_ptr])
+                    new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5
+                elif data == 'depth':
+                    new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1
+                elif data == 'opacity':
+                    new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2)
+                elif data == 'trivec':
+                    new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr]
+                else:
+                    new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr]
+
+        # Update the structure and data array
+        self.structure = new_structure
+        for data in self.data:
+            setattr(self, data, new_data[data])
+
+        # Save data array control temp variables
+        self.data_rearrange_buffer = {
+            'subdivide_mask': subdivide_mask,
+            'merge_mask': merge_mask,
+            'data_valid': data_valid,
+            'new_data_idx': new_data_idx,
+            'new_data_length': new_data_length,
+            'new_data': new_data
+        } 
diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..b72a1b7e76b509ee5a5e6979858eb17b4158a151
--- /dev/null
+++ b/trellis/representations/radiance_field/__init__.py
@@ -0,0 +1 @@
+from .strivec import Strivec
\ No newline at end of file
diff --git a/trellis/representations/radiance_field/strivec.py b/trellis/representations/radiance_field/strivec.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc4b749786d934dae82864b560baccd91fcabbc
--- /dev/null
+++ b/trellis/representations/radiance_field/strivec.py
@@ -0,0 +1,28 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from ..octree import DfsOctree as Octree
+
+
+class Strivec(Octree):
+    def __init__(
+        self,
+        resolution: int,
+        aabb: list,
+        sh_degree: int = 0,
+        rank: int = 8,
+        dim: int = 8,
+        device: str = "cuda",
+    ):
+        assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
+        self.resolution = resolution
+        depth = int(np.round(np.log2(resolution)))
+        super().__init__(
+            depth=depth,
+            aabb=aabb,
+            sh_degree=sh_degree,
+            primitive="trivec",
+            primitive_config={"rank": rank, "dim": dim},
+            device=device,
+        )
diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..3b454d9c75521e33466055fe37c3fc1e37180a79
--- /dev/null
+++ b/trellis/utils/general_utils.py
@@ -0,0 +1,187 @@
+import numpy as np
+import cv2
+import torch
+
+
+# Dictionary utils
+def _dict_merge(dicta, dictb, prefix=''):
+    """
+    Merge two dictionaries.
+    """
+    assert isinstance(dicta, dict), 'input must be a dictionary'
+    assert isinstance(dictb, dict), 'input must be a dictionary'
+    dict_ = {}
+    all_keys = set(dicta.keys()).union(set(dictb.keys()))
+    for key in all_keys:
+        if key in dicta.keys() and key in dictb.keys():
+            if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
+                dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
+            else:
+                raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
+        elif key in dicta.keys():
+            dict_[key] = dicta[key]
+        else:
+            dict_[key] = dictb[key]
+    return dict_
+
+
+def dict_merge(dicta, dictb):
+    """
+    Merge two dictionaries.
+    """
+    return _dict_merge(dicta, dictb, prefix='')
+
+
+def dict_foreach(dic, func, special_func={}):
+    """
+    Recursively apply a function to all non-dictionary leaf values in a dictionary.
+    """
+    assert isinstance(dic, dict), 'input must be a dictionary'
+    for key in dic.keys():
+        if isinstance(dic[key], dict):
+            dic[key] = dict_foreach(dic[key], func)
+        else:
+            if key in special_func.keys():
+                dic[key] = special_func[key](dic[key])
+            else:
+                dic[key] = func(dic[key])
+    return dic
+
+
+def dict_reduce(dicts, func, special_func={}):
+    """
+    Reduce a list of dictionaries. Leaf values must be scalars.
+    """
+    assert isinstance(dicts, list), 'input must be a list of dictionaries'
+    assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
+    assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
+    all_keys = set([key for dict_ in dicts for key in dict_.keys()])
+    reduced_dict = {}
+    for key in all_keys:
+        vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
+        if isinstance(vlist[0], dict):
+            reduced_dict[key] = dict_reduce(vlist, func, special_func)
+        else:
+            if key in special_func.keys():
+                reduced_dict[key] = special_func[key](vlist)
+            else:
+                reduced_dict[key] = func(vlist)
+    return reduced_dict
+
+
+def dict_any(dic, func):
+    """
+    Recursively apply a function to all non-dictionary leaf values in a dictionary.
+    """
+    assert isinstance(dic, dict), 'input must be a dictionary'
+    for key in dic.keys():
+        if isinstance(dic[key], dict):
+            if dict_any(dic[key], func):
+                return True
+        else:
+            if func(dic[key]):
+                return True
+    return False
+
+
+def dict_all(dic, func):
+    """
+    Recursively apply a function to all non-dictionary leaf values in a dictionary.
+    """
+    assert isinstance(dic, dict), 'input must be a dictionary'
+    for key in dic.keys():
+        if isinstance(dic[key], dict):
+            if not dict_all(dic[key], func):
+                return False
+        else:
+            if not func(dic[key]):
+                return False
+    return True
+
+
+def dict_flatten(dic, sep='.'):
+    """
+    Flatten a nested dictionary into a dictionary with no nested dictionaries.
+    """
+    assert isinstance(dic, dict), 'input must be a dictionary'
+    flat_dict = {}
+    for key in dic.keys():
+        if isinstance(dic[key], dict):
+            sub_dict = dict_flatten(dic[key], sep=sep)
+            for sub_key in sub_dict.keys():
+                flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
+        else:
+            flat_dict[key] = dic[key]
+    return flat_dict
+
+
+def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
+    num_images = len(images)
+    if nrow is None and ncol is None:
+        if aspect_ratio is not None:
+            nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
+        else:
+            nrow = int(np.sqrt(num_images))
+        ncol = (num_images + nrow - 1) // nrow
+    elif nrow is None and ncol is not None:
+        nrow = (num_images + ncol - 1) // ncol
+    elif nrow is not None and ncol is None:
+        ncol = (num_images + nrow - 1) // nrow
+    else:
+        assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
+        
+    grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
+    for i, img in enumerate(images):
+        row = i // ncol
+        col = i % ncol
+        grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
+    return grid
+
+
+def notes_on_image(img, notes=None):
+    img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
+    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+    if notes is not None:
+        img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    return img
+
+
+def save_image_with_notes(img, path, notes=None):
+    """
+    Save an image with notes.
+    """
+    if isinstance(img, torch.Tensor):
+        img = img.cpu().numpy().transpose(1, 2, 0)
+    if img.dtype == np.float32 or img.dtype == np.float64:
+        img = np.clip(img * 255, 0, 255).astype(np.uint8)
+    img = notes_on_image(img, notes)
+    cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+
+# debug utils
+
+def atol(x, y):
+    """
+    Absolute tolerance.
+    """
+    return torch.abs(x - y)
+
+
+def rtol(x, y):
+    """
+    Relative tolerance.
+    """
+    return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
+
+
+# print utils
+def indent(s, n=4):
+    """
+    Indent a string.
+    """
+    lines = s.split('\n')
+    for i in range(1, len(lines)):
+        lines[i] = ' ' * n + lines[i]
+    return '\n'.join(lines)
+
diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0709033d2f5b784767b028b51aa5d4014f9db34
--- /dev/null
+++ b/trellis/utils/postprocessing_utils.py
@@ -0,0 +1,458 @@
+from typing import *
+import numpy as np
+import torch
+import utils3d
+import nvdiffrast.torch as dr
+from tqdm import tqdm
+import trimesh
+import trimesh.visual
+import xatlas
+import pyvista as pv
+from pymeshfix import _meshfix
+import igraph
+import cv2
+from PIL import Image
+from .random_utils import sphere_hammersley_sequence
+from .render_utils import render_multiview
+from ..representations import Strivec, Gaussian, MeshExtractResult
+
+
+@torch.no_grad()
+def _fill_holes(
+    verts,
+    faces,
+    max_hole_size=0.04,
+    max_hole_nbe=32,
+    resolution=128,
+    num_views=500,
+    debug=False,
+    verbose=False
+):
+    """
+    Rasterize a mesh from multiple views and remove invisible faces.
+    Also includes postprocessing to:
+        1. Remove connected components that are have low visibility.
+        2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
+
+    Args:
+        verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
+        faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
+        max_hole_size (float): Maximum area of a hole to fill.
+        resolution (int): Resolution of the rasterization.
+        num_views (int): Number of views to rasterize the mesh.
+        verbose (bool): Whether to print progress.
+    """
+    # Construct cameras
+    yaws = []
+    pitchs = []
+    for i in range(num_views):
+        y, p = sphere_hammersley_sequence(i, num_views)
+        yaws.append(y)
+        pitchs.append(p)
+    yaws = torch.tensor(yaws).cuda()
+    pitchs = torch.tensor(pitchs).cuda()
+    radius = 2.0
+    fov = torch.deg2rad(torch.tensor(40)).cuda()
+    projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
+    views = []
+    for (yaw, pitch) in zip(yaws, pitchs):
+        orig = torch.tensor([
+            torch.sin(yaw) * torch.cos(pitch),
+            torch.cos(yaw) * torch.cos(pitch),
+            torch.sin(pitch),
+        ]).cuda().float() * radius
+        view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+        views.append(view)
+    views = torch.stack(views, dim=0)
+
+    # Rasterize
+    visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
+    rastctx = utils3d.torch.RastContext(backend='cuda')
+    for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
+        view = views[i]
+        buffers = utils3d.torch.rasterize_triangle_faces(
+            rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
+        )
+        face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
+        face_id = torch.unique(face_id).long()
+        visblity[face_id] += 1
+    visblity = visblity.float() / num_views
+    
+    # Mincut
+    ## construct outer faces
+    edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
+    boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
+    connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
+    outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
+    for i in range(len(connected_components)):
+        outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
+    outer_face_indices = outer_face_indices.nonzero().reshape(-1)
+    
+    ## construct inner faces
+    inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
+    if verbose:
+        tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
+    if inner_face_indices.shape[0] == 0:
+        return verts, faces
+    
+    ## Construct dual graph (faces as nodes, edges as edges)
+    dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
+    dual_edge2edge = edges[dual_edge2edge]
+    dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
+    if verbose:
+        tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
+
+    ## solve mincut problem
+    ### construct main graph
+    g = igraph.Graph()
+    g.add_vertices(faces.shape[0])
+    g.add_edges(dual_edges.cpu().numpy())
+    g.es['weight'] = dual_edges_weights.cpu().numpy()
+    
+    ### source and target
+    g.add_vertex('s')
+    g.add_vertex('t')
+    
+    ### connect invisible faces to source
+    g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
+    
+    ### connect outer faces to target
+    g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
+                
+    ### solve mincut
+    cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
+    remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
+    if verbose:
+        tqdm.write(f'Mincut solved, start checking the cut')
+    
+    ### check if the cut is valid with each connected component
+    to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
+    if debug:
+        tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
+    valid_remove_cc = []
+    cutting_edges = []
+    for cc in to_remove_cc:
+        #### check if the connected component has low visibility
+        visblity_median = visblity[remove_face_indices[cc]].median()
+        if debug:
+            tqdm.write(f'visblity_median: {visblity_median}')
+        if visblity_median > 0.25:
+            continue
+        
+        #### check if the cuting loop is small enough
+        cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
+        cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
+        cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
+        if len(cc_new_boundary_edge_indices) > 0:
+            cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
+            cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
+            cc_new_boundary_edges_cc_area = []
+            for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
+                _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
+                _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
+                cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
+            if debug:
+                cutting_edges.append(cc_new_boundary_edge_indices)
+                tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
+            if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
+                continue
+            
+        valid_remove_cc.append(cc)
+        
+    if debug:
+        face_v = verts[faces].mean(dim=1).cpu().numpy()
+        vis_dual_edges = dual_edges.cpu().numpy()
+        vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
+        vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
+        vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
+        vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
+        if len(valid_remove_cc) > 0:
+            vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
+        utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
+        
+        vis_verts = verts.cpu().numpy()
+        vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
+        utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
+        
+    
+    if len(valid_remove_cc) > 0:
+        remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
+        mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
+        mask[remove_face_indices] = 0
+        faces = faces[mask]
+        faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
+        if verbose:
+            tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
+    else:
+        if verbose:
+            tqdm.write(f'Removed 0 faces by mincut')
+            
+    mesh = _meshfix.PyTMesh()
+    mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
+    mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
+    verts, faces = mesh.return_arrays()
+    verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
+
+    return verts, faces
+
+
+def postprocess_mesh(
+    vertices: np.array,
+    faces: np.array,
+    simplify: bool = True,
+    simplify_ratio: float = 0.9,
+    fill_holes: bool = True,
+    fill_holes_max_hole_size: float = 0.04,
+    fill_holes_max_hole_nbe: int = 32,
+    fill_holes_resolution: int = 1024,
+    fill_holes_num_views: int = 1000,
+    debug: bool = False,
+    verbose: bool = False,
+):
+    """
+    Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
+
+    Args:
+        vertices (np.array): Vertices of the mesh. Shape (V, 3).
+        faces (np.array): Faces of the mesh. Shape (F, 3).
+        simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
+        simplify_ratio (float): Ratio of faces to keep after simplification.
+        fill_holes (bool): Whether to fill holes in the mesh.
+        fill_holes_max_hole_size (float): Maximum area of a hole to fill.
+        fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
+        fill_holes_resolution (int): Resolution of the rasterization.
+        fill_holes_num_views (int): Number of views to rasterize the mesh.
+        verbose (bool): Whether to print progress.
+    """
+
+    if verbose:
+        tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+    # Simplify
+    if simplify and simplify_ratio > 0:
+        mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
+        mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
+        vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
+        if verbose:
+            tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+    # Remove invisible faces
+    if fill_holes:
+        vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
+        vertices, faces = _fill_holes(
+            vertices, faces,
+            max_hole_size=fill_holes_max_hole_size,
+            max_hole_nbe=fill_holes_max_hole_nbe,
+            resolution=fill_holes_resolution,
+            num_views=fill_holes_num_views,
+            debug=debug,
+            verbose=verbose,
+        )
+        vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
+        if verbose:
+            tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
+
+    return vertices, faces
+
+
+def parametrize_mesh(vertices: np.array, faces: np.array):
+    """
+    Parametrize a mesh to a texture space, using xatlas.
+
+    Args:
+        vertices (np.array): Vertices of the mesh. Shape (V, 3).
+        faces (np.array): Faces of the mesh. Shape (F, 3).
+    """
+
+    vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
+
+    vertices = vertices[vmapping]
+    faces = indices
+
+    return vertices, faces, uvs
+
+
+def bake_texture(
+    vertices: np.array,
+    faces: np.array,
+    uvs: np.array,
+    observations: List[np.array],
+    masks: List[np.array],
+    extrinsics: List[np.array],
+    intrinsics: List[np.array],
+    texture_size: int = 2048,
+    near: float = 0.1,
+    far: float = 10.0,
+    mode: Literal['fast', 'opt'] = 'opt',
+    lambda_tv: float = 1e-2,
+    verbose: bool = False,
+):
+    """
+    Bake texture to a mesh from multiple observations.
+
+    Args:
+        vertices (np.array): Vertices of the mesh. Shape (V, 3).
+        faces (np.array): Faces of the mesh. Shape (F, 3).
+        uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
+        observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
+        masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
+        extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
+        intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
+        texture_size (int): Size of the texture.
+        near (float): Near plane of the camera.
+        far (float): Far plane of the camera.
+        mode (Literal['fast', 'opt']): Mode of texture baking.
+        lambda_tv (float): Weight of total variation loss in optimization.
+        verbose (bool): Whether to print progress.
+    """
+    vertices = torch.tensor(vertices).cuda()
+    faces = torch.tensor(faces.astype(np.int32)).cuda()
+    uvs = torch.tensor(uvs).cuda()
+    observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
+    masks = [torch.tensor(m>0).bool().cuda() for m in masks]
+    views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
+    projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
+
+    if mode == 'fast':
+        texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
+        texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
+        rastctx = utils3d.torch.RastContext(backend='cuda')
+        for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
+            with torch.no_grad():
+                rast = utils3d.torch.rasterize_triangle_faces(
+                    rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
+                )
+                uv_map = rast['uv'][0].detach().flip(0)
+                mask = rast['mask'][0].detach().bool() & masks[0]
+            
+            # nearest neighbor interpolation
+            uv_map = (uv_map * texture_size).floor().long()
+            obs = observation[mask]
+            uv_map = uv_map[mask]
+            idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
+            texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
+            texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
+
+        mask = texture_weights > 0
+        texture[mask] /= texture_weights[mask][:, None]
+        texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
+
+        # inpaint
+        mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
+        texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
+
+    elif mode == 'opt':
+        rastctx = utils3d.torch.RastContext(backend='cuda')
+        observations = [observations.flip(0) for observations in observations]
+        masks = [m.flip(0) for m in masks]
+        _uv = []
+        _uv_dr = []
+        for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
+            with torch.no_grad():
+                rast = utils3d.torch.rasterize_triangle_faces(
+                    rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
+                )
+                _uv.append(rast['uv'].detach())
+                _uv_dr.append(rast['uv_dr'].detach())
+
+        texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
+        optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
+
+        def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
+            return start_lr * (end_lr / start_lr) ** (step / total_steps)
+
+        def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
+            return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
+        
+        def tv_loss(texture):
+            return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
+                   torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
+    
+        total_steps = 2500
+        with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
+            for step in range(total_steps):
+                optimizer.zero_grad()
+                selected = np.random.randint(0, len(views))
+                uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
+                render = dr.texture(texture, uv, uv_dr)[0]
+                loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
+                if lambda_tv > 0:
+                    loss += lambda_tv * tv_loss(texture)
+                loss.backward()
+                optimizer.step()
+                # annealing
+                optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
+                pbar.set_postfix({'loss': loss.item()})
+                pbar.update()
+        texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
+        mask = 1 - utils3d.torch.rasterize_triangle_faces(
+            rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
+        )['mask'][0].detach().cpu().numpy().astype(np.uint8)
+        texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
+    else:
+        raise ValueError(f'Unknown mode: {mode}')
+
+    return texture
+
+
+def to_glb(
+    app_rep: Union[Strivec, Gaussian],
+    mesh: MeshExtractResult,
+    simplify: float = 0.95,
+    fill_holes: bool = True,
+    fill_holes_max_size: float = 0.04,
+    texture_size: int = 1024,
+    debug: bool = False,
+    verbose: bool = True,
+) -> trimesh.Trimesh:
+    """
+    Convert a generated asset to a glb file.
+
+    Args:
+        app_rep (Union[Strivec, Gaussian]): Appearance representation.
+        mesh (MeshExtractResult): Extracted mesh.
+        simplify (float): Ratio of faces to remove in simplification.
+        fill_holes (bool): Whether to fill holes in the mesh.
+        fill_holes_max_size (float): Maximum area of a hole to fill.
+        texture_size (int): Size of the texture.
+        debug (bool): Whether to print debug information.
+        verbose (bool): Whether to print progress.
+    """
+    vertices = mesh.vertices.cpu().numpy()
+    faces = mesh.faces.cpu().numpy()
+    
+    # mesh postprocess
+    vertices, faces = postprocess_mesh(
+        vertices, faces,
+        simplify=simplify > 0,
+        simplify_ratio=simplify,
+        fill_holes=fill_holes,
+        fill_holes_max_hole_size=fill_holes_max_size,
+        fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
+        fill_holes_resolution=1024,
+        fill_holes_num_views=1000,
+        debug=debug,
+        verbose=verbose,
+    )
+
+    # parametrize mesh
+    vertices, faces, uvs = parametrize_mesh(vertices, faces)
+
+    # bake texture
+    observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
+    masks = [np.any(observation > 0, axis=-1) for observation in observations]
+    extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
+    intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
+    texture = bake_texture(
+        vertices, faces, uvs,
+        observations, masks, extrinsics, intrinsics,
+        texture_size=texture_size, mode='opt',
+        lambda_tv=0.01,
+        verbose=True
+    )
+    texture = Image.fromarray(texture)
+
+    # rotate mesh (from z-up to y-up)
+    vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
+    mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture))
+    return mesh
diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b668c277b51f4930991912a80573adc79364028
--- /dev/null
+++ b/trellis/utils/random_utils.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
+
+def radical_inverse(base, n):
+    val = 0
+    inv_base = 1.0 / base
+    inv_base_n = inv_base
+    while n > 0:
+        digit = n % base
+        val += digit * inv_base_n
+        n //= base
+        inv_base_n *= inv_base
+    return val
+
+def halton_sequence(dim, n):
+    return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
+
+def hammersley_sequence(dim, n, num_samples):
+    return [n / num_samples] + halton_sequence(dim - 1, n)
+
+def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
+    u, v = hammersley_sequence(2, n, num_samples)
+    u += offset[0] / num_samples
+    v += offset[1]
+    if remap:
+        u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
+    theta = np.arccos(1 - 2 * u) - np.pi / 2
+    phi = v * 2 * np.pi
+    return [phi, theta]
\ No newline at end of file
diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8187c84f305d51540e88ae5b634a484a74c16e95
--- /dev/null
+++ b/trellis/utils/render_utils.py
@@ -0,0 +1,116 @@
+import torch
+import numpy as np
+from tqdm import tqdm
+import utils3d
+from PIL import Image
+
+from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
+from ..representations import Octree, Gaussian, MeshExtractResult
+from ..modules import sparse as sp
+from .random_utils import sphere_hammersley_sequence
+
+
+def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
+    is_list = isinstance(yaws, list)
+    if not is_list:
+        yaws = [yaws]
+        pitchs = [pitchs]
+    if not isinstance(rs, list):
+        rs = [rs] * len(yaws)
+    if not isinstance(fovs, list):
+        fovs = [fovs] * len(yaws)
+    extrinsics = []
+    intrinsics = []
+    for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
+        fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
+        yaw = torch.tensor(float(yaw)).cuda()
+        pitch = torch.tensor(float(pitch)).cuda()
+        orig = torch.tensor([
+            torch.sin(yaw) * torch.cos(pitch),
+            torch.cos(yaw) * torch.cos(pitch),
+            torch.sin(pitch),
+        ]).cuda() * r
+        extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
+        intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
+        extrinsics.append(extr)
+        intrinsics.append(intr)
+    if not is_list:
+        extrinsics = extrinsics[0]
+        intrinsics = intrinsics[0]
+    return extrinsics, intrinsics
+
+
+def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
+    if isinstance(sample, Octree):
+        renderer = OctreeRenderer()
+        renderer.rendering_options.resolution = options.get('resolution', 512)
+        renderer.rendering_options.near = options.get('near', 0.8)
+        renderer.rendering_options.far = options.get('far', 1.6)
+        renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
+        renderer.rendering_options.ssaa = options.get('ssaa', 4)
+        renderer.pipe.primitive = sample.primitive
+    elif isinstance(sample, Gaussian):
+        renderer = GaussianRenderer()
+        renderer.rendering_options.resolution = options.get('resolution', 512)
+        renderer.rendering_options.near = options.get('near', 0.8)
+        renderer.rendering_options.far = options.get('far', 1.6)
+        renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0))
+        renderer.rendering_options.ssaa = options.get('ssaa', 1)
+        renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
+        renderer.pipe.use_mip_gaussian = True
+    elif isinstance(sample, MeshExtractResult):
+        renderer = MeshRenderer()
+        renderer.rendering_options.resolution = options.get('resolution', 512)
+        renderer.rendering_options.near = options.get('near', 1)
+        renderer.rendering_options.far = options.get('far', 100)
+        renderer.rendering_options.ssaa = options.get('ssaa', 4)
+    else:
+        raise ValueError(f'Unsupported sample type: {type(sample)}')
+    
+    rets = {}
+    for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
+        if not isinstance(sample, MeshExtractResult):
+            res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
+            if 'color' not in rets: rets['color'] = []
+            if 'depth' not in rets: rets['depth'] = []
+            rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
+            if 'percent_depth' in res:
+                rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
+            elif 'depth' in res:
+                rets['depth'].append(res['depth'].detach().cpu().numpy())
+            else:
+                rets['depth'].append(None)
+        else:
+            res = renderer.render(sample, extr, intr)
+            if 'normal' not in rets: rets['normal'] = []
+            rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
+    return rets
+
+
+def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
+    yaws = torch.linspace(0, 2 * 3.1415, num_frames)
+    pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
+    yaws = yaws.tolist()
+    pitch = pitch.tolist()
+    extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
+    return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
+
+
+def render_multiview(sample, resolution=512, nviews=30):
+    r = 2
+    fov = 40
+    cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
+    yaws = [cam[0] for cam in cams]
+    pitchs = [cam[1] for cam in cams]
+    extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
+    res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
+    return res['color'], extrinsics, intrinsics
+
+
+def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
+    yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
+    yaw_offset = offset[0]
+    yaw = [y + yaw_offset for y in yaw]
+    pitch = [offset[1] for _ in range(4)]
+    extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
+    return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
diff --git a/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
new file mode 100644
index 0000000000000000000000000000000000000000..5a7f93e9f56294b4e4079c7c6f32902f3c5890a8
Binary files /dev/null and b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl differ
diff --git a/wheels/nvdiffrast-0.3.3-py3-none-any.whl b/wheels/nvdiffrast-0.3.3-py3-none-any.whl
new file mode 100644
index 0000000000000000000000000000000000000000..cb56d6d359394e2829cd6a433f6f938a49ee733f
Binary files /dev/null and b/wheels/nvdiffrast-0.3.3-py3-none-any.whl differ