cronos3k commited on
Commit
0b2d68e
·
verified ·
1 Parent(s): c260ece

Update app.py

Browse files

addressing a error

Files changed (1) hide show
  1. app.py +6 -20
app.py CHANGED
@@ -16,26 +16,22 @@ from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
-
20
  # Constants
21
  MAX_SEED = np.iinfo(np.int32).max
22
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
23
  os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
-
26
  # Session Management Functions
27
  def start_session(req: gr.Request):
28
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
29
  print(f'Creating user directory: {user_dir}')
30
  os.makedirs(user_dir, exist_ok=True)
31
 
32
-
33
  def end_session(req: gr.Request):
34
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
  print(f'Removing user directory: {user_dir}')
36
  shutil.rmtree(user_dir)
37
 
38
-
39
  # Image Preprocessing Function
40
  def preprocess_image(image: Image.Image) -> Image.Image:
41
  """
@@ -50,7 +46,6 @@ def preprocess_image(image: Image.Image) -> Image.Image:
50
  processed_image = pipeline.preprocess_image(image)
51
  return processed_image
52
 
53
-
54
  # State Packing and Unpacking Functions
55
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
56
  return {
@@ -69,7 +64,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
69
  'trial_id': trial_id,
70
  }
71
 
72
-
73
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
74
  gs = Gaussian(
75
  aabb=state['gaussian']['aabb'],
@@ -92,7 +86,6 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
92
 
93
  return gs, mesh, state['trial_id']
94
 
95
-
96
  # Seed Management Function
97
  def get_seed(randomize_seed: bool, seed: int) -> int:
98
  """
@@ -107,7 +100,6 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
107
  """
108
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
109
 
110
-
111
  # Core 3D Generation Function
112
  @spaces.GPU
113
  def image_to_3d(
@@ -159,7 +151,6 @@ def image_to_3d(
159
  torch.cuda.empty_cache()
160
  return state, video_path
161
 
162
-
163
  # Existing GLB Extraction Function
164
  @spaces.GPU
165
  def extract_glb(
@@ -188,7 +179,6 @@ def extract_glb(
188
  torch.cuda.empty_cache()
189
  return glb_path, glb_path
190
 
191
-
192
  # New High-Quality GLB Extraction Function
193
  @spaces.GPU
194
  def extract_glb_high_quality(
@@ -215,7 +205,6 @@ def extract_glb_high_quality(
215
  torch.cuda.empty_cache()
216
  return glb_path, glb_path
217
 
218
-
219
  # Gradio Interface Definition
220
  with gr.Blocks(delete_cache=(600, 600)) as demo:
221
  gr.Markdown("""
@@ -325,14 +314,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
325
  # Existing Download GLB Button
326
  download_glb = gr.DownloadButton(
327
  label="Download GLB",
328
- file_name="model.glb",
329
- interactive=False
330
  )
331
  # New Download High Quality GLB Button
332
  download_high_quality_glb = gr.DownloadButton(
333
  label="Download High Quality GLB",
334
- file_name="model_high_quality.glb",
335
- interactive=False
336
  )
337
 
338
  # State Variables
@@ -382,7 +369,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
382
  ],
383
  outputs=[output_buf, video_output],
384
  ).then(
385
- lambda: gr.Button.update(interactive=True),
386
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
387
  )
388
 
@@ -392,7 +379,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
392
  inputs=[output_buf, mesh_simplify, texture_size],
393
  outputs=[model_output, download_glb],
394
  ).then(
395
- lambda: gr.Button.update(interactive=True),
396
  outputs=[download_glb],
397
  )
398
 
@@ -402,11 +389,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
402
  inputs=[output_buf],
403
  outputs=[model_output, glb_high_quality_path_state],
404
  ).then(
405
- lambda glb_path: {"value": glb_path} if glb_path else None,
406
  inputs=[glb_high_quality_path_state],
407
  outputs=[download_high_quality_glb],
408
  ).then(
409
- lambda: gr.Button.update(interactive=True),
410
  outputs=[download_high_quality_glb],
411
  )
412
 
@@ -422,7 +409,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
422
  outputs=[download_glb, download_high_quality_glb],
423
  )
424
 
425
-
426
  # Launch the Gradio app
427
  if __name__ == "__main__":
428
  # Initialize the pipeline
 
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
 
19
  # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
 
24
  # Session Management Functions
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
 
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  print(f'Removing user directory: {user_dir}')
33
  shutil.rmtree(user_dir)
34
 
 
35
  # Image Preprocessing Function
36
  def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
 
46
  processed_image = pipeline.preprocess_image(image)
47
  return processed_image
48
 
 
49
  # State Packing and Unpacking Functions
50
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
51
  return {
 
64
  'trial_id': trial_id,
65
  }
66
 
 
67
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
  gs = Gaussian(
69
  aabb=state['gaussian']['aabb'],
 
86
 
87
  return gs, mesh, state['trial_id']
88
 
 
89
  # Seed Management Function
90
  def get_seed(randomize_seed: bool, seed: int) -> int:
91
  """
 
100
  """
101
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
102
 
 
103
  # Core 3D Generation Function
104
  @spaces.GPU
105
  def image_to_3d(
 
151
  torch.cuda.empty_cache()
152
  return state, video_path
153
 
 
154
  # Existing GLB Extraction Function
155
  @spaces.GPU
156
  def extract_glb(
 
179
  torch.cuda.empty_cache()
180
  return glb_path, glb_path
181
 
 
182
  # New High-Quality GLB Extraction Function
183
  @spaces.GPU
184
  def extract_glb_high_quality(
 
205
  torch.cuda.empty_cache()
206
  return glb_path, glb_path
207
 
 
208
  # Gradio Interface Definition
209
  with gr.Blocks(delete_cache=(600, 600)) as demo:
210
  gr.Markdown("""
 
314
  # Existing Download GLB Button
315
  download_glb = gr.DownloadButton(
316
  label="Download GLB",
317
+ interactive=False # Initially disabled
 
318
  )
319
  # New Download High Quality GLB Button
320
  download_high_quality_glb = gr.DownloadButton(
321
  label="Download High Quality GLB",
322
+ interactive=False # Initially disabled
 
323
  )
324
 
325
  # State Variables
 
369
  ],
370
  outputs=[output_buf, video_output],
371
  ).then(
372
+ lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
373
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
374
  )
375
 
 
379
  inputs=[output_buf, mesh_simplify, texture_size],
380
  outputs=[model_output, download_glb],
381
  ).then(
382
+ lambda: gr.DownloadButton.update(interactive=True),
383
  outputs=[download_glb],
384
  )
385
 
 
389
  inputs=[output_buf],
390
  outputs=[model_output, glb_high_quality_path_state],
391
  ).then(
392
+ lambda glb_path: glb_path, # Pass the file path directly
393
  inputs=[glb_high_quality_path_state],
394
  outputs=[download_high_quality_glb],
395
  ).then(
396
+ lambda: gr.DownloadButton.update(interactive=True),
397
  outputs=[download_high_quality_glb],
398
  )
399
 
 
409
  outputs=[download_glb, download_high_quality_glb],
410
  )
411
 
 
412
  # Launch the Gradio app
413
  if __name__ == "__main__":
414
  # Initialize the pipeline