cronos3k commited on
Commit
342cabd
·
verified ·
1 Parent(s): 9173005

Update app.py

Browse files

this should fix things

Files changed (1) hide show
  1. app.py +56 -138
app.py CHANGED
@@ -16,24 +16,25 @@ from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
- # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
- # Session Management Functions
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
 
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  print(f'Removing user directory: {user_dir}')
33
  shutil.rmtree(user_dir)
34
 
35
- # Image Preprocessing Function
36
- def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
38
  Preprocess the input image.
39
 
@@ -41,12 +42,13 @@ def preprocess_image(image: Image.Image) -> Image.Image:
41
  image (Image.Image): The input image.
42
 
43
  Returns:
 
44
  Image.Image: The preprocessed image.
45
  """
46
  processed_image = pipeline.preprocess_image(image)
47
  return processed_image
48
 
49
- # State Packing and Unpacking Functions
50
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
51
  return {
52
  'gaussian': {
@@ -64,6 +66,7 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
64
  'trial_id': trial_id,
65
  }
66
 
 
67
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
  gs = Gaussian(
69
  aabb=state['gaussian']['aabb'],
@@ -86,21 +89,14 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
86
 
87
  return gs, mesh, state['trial_id']
88
 
89
- # Seed Management Function
90
  def get_seed(randomize_seed: bool, seed: int) -> int:
91
  """
92
  Get the random seed.
93
-
94
- Args:
95
- randomize_seed (bool): Whether to randomize the seed.
96
- seed (int): The provided seed value.
97
-
98
- Returns:
99
- int: The final seed to use.
100
  """
101
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
102
 
103
- # Core 3D Generation Function
104
  @spaces.GPU
105
  def image_to_3d(
106
  image: Image.Image,
@@ -121,10 +117,10 @@ def image_to_3d(
121
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
122
  slat_guidance_strength (float): The guidance strength for structured latent generation.
123
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
124
- req (gr.Request): Gradio request object.
125
 
126
  Returns:
127
- Tuple[dict, str]: The state dictionary and the path to the generated video.
 
128
  """
