Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,235 Bytes
db6a3b7 3057b36 7d475c1 db6a3b7 cd41f5f 690b53e db6a3b7 9880f3d 7d475c1 db6a3b7 9880f3d db6a3b7 9880f3d db6a3b7 258ea5a bd46f72 cd41f5f d7b1815 258ea5a cd41f5f 258ea5a cd41f5f 258ea5a db894f7 a481d7a 342cabd 16dfcc8 3057b36 cd41f5f 258ea5a db6a3b7 258ea5a db6a3b7 258ea5a db6a3b7 cd41f5f db894f7 cd41f5f db894f7 bd46f72 258ea5a 7d475c1 15fe7bc 258ea5a cd41f5f 7d475c1 258ea5a c260ece 258ea5a c260ece 258ea5a c260ece cd41f5f 7d475c1 258ea5a 7d475c1 a481d7a db6a3b7 342cabd a481d7a bd46f72 342cabd 258ea5a bd46f72 342cabd 258ea5a bd46f72 342cabd bd46f72 9173005 db6a3b7 258ea5a a481d7a 258ea5a db6a3b7 2e7f188 cd41f5f db6a3b7 258ea5a cd41f5f a481d7a db6a3b7 cd41f5f db6a3b7 cd41f5f db6a3b7 258ea5a 4241cf4 258ea5a 4241cf4 db6a3b7 a481d7a db6a3b7 c666caf 258ea5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
import uuid
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
print(f'Creating user directory: {user_dir}')
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
print(f'Removing user directory: {user_dir}')
shutil.rmtree(user_dir)
def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
processed_image = pipeline.preprocess_image(image)
return processed_image
def get_seed(randomize_seed: bool, seed: int) -> int:
"""Get the random seed."""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def image_to_3d(
image: Image.Image,
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
req: gr.Request,
) -> Tuple[str, str, str]:
"""
Convert an image to a 3D model and save both video preview and full-quality GLB.
Returns:
Tuple[str, str, str]: (video_path, glb_path, download_path)
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
outputs = pipeline.run(
image,
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
)
# Generate and save video preview
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
trial_id = str(uuid.uuid4())
video_path = os.path.join(user_dir, f"{trial_id}.mp4")
imageio.mimsave(video_path, video, fps=15)
# Save full-quality GLB directly from the generated mesh
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
simplify=0.0, # No simplification
texture_size=2048, # Maximum texture resolution
verbose=False
)
glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
glb.export(glb_path)
torch.cuda.empty_cache()
return video_path, glb_path, glb_path
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
* Upload an image and click "Generate" to create a high-quality 3D model
* Once generation is complete, you can download the full-quality GLB file
* The model will be in maximum quality with no reduction applied
""")
with gr.Row():
with gr.Column():
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
generate_btn = gr.Button("Generate")
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
download_glb = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
# Example images
with gr.Row():
examples = gr.Examples(
examples=[
f'assets/example_image/{image}'
for image in os.listdir("assets/example_image")
],
inputs=[image_prompt],
fn=preprocess_image,
outputs=[image_prompt],
run_on_click=True,
examples_per_page=64,
)
# Event handlers
demo.load(start_session)
demo.unload(end_session)
image_prompt.upload(
preprocess_image,
inputs=[image_prompt],
outputs=[image_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
outputs=[video_output, model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
# Launch the Gradio app
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
except:
pass
demo.launch() |