cronos3k commited on
Commit
876b58d
·
verified ·
1 Parent(s): 9b222ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -185
app.py CHANGED
@@ -1,80 +1,4 @@
1
- import gradio as gr
2
- import os
3
- import shutil
4
- os.environ['SPCONV_ALGO'] = 'native'
5
- from typing import *
6
- import torch
7
- import numpy as np
8
- import imageio
9
- import uuid
10
- from easydict import EasyDict as edict
11
- from PIL import Image
12
- from trellis.pipelines import TrellisImageTo3DPipeline
13
- from trellis.representations import Gaussian, MeshExtractResult
14
- from trellis.utils import render_utils, postprocessing_utils
15
- from gradio_litmodel3d import LitModel3D
16
-
17
-
18
- MAX_SEED = np.iinfo(np.int32).max
19
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
20
- os.makedirs(TMP_DIR, exist_ok=True)
21
-
22
-
23
- def start_session(req: gr.Request):
24
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
25
- print(f'Creating user directory: {user_dir}')
26
- os.makedirs(user_dir, exist_ok=True)
27
-
28
- def end_session(req: gr.Request):
29
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
- print(f'Removing user directory: {user_dir}')
31
- shutil.rmtree(user_dir)
32
-
33
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
34
- processed_image = pipeline.preprocess_image(image)
35
- return processed_image
36
-
37
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
38
- return {
39
- 'gaussian': {
40
- **gs.init_params,
41
- '_xyz': gs._xyz.cpu().numpy(),
42
- '_features_dc': gs._features_dc.cpu().numpy(),
43
- '_scaling': gs._scaling.cpu().numpy(),
44
- '_rotation': gs._rotation.cpu().numpy(),
45
- '_opacity': gs._opacity.cpu().numpy(),
46
- },
47
- 'mesh': {
48
- 'vertices': mesh.vertices.cpu().numpy(),
49
- 'faces': mesh.faces.cpu().numpy(),
50
- },
51
- 'trial_id': trial_id,
52
- }
53
-
54
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
55
- gs = Gaussian(
56
- aabb=state['gaussian']['aabb'],
57
- sh_degree=state['gaussian']['sh_degree'],
58
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
59
- scaling_bias=state['gaussian']['scaling_bias'],
60
- opacity_bias=state['gaussian']['opacity_bias'],
61
- scaling_activation=state['gaussian']['scaling_activation'],
62
- )
63
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
64
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
65
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
66
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
67
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
68
-
69
- mesh = edict(
70
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
71
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
72
- )
73
-
74
- return gs, mesh, state['trial_id']
75
-
76
- def get_seed(randomize_seed: bool, seed: int) -> int:
77
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
78
 
79
  def image_to_3d(
80
  image: Image.Image,
@@ -84,7 +8,7 @@ def image_to_3d(
84
  slat_guidance_strength: float,
85
  slat_sampling_steps: int,
86
  req: gr.Request,
87
- ) -> Tuple[dict, str, str, str]:
88
  """
89
  Convert an image to a 3D model.
90
  """
@@ -125,96 +49,9 @@ def image_to_3d(
125
  glb.export(full_glb_path)
126
 
127
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
128
- return state, video_path, model_output, full_glb_path
129
 
130
- def extract_glb(
131
- state: dict,
132
- mesh_simplify: float,
133
- texture_size: int,
134
- req: gr.Request,
135
- ) -> Tuple[str, str]:
136
- """
137
- Extract a reduced GLB file from the 3D model.
138
- """
139
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
140
- gs, mesh, trial_id = unpack_state(state)
141
- glb = postprocessing_utils.to_glb(
142
- gs, mesh,
143
- simplify=mesh_simplify,
144
- fill_holes=True,
145
- fill_holes_max_size=0.04,
146
- texture_size=texture_size,
147
- verbose=False
148
- )
149
- glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
150
- glb.export(glb_path)
151
- return glb_path, glb_path
152
-
153
- with gr.Blocks(delete_cache=(600, 600)) as demo:
154
- gr.Markdown("""
155
- ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
156
- * Upload an image and click "Generate" to create a 3D asset
157
- * After generation, you can:
158
- * Download the full quality GLB immediately
159
- * Create a reduced size version with the extraction settings below
160
- """)
161
-
162
- with gr.Row():
163
- with gr.Column():
164
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
165
-
166
- with gr.Accordion(label="Generation Settings", open=False):
167
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
168
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
169
- gr.Markdown("Stage 1: Sparse Structure Generation")
170
- with gr.Row():
171
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
172
- ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
173
- gr.Markdown("Stage 2: Structured Latent Generation")
174
- with gr.Row():
175
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
176
- slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
177
-
178
- generate_btn = gr.Button("Generate")
179
-
180
- with gr.Accordion(label="GLB Extraction Settings", open=False):
181
- mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
182
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
183
-
184
- extract_glb_btn = gr.Button("Extract Reduced GLB", interactive=False)
185
-
186
- with gr.Column():
187
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
188
- model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
189
- with gr.Row():
190
- download_full = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
191
- download_reduced = gr.DownloadButton(label="Download Reduced GLB", interactive=False)
192
-
193
- output_buf = gr.State()
194
-
195
- # Example images
196
- with gr.Row():
197
- examples = gr.Examples(
198
- examples=[
199
- f'assets/example_image/{image}'
200
- for image in os.listdir("assets/example_image")
201
- ],
202
- inputs=[image_prompt],
203
- fn=preprocess_image,
204
- outputs=[image_prompt],
205
- run_on_click=True,
206
- examples_per_page=64,
207
- )
208
-
209
- # Event handlers
210
- demo.load(start_session)
211
- demo.unload(end_session)
212
-
213
- image_prompt.upload(
214
- preprocess_image,
215
- inputs=[image_prompt],
216
- outputs=[image_prompt],
217
- )
218
 
219
  generate_btn.click(
220
  get_seed,
@@ -223,26 +60,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
223
  ).then(
224
  image_to_3d,
225
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
226
- outputs=[output_buf, video_output, model_output, download_full],
227
  ).then(
228
  lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
229
  outputs=[download_full, extract_glb_btn, download_reduced],
230
  )
231
 
232
- extract_glb_btn.click(
233
- extract_glb,
234
- inputs=[output_buf, mesh_simplify, texture_size],
235
- outputs=[model_output, download_reduced],
236
- ).then(
237
- lambda: gr.Button(interactive=True),
238
- outputs=[download_reduced],
239
- )
240
-
241
- if __name__ == "__main__":
242
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
243
- pipeline.cuda()
244
- try:
245
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
246
- except:
247
- pass
248
- demo.launch()
 
1
+ # [Previous imports and utility functions remain exactly the same until image_to_3d]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def image_to_3d(
4
  image: Image.Image,
 
8
  slat_guidance_strength: float,
9
  slat_sampling_steps: int,
10
  req: gr.Request,
11
+ ) -> Tuple[dict, str, str]:
12
  """
13
  Convert an image to a 3D model.
14
  """
 
49
  glb.export(full_glb_path)
50
 
51
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
52
+ return state, video_path, full_glb_path
53
 
54
+ # [Rest of the code remains exactly the same, except for the event handler which needs to be updated]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  generate_btn.click(
57
  get_seed,
 
60
  ).then(
61
  image_to_3d,
62
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
63
+ outputs=[output_buf, video_output, download_full],
64
  ).then(
65
  lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
66
  outputs=[download_full, extract_glb_btn, download_reduced],
67
  )
68
 
69
+ # [Rest of the code remains exactly the same]