cronos3k commited on
Commit
c260ece
·
verified ·
1 Parent(s): 857bb76

Update app.py

Browse files

making the high quality source mesh available for download

Files changed (1) hide show
  1. app.py +195 -41
app.py CHANGED
@@ -17,24 +17,27 @@ from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
 
 
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
-
30
-
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
  print(f'Removing user directory: {user_dir}')
34
  shutil.rmtree(user_dir)
35
 
36
 
37
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
 
38
  """
39
  Preprocess the input image.
40
 
@@ -42,13 +45,13 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
42
  image (Image.Image): The input image.
43
 
44
  Returns:
45
- str: uuid of the trial.
46
  Image.Image: The preprocessed image.
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
50
 
51
 
 
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
53
  return {
54
  'gaussian': {
@@ -65,8 +68,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
65
  },
66
  'trial_id': trial_id,
67
  }
68
-
69
-
70
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
  gs = Gaussian(
72
  aabb=state['gaussian']['aabb'],
@@ -90,13 +93,22 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
90
  return gs, mesh, state['trial_id']
91
 
92
 
 
93
  def get_seed(randomize_seed: bool, seed: int) -> int:
94
  """
95
  Get the random seed.
 
 
 
 
 
 
 
96
  """
97
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
98
 
99
 
 
100
  @spaces.GPU
101
  def image_to_3d(
102
  image: Image.Image,
@@ -117,10 +129,10 @@ def image_to_3d(
117
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
119
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
120
 
121
  Returns:
122
- dict: The information of the generated 3D model.
123
- str: The path to the video of the 3D model.
124
  """