129
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
130
  outputs = pipeline.run(
@@ -151,7 +147,7 @@ def image_to_3d(
151
  torch.cuda.empty_cache()
152
  return state, video_path
153
 
154
- # Existing GLB Extraction Function
155
  @spaces.GPU
156
  def extract_glb(
157
  state: dict,
@@ -166,10 +162,9 @@ def extract_glb(
166
  state (dict): The state of the generated 3D model.
167
  mesh_simplify (float): The mesh simplification factor.
168
  texture_size (int): The texture resolution.
169
- req (gr.Request): Gradio request object.
170
 
171
  Returns:
172
- Tuple[str, str]: The path to the extracted GLB file.
173
  """
174
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
175
  gs, mesh, trial_id = unpack_state(state)
@@ -179,7 +174,8 @@ def extract_glb(
179
  torch.cuda.empty_cache()
180
  return glb_path, glb_path
181
 
182
- # **New High-Quality GLB Extraction Function**
 
183
  @spaces.GPU
184
  def extract_glb_high_quality(
185
  state: dict,
@@ -205,7 +201,7 @@ def extract_glb_high_quality(
205
  torch.cuda.empty_cache()
206
  return glb_path, glb_path
207
 
208
- # Gradio Interface Definition
209
  with gr.Blocks(delete_cache=(600, 600)) as demo:
210
  gr.Markdown("""
211
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -213,121 +209,53 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
213
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
214
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
215
  """)
216
-
217
  with gr.Row():
218
  with gr.Column():
219
- # Image Input
220
- image_prompt = gr.Image(
221
- label="Image Prompt",
222
- format="png",
223
- image_mode="RGBA",
224
- type="pil",
225
- height=300
226
- )
227
 
228
- # Generation Settings Accordion
229
  with gr.Accordion(label="Generation Settings", open=False):
230
- seed = gr.Slider(
231
- 0,
232
- MAX_SEED,
233
- label="Seed",
234
- value=0,
235
- step=1
236
- )
237
- randomize_seed = gr.Checkbox(
238
- label="Randomize Seed",
239
- value=True
240
- )
241
- gr.Markdown("### Stage 1: Sparse Structure Generation")
242
  with gr.Row():
243
- ss_guidance_strength = gr.Slider(
244
- 0.0,
245
- 10.0,
246
- label="Guidance Strength",
247
- value=7.5,
248
- step=0.1
249
- )
250
- ss_sampling_steps = gr.Slider(
251
- 1,
252
- 500,
253
- label="Sampling Steps",
254
- value=12,
255
- step=1
256
- )
257
- gr.Markdown("### Stage 2: Structured Latent Generation")
258
  with gr.Row():
259
- slat_guidance_strength = gr.Slider(
260
- 0.0,
261
- 10.0,
262
- label="Guidance Strength",
263
- value=3.0,
264
- step=0.1
265
- )
266
- slat_sampling_steps = gr.Slider(
267
- 1,
268
- 500,
269
- label="Sampling Steps",
270
- value=12,
271
- step=1
272
- )
273
-
274
- # Generate Button
275
  generate_btn = gr.Button("Generate")
276
 
277
- # GLB Extraction Settings Accordion
278
  with gr.Accordion(label="GLB Extraction Settings", open=False):
279
- mesh_simplify = gr.Slider(
280
- 0.0,
281
- 0.98,
282
- label="Simplify",
283
- value=0.95,
284
- step=0.01
285
- )
286
- texture_size = gr.Slider(
287
- 512,
288
- 2048,
289
- label="Texture Size",
290
- value=1024,
291
- step=512
292
- )
293
 
294
- # Existing Extract GLB Button
295
- extract_glb_btn = gr.Button("Extract GLB", interactive=True)
296
 
297
- # **New Download High Quality GLB Button**
298
- download_high_quality_glb_btn = gr.Button("Download High Quality GLB", interactive=True)
299
 
300
  with gr.Column():
301
- # Video Output
302
- video_output = gr.Video(
303
- label="Generated 3D Asset",
304
- autoplay=True,
305
- loop=True,
306
- height=300
307
- )
308
- # 3D Model Display
309
- model_output = LitModel3D(
310
- label="Extracted GLB",
311
- exposure=20.0,
312
- height=300
313
- )
314
- # Existing Download GLB Button
315
  download_glb = gr.DownloadButton(
316
  label="Download GLB",
317
- file_count="single",
318
  )
319
- # **New Download High Quality GLB Button**
 
320
  download_high_quality_glb = gr.DownloadButton(
321
  label="Download High Quality GLB",
322
- file_count="single",
323
  )
324
 
325
- # State Variables
326
  output_buf = gr.State()
327
  glb_path_state = gr.State() # For standard GLB
328
  glb_high_quality_path_state = gr.State() # For high-quality GLB
329
 
330
- # Example Images
331
  with gr.Row():
332
  examples = gr.Examples(
333
  examples=[
@@ -341,39 +269,35 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
341
  examples_per_page=64,
342
  )
343
 
344
- # Event Handlers
345
  demo.load(start_session)
346
  demo.unload(end_session)
347
 
348
- # Image Upload Handler
349
  image_prompt.upload(
350
  preprocess_image,
351
  inputs=[image_prompt],
352
  outputs=[image_prompt],
353
  )
354
 
355
- # Generate Button Click Handler
356
  generate_btn.click(
357
  get_seed,
358
  inputs=[randomize_seed, seed],
359
  outputs=[seed],
360
  ).then(
361
  image_to_3d,
362
- inputs=[
363
- image_prompt,
364
- seed,
365
- ss_guidance_strength,
366
- ss_sampling_steps,
367
- slat_guidance_strength,
368
- slat_sampling_steps
369
- ],
370
  outputs=[output_buf, video_output],
371
  ).then(
372
- lambda: gr.Button.update(interactive=True),
373
- outputs=[extract_glb_btn, download_high_quality_glb_btn],
 
 
 
 
 
 
374
  )
375
 
376
- # Existing Extract GLB Button Click Handler
377
  extract_glb_btn.click(
378
  extract_glb,
379
  inputs=[output_buf, mesh_simplify, texture_size],
@@ -384,8 +308,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
384
  outputs=[download_glb],
385
  )
386
 
387
- # **New Download High Quality GLB Button Click Handler**
388
- download_high_quality_glb_btn.click(
389
  extract_glb_high_quality,
390
  inputs=[output_buf],
391
  outputs=[model_output, glb_high_quality_path_state],
@@ -395,25 +319,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
395
  outputs=[download_high_quality_glb],
396
  )
397
 
398
- # Handle Clearing of Video Output
399
- video_output.clear(
400
- lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
401
- outputs=[extract_glb_btn, download_high_quality_glb_btn],
402
- )
403
-
404
- # Handle Clearing of Model Output
405
  model_output.clear(
406
  lambda: (gr.File.update(value=None), gr.File.update(value=None)),
407
  outputs=[download_glb, download_high_quality_glb],
408
  )
409
 
410
- # Launch the Gradio app
 
411
  if __name__ == "__main__":
412
- # Initialize the pipeline
413
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
414
  pipeline.cuda()
415
  try:
416
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
417
  except:
418
  pass
419
- demo.queue(concurrency_count=1, max_size=10).launch()
 
 
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
+
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
+
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
30
+
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
  print(f'Removing user directory: {user_dir}')
34
  shutil.rmtree(user_dir)
35
 
36
+
37
+ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
38
  """
39
  Preprocess the input image.
40
 
 
42
  image (Image.Image): The input image.
43
 
44
  Returns:
45
+ str: uuid of the trial.
46
  Image.Image: The preprocessed image.
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
50
 
51
+
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
53
  return {
54
  'gaussian': {
 
66
  'trial_id': trial_id,
67
  }
68
 
69
+
70
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
  gs = Gaussian(
72
  aabb=state['gaussian']['aabb'],
 
89
 
90
  return gs, mesh, state['trial_id']
91
 
92
+
93
  def get_seed(randomize_seed: bool, seed: int) -> int:
94
  """
95
  Get the random seed.
 
 
 
 
 
 
 
96
  """
97
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
98
 
99
+
100
  @spaces.GPU
101
  def image_to_3d(
102
  image: Image.Image,
 
117
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
119
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
120
 
121
  Returns:
122
+ dict: The information of the generated 3D model.
123
+ str: The path to the video of the 3D model.
124
  """
125
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
126
  outputs = pipeline.run(
 
147
  torch.cuda.empty_cache()
148
  return state, video_path
149
 
150
+
151
  @spaces.GPU
152
  def extract_glb(
153
  state: dict,
 
162
  state (dict): The state of the generated 3D model.
163
  mesh_simplify (float): The mesh simplification factor.
164
  texture_size (int): The texture resolution.
 
165
 
166
  Returns:
167
+ str: The path to the extracted GLB file.
168
  """
169
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
  gs, mesh, trial_id = unpack_state(state)
 
174
  torch.cuda.empty_cache()
175
  return glb_path, glb_path
176
 
177
+
178
+ # **Addition: High-Quality GLB Extraction Function**
179
  @spaces.GPU
180
  def extract_glb_high_quality(
181
  state: dict,
 
201
  torch.cuda.empty_cache()
202
  return glb_path, glb_path
203
 
204
+
205
  with gr.Blocks(delete_cache=(600, 600)) as demo:
206
  gr.Markdown("""
207
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
209
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
210
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
211
  """)
212
+
213
  with gr.Row():
214
  with gr.Column():
215
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
 
 
 
 
 
216
 
 
217
  with gr.Accordion(label="Generation Settings", open=False):
218
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
219
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
220
+ gr.Markdown("Stage 1: Sparse Structure Generation")
 
 
 
 
 
 
 
 
 
221
  with gr.Row():
222
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
223
+ ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
224
+ gr.Markdown("Stage 2: Structured Latent Generation")
 
 
 
 
 
 
 
 
 
 
 
 
225
  with gr.Row():
226
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
227
+ slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
228
+
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  generate_btn = gr.Button("Generate")
230
 
 
231
  with gr.Accordion(label="GLB Extraction Settings", open=False):
232
+ mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
233
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
236
 
237
+ # **Addition: Download High Quality GLB Button**
238
+ extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
239
 
240
  with gr.Column():
241
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
242
+ model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
 
 
 
 
 
 
 
 
 
 
 
 
243
  download_glb = gr.DownloadButton(
244
  label="Download GLB",
245
+ # Removed 'file_count' to prevent runtime error
246
  )
247
+
248
+ # **Addition: Download High Quality GLB DownloadButton**
249
  download_high_quality_glb = gr.DownloadButton(
250
  label="Download High Quality GLB",
251
+ # Removed 'file_count' to prevent runtime error
252
  )
253
 
 
254
  output_buf = gr.State()
255
  glb_path_state = gr.State() # For standard GLB
256
  glb_high_quality_path_state = gr.State() # For high-quality GLB
257
 
258
+ # Example images at the bottom of the page
259
  with gr.Row():
260
  examples = gr.Examples(
261
  examples=[
 
269
  examples_per_page=64,
270
  )
271
 
272
+ # Handlers
273
  demo.load(start_session)
274
  demo.unload(end_session)
275
 
 
276
  image_prompt.upload(
277
  preprocess_image,
278
  inputs=[image_prompt],
279
  outputs=[image_prompt],
280
  )
281
 
 
282
  generate_btn.click(
283
  get_seed,
284
  inputs=[randomize_seed, seed],
285
  outputs=[seed],
286
  ).then(
287
  image_to_3d,
288
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
 
 
 
 
 
 
 
289
  outputs=[output_buf, video_output],
290
  ).then(
291
+ # Enable the Extract GLB and Download High Quality GLB buttons after generation
292
+ lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
293
+ outputs=[extract_glb_btn, extract_glb_high_quality_btn],
294
+ )
295
+
296
+ video_output.clear(
297
+ lambda: (gr.Button.update(interactive=False), gr.Button.update(interactive=False)),
298
+ outputs=[extract_glb_btn, extract_glb_high_quality_btn],
299
  )
300
 
 
301
  extract_glb_btn.click(
302
  extract_glb,
303
  inputs=[output_buf, mesh_simplify, texture_size],
 
308
  outputs=[download_glb],
309
  )
310
 
311
+ # **Addition: High-Quality GLB Extraction and Download**
312
+ extract_glb_high_quality_btn.click(
313
  extract_glb_high_quality,
314
  inputs=[output_buf],
315
  outputs=[model_output, glb_high_quality_path_state],
 
319
  outputs=[download_high_quality_glb],
320
  )
321
 
 
 
 
 
 
 
 
322
  model_output.clear(
323
  lambda: (gr.File.update(value=None), gr.File.update(value=None)),
324
  outputs=[download_glb, download_high_quality_glb],
325
  )
326
 
327
+
328
+ # **Addition: Configure Gradio's Queue to Handle Long GPU Operations**
329
  if __name__ == "__main__":
 
330
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
331
  pipeline.cuda()
332
  try:
333
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
334
  except:
335
  pass
336
+ # Configure Gradio's queue with appropriate settings
337
+ demo.queue(concurrency_count=1, max_size=10, timeout=600).launch()