125
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
126
  outputs = pipeline.run(
@@ -148,6 +160,7 @@ def image_to_3d(
148
  return state, video_path
149
 
150
 
 
151
  @spaces.GPU
152
  def extract_glb(
153
  state: dict,
@@ -162,9 +175,10 @@ def extract_glb(
162
  state (dict): The state of the generated 3D model.
163
  mesh_simplify (float): The mesh simplification factor.
164
  texture_size (int): The texture resolution.
 
165
 
166
  Returns:
167
- str: The path to the extracted GLB file.
168
  """
169
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
  gs, mesh, trial_id = unpack_state(state)
@@ -175,45 +189,158 @@ def extract_glb(
175
  return glb_path, glb_path
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Blocks(delete_cache=(600, 600)) as demo:
179
  gr.Markdown("""
180
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
- * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
182
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
183
  """)
184
-
185
  with gr.Row():
186
  with gr.Column():
187
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
 
 
 
 
 
188
 
 
189
  with gr.Accordion(label="Generation Settings", open=False):
190
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
191
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
192
- gr.Markdown("Stage 1: Sparse Structure Generation")
 
 
 
 
 
 
 
 
 
193
  with gr.Row():
194
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
195
- ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
196
- gr.Markdown("Stage 2: Structured Latent Generation")
 
 
 
 
 
 
 
 
 
 
 
 
197
  with gr.Row():
198
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
199
- slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
200
-
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  generate_btn = gr.Button("Generate")
202
 
 
203
  with gr.Accordion(label="GLB Extraction Settings", open=False):
204
- mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
205
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
207
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
 
 
208
 
209
  with gr.Column():
210
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
211
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
212
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  output_buf = gr.State()
 
 
215
 
216
- # Example images at the bottom of the page
217
  with gr.Row():
218
  examples = gr.Examples(
219
  examples=[
@@ -227,51 +354,78 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
227
  examples_per_page=64,
228
  )
229
 
230
- # Handlers
231
  demo.load(start_session)
232
  demo.unload(end_session)
233
 
 
234
  image_prompt.upload(
235
  preprocess_image,
236
  inputs=[image_prompt],
237
  outputs=[image_prompt],
238
  )
239
 
 
240
  generate_btn.click(
241
  get_seed,
242
  inputs=[randomize_seed, seed],
243
  outputs=[seed],
244
  ).then(
245
  image_to_3d,
246
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
 
 
 
 
 
 
 
247
  outputs=[output_buf, video_output],
248
  ).then(
249
- lambda: gr.Button(interactive=True),
250
- outputs=[extract_glb_btn],
251
- )
252
-
253
- video_output.clear(
254
- lambda: gr.Button(interactive=False),
255
- outputs=[extract_glb_btn],
256
  )
257
 
 
258
  extract_glb_btn.click(
259
  extract_glb,
260
  inputs=[output_buf, mesh_simplify, texture_size],
261
  outputs=[model_output, download_glb],
262
  ).then(
263
- lambda: gr.Button(interactive=True),
264
  outputs=[download_glb],
265
  )
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  model_output.clear(
268
- lambda: gr.Button(interactive=False),
269
- outputs=[download_glb],
270
  )
271
-
272
 
273
  # Launch the Gradio app
274
  if __name__ == "__main__":
 
275
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
276
  pipeline.cuda()
277
  try:
 
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
 
20
+ # Constants
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
 
26
+ # Session Management Functions
27
  def start_session(req: gr.Request):
28
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
29
  print(f'Creating user directory: {user_dir}')
30
  os.makedirs(user_dir, exist_ok=True)
31
+
32
+
33
  def end_session(req: gr.Request):
34
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
  print(f'Removing user directory: {user_dir}')
36
  shutil.rmtree(user_dir)
37
 
38
 
39
+ # Image Preprocessing Function
40
+ def preprocess_image(image: Image.Image) -> Image.Image:
41
  """
42
  Preprocess the input image.
43
 
 
45
  image (Image.Image): The input image.
46
 
47
  Returns:
 
48
  Image.Image: The preprocessed image.
49
  """
50
  processed_image = pipeline.preprocess_image(image)
51
  return processed_image
52
 
53
 
54
+ # State Packing and Unpacking Functions
55
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
56
  return {
57
  'gaussian': {
 
68
  },
69
  'trial_id': trial_id,
70
  }
71
+
72
+
73
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
74
  gs = Gaussian(
75
  aabb=state['gaussian']['aabb'],
 
93
  return gs, mesh, state['trial_id']
94
 
95
 
96
+ # Seed Management Function
97
  def get_seed(randomize_seed: bool, seed: int) -> int:
98
  """
99
  Get the random seed.
100
+
101
+ Args:
102
+ randomize_seed (bool): Whether to randomize the seed.
103
+ seed (int): The provided seed value.
104
+
105
+ Returns:
106
+ int: The final seed to use.
107
  """
108
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
109
 
110
 
111
+ # Core 3D Generation Function
112
  @spaces.GPU
113
  def image_to_3d(
114
  image: Image.Image,
 
129
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
130
  slat_guidance_strength (float): The guidance strength for structured latent generation.
131
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
132
+ req (gr.Request): Gradio request object.
133
 
134
  Returns:
135
+ Tuple[dict, str]: The state dictionary and the path to the generated video.
 
136
  """
137
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
138
  outputs = pipeline.run(
 
160
  return state, video_path
161
 
162
 
163
+ # Existing GLB Extraction Function
164
  @spaces.GPU
165
  def extract_glb(
166
  state: dict,
 
175
  state (dict): The state of the generated 3D model.
176
  mesh_simplify (float): The mesh simplification factor.
177
  texture_size (int): The texture resolution.
178
+ req (gr.Request): Gradio request object.
179
 
180
  Returns:
181
+ Tuple[str, str]: The path to the extracted GLB file.
182
  """
183
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
184
  gs, mesh, trial_id = unpack_state(state)
 
189
  return glb_path, glb_path
190
 
191
 
192
+ # New High-Quality GLB Extraction Function
193
+ @spaces.GPU
194
+ def extract_glb_high_quality(
195
+ state: dict,
196
+ req: gr.Request,
197
+ ) -> Tuple[str, str]:
198
+ """
199
+ Extract a high-quality GLB file from the 3D model without polygon reduction.
200
+
201
+ Args:
202
+ state (dict): The state of the generated 3D model.
203
+ req (gr.Request): Gradio request object.
204
+
205
+ Returns:
206
+ Tuple[str, str]: The path to the high-quality GLB file.
207
+ """
208
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
209
+ gs, mesh, trial_id = unpack_state(state)
210
+ # Set simplify to 0.0 to disable polygon reduction
211
+ # Set texture_size to 2048 for maximum texture quality
212
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.0, texture_size=2048, verbose=False)
213
+ glb_path = os.path.join(user_dir, f"{trial_id}_high_quality.glb")
214
+ glb.export(glb_path)
215
+ torch.cuda.empty_cache()
216
+ return glb_path, glb_path
217
+
218
+
219
+ # Gradio Interface Definition
220
  with gr.Blocks(delete_cache=(600, 600)) as demo:
221
  gr.Markdown("""
222
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
223
+ * Upload an image and click "Generate" to create a 3D asset. If the image has an alpha channel, it will be used as the mask. Otherwise, the background will be removed automatically.
224
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
225
+ * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
226
  """)
227
+
228
  with gr.Row():
229
  with gr.Column():
230
+ # Image Input
231
+ image_prompt = gr.Image(
232
+ label="Image Prompt",
233
+ format="png",
234
+ image_mode="RGBA",
235
+ type="pil",
236
+ height=300
237
+ )
238
 
239
+ # Generation Settings Accordion
240
  with gr.Accordion(label="Generation Settings", open=False):
241
+ seed = gr.Slider(
242
+ 0,
243
+ MAX_SEED,
244
+ label="Seed",
245
+ value=0,
246
+ step=1
247
+ )
248
+ randomize_seed = gr.Checkbox(
249
+ label="Randomize Seed",
250
+ value=True
251
+ )
252
+ gr.Markdown("### Stage 1: Sparse Structure Generation")
253
  with gr.Row():
254
+ ss_guidance_strength = gr.Slider(
255
+ 0.0,
256
+ 10.0,
257
+ label="Guidance Strength",
258
+ value=7.5,
259
+ step=0.1
260
+ )
261
+ ss_sampling_steps = gr.Slider(
262
+ 1,
263
+ 500,
264
+ label="Sampling Steps",
265
+ value=12,
266
+ step=1
267
+ )
268
+ gr.Markdown("### Stage 2: Structured Latent Generation")
269
  with gr.Row():
270
+ slat_guidance_strength = gr.Slider(
271
+ 0.0,
272
+ 10.0,
273
+ label="Guidance Strength",
274
+ value=3.0,
275
+ step=0.1
276
+ )
277
+ slat_sampling_steps = gr.Slider(
278
+ 1,
279
+ 500,
280
+ label="Sampling Steps",
281
+ value=12,
282
+ step=1
283
+ )
284
+
285
+ # Generate Button
286
  generate_btn = gr.Button("Generate")
287
 
288
+ # GLB Extraction Settings Accordion
289
  with gr.Accordion(label="GLB Extraction Settings", open=False):
290
+ mesh_simplify = gr.Slider(
291
+ 0.0,
292
+ 0.98,
293
+ label="Simplify",
294
+ value=0.95,
295
+ step=0.01
296
+ )
297
+ texture_size = gr.Slider(
298
+ 512,
299
+ 2048,
300
+ label="Texture Size",
301
+ value=1024,
302
+ step=512
303
+ )
304
 
305
+ # Existing Extract GLB Button
306
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
307
+
308
+ # New Extract High Quality GLB Button
309
+ extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
310
 
311
  with gr.Column():
312
+ # Video Output
313
+ video_output = gr.Video(
314
+ label="Generated 3D Asset",
315
+ autoplay=True,
316
+ loop=True,
317
+ height=300
318
+ )
319
+ # 3D Model Display
320
+ model_output = LitModel3D(
321
+ label="Extracted GLB",
322
+ exposure=20.0,
323
+ height=300
324
+ )
325
+ # Existing Download GLB Button
326
+ download_glb = gr.DownloadButton(
327
+ label="Download GLB",
328
+ file_name="model.glb",
329
+ interactive=False
330
+ )
331
+ # New Download High Quality GLB Button
332
+ download_high_quality_glb = gr.DownloadButton(
333
+ label="Download High Quality GLB",
334
+ file_name="model_high_quality.glb",
335
+ interactive=False
336
+ )
337
+
338
+ # State Variables
339
  output_buf = gr.State()
340
+ glb_path_state = gr.State() # For standard GLB
341
+ glb_high_quality_path_state = gr.State() # For high-quality GLB
342
 
343
+ # Example Images
344
  with gr.Row():
345
  examples = gr.Examples(
346
  examples=[
 
354
  examples_per_page=64,
355
  )
356
 
357
+ # Event Handlers
358
  demo.load(start_session)
359
  demo.unload(end_session)
360
 
361
+ # Image Upload Handler
362
  image_prompt.upload(
363
  preprocess_image,
364
  inputs=[image_prompt],
365
  outputs=[image_prompt],
366
  )
367
 
368
+ # Generate Button Click Handler
369
  generate_btn.click(
370
  get_seed,
371
  inputs=[randomize_seed, seed],
372
  outputs=[seed],
373
  ).then(
374
  image_to_3d,
375
+ inputs=[
376
+ image_prompt,
377
+ seed,
378
+ ss_guidance_strength,
379
+ ss_sampling_steps,
380
+ slat_guidance_strength,
381
+ slat_sampling_steps
382
+ ],
383
  outputs=[output_buf, video_output],
384
  ).then(
385
+ lambda: gr.Button.update(interactive=True),
386
+ outputs=[extract_glb_btn, extract_glb_high_quality_btn],
 
 
 
 
 
387
  )
388
 
389
+ # Existing Extract GLB Button Click Handler
390
  extract_glb_btn.click(
391
  extract_glb,
392
  inputs=[output_buf, mesh_simplify, texture_size],
393
  outputs=[model_output, download_glb],
394
  ).then(
395
+ lambda: gr.Button.update(interactive=True),
396
  outputs=[download_glb],
397
  )
398
 
399
+ # New Extract High Quality GLB Button Click Handler
400
+ extract_glb_high_quality_btn.click(
401
+ extract_glb_high_quality,
402
+ inputs=[output_buf],
403
+ outputs=[model_output, glb_high_quality_path_state],
404
+ ).then(
405
+ lambda glb_path: {"value": glb_path} if glb_path else None,
406
+ inputs=[glb_high_quality_path_state],
407
+ outputs=[download_high_quality_glb],
408
+ ).then(
409
+ lambda: gr.Button.update(interactive=True),
410
+ outputs=[download_high_quality_glb],
411
+ )
412
+
413
+ # Handle Clearing of Video Output
414
+ video_output.clear(
415
+ lambda: (gr.Button.update(interactive=False), gr.Button.update(interactive=False)),
416
+ outputs=[extract_glb_btn, extract_glb_high_quality_btn],
417
+ )
418
+
419
+ # Handle Clearing of Model Output
420
  model_output.clear(
421
+ lambda: (gr.File.update(value=None), gr.File.update(value=None)),
422
+ outputs=[download_glb, download_high_quality_glb],
423
  )
424
+
425
 
426
  # Launch the Gradio app
427
  if __name__ == "__main__":
428
+ # Initialize the pipeline
429
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
430
  pipeline.cuda()
431
  try: