cronos3k commited on
Commit
a481d7a
·
verified ·
1 Parent(s): e2e71cc

Update app.py

Browse files

working on timeouts

Files changed (1) hide show
  1. app.py +52 -31
app.py CHANGED
@@ -33,7 +33,7 @@ def end_session(req: gr.Request):
33
  shutil.rmtree(user_dir)
34
 
35
  # Image Preprocessing Function
36
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
37
  """
38
  Preprocess the input image.
39
 
@@ -41,12 +41,15 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
41
  image (Image.Image): The input image.
42
 
43
  Returns:
44
- str: uuid of the trial.
45
  Image.Image: The preprocessed image.
46
  """
 
 
 
 
 
47
  processed_image = pipeline.preprocess_image(image)
48
- trial_id = str(uuid.uuid4())
49
- return trial_id, processed_image
50
 
51
  # State Packing and Unpacking Functions
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
@@ -92,13 +95,19 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
92
  def get_seed(randomize_seed: bool, seed: int) -> int:
93
  """
94
  Get the random seed.
 
 
 
 
 
 
 
95
  """
96
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
 
98
  # Core 3D Generation Function
99
  @spaces.GPU
100
  def image_to_3d(
101
- trial_id: str,
102
  image: Image.Image,
103
  seed: int,
104
  ss_guidance_strength: float,
@@ -111,7 +120,6 @@ def image_to_3d(
111
  Convert an image to a 3D model.
112
 
113
  Args:
114
- trial_id (str): The UUID of the trial.
115
  image (Image.Image): The input image.
116
  seed (int): The random seed.
117
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
@@ -121,8 +129,7 @@ def image_to_3d(
121
  req (gr.Request): Gradio request object.
122
 
123
  Returns:
124
- dict: The information of the generated 3D model.
125
- str: The path to the video of the 3D model.
126
  """
127
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
  outputs = pipeline.run(
@@ -142,6 +149,7 @@ def image_to_3d(
142
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
143
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
144
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
145
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
146
  imageio.mimsave(video_path, video, fps=15)
147
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
@@ -166,7 +174,7 @@ def extract_glb(
166
  req (gr.Request): Gradio request object.
167
 
168
  Returns:
169
- str: The path to the extracted GLB file.
170
  """
171
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
172
  gs, mesh, trial_id = unpack_state(state)
@@ -210,53 +218,62 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
210
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
211
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
212
  """)
213
-
214
  with gr.Row():
215
  with gr.Column():
 
216
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
217
-
 
218
  with gr.Accordion(label="Generation Settings", open=False):
219
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
220
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
221
- gr.Markdown("Stage 1: Sparse Structure Generation")
222
  with gr.Row():
223
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
224
  ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
225
- gr.Markdown("Stage 2: Structured Latent Generation")
226
  with gr.Row():
227
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
228
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
229
 
 
230
  generate_btn = gr.Button("Generate")
231
-
 
232
  with gr.Accordion(label="GLB Extraction Settings", open=False):
233
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
234
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
235
-
 
236
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
237
-
238
  # **Addition: Download High Quality GLB Button**
239
  extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
240
 
241
  with gr.Column():
 
242
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
243
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
 
244
  download_glb = gr.DownloadButton(
245
  label="Download GLB",
246
  # Removed 'file_count' to prevent runtime error
247
  )
248
-
249
  # **Addition: Download High Quality GLB DownloadButton**
250
  download_high_quality_glb = gr.DownloadButton(
251
  label="Download High Quality GLB",
252
  # Removed 'file_count' to prevent runtime error
253
  )
254
-
 
255
  output_buf = gr.State()
256
  glb_path_state = gr.State() # For standard GLB
257
  glb_high_quality_path_state = gr.State() # For high-quality GLB
258
 
259
- # Example images at the bottom of the page
260
  with gr.Row():
261
  examples = gr.Examples(
262
  examples=[
@@ -270,10 +287,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
270
  examples_per_page=64,
271
  )
272
 
273
- # Handlers
274
  demo.load(start_session)
275
  demo.unload(end_session)
276
-
277
  image_prompt.upload(
278
  preprocess_image,
279
  inputs=[image_prompt],
@@ -284,25 +301,28 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
284
  get_seed,
285
  inputs=[randomize_seed, seed],
286
  outputs=[seed],
 
287
  ).then(
288
  image_to_3d,
289
- inputs=[output_buf, image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
290
  outputs=[output_buf, video_output],
 
291
  ).then(
292
  # Enable the Extract GLB and Download High Quality GLB buttons after generation
293
- lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
294
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
295
  )
296
 
297
  video_output.clear(
298
- lambda: (gr.Button.update(interactive=False), gr.Button.update(interactive=False)),
299
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
300
  )
301
 
302
  extract_glb_btn.click(
303
  extract_glb,
304
- inputs=[output_buf, mesh_simplify, texture_size],
305
  outputs=[model_output, glb_path_state],
 
306
  ).then(
307
  lambda glb_path: glb_path if glb_path else "",
308
  inputs=[glb_path_state],
@@ -312,8 +332,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
312
  # **Addition: High-Quality GLB Extraction and Download**
313
  extract_glb_high_quality_btn.click(
314
  extract_glb_high_quality,
315
- inputs=[output_buf],
316
  outputs=[model_output, glb_high_quality_path_state],
 
317
  ).then(
318
  lambda glb_path: glb_path if glb_path else "",
319
  inputs=[glb_high_quality_path_state],
@@ -325,14 +346,14 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
325
  outputs=[download_glb, download_high_quality_glb],
326
  )
327
 
328
- # **Addition: Configure Gradio's Queue to Handle Long GPU Operations**
329
  if __name__ == "__main__":
330
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
331
  pipeline.cuda()
332
  try:
333
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
334
- except:
335
- pass
336
- # Configure Gradio's queue with appropriate settings
337
- # Removed 'concurrency_count' and 'timeout' as they are deprecated
338
  demo.queue().launch()
 
33
  shutil.rmtree(user_dir)
34
 
35
  # Image Preprocessing Function
36
+ def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
38
  Preprocess the input image.
39
 
 
41
  image (Image.Image): The input image.
42
 
43
  Returns:
 
44
  Image.Image: The preprocessed image.
45
  """
46
+ # Validate image
47
+ if image is None:
48
+ raise ValueError("No image provided.")
49
+ if image.mode != "RGBA":
50
+ image = image.convert("RGBA")
51
  processed_image = pipeline.preprocess_image(image)
52
+ return processed_image
 
53
 
54
  # State Packing and Unpacking Functions
55
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
 
95
  def get_seed(randomize_seed: bool, seed: int) -> int:
96
  """
97
  Get the random seed.
98
+
99
+ Args:
100
+ randomize_seed (bool): Whether to randomize the seed.
101
+ seed (int): The provided seed value.
102
+
103
+ Returns:
104
+ int: The final seed to use.
105
  """
106
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
107
 
108
  # Core 3D Generation Function
109
  @spaces.GPU
110
  def image_to_3d(
 
111
  image: Image.Image,
112
  seed: int,
113
  ss_guidance_strength: float,
 
120
  Convert an image to a 3D model.
121
 
122
  Args:
 
123
  image (Image.Image): The input image.
124
  seed (int): The random seed.
125
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
 
129
  req (gr.Request): Gradio request object.
130
 
131
  Returns:
132
+ Tuple[dict, str]: The state dictionary and the path to the generated video.
 
133
  """
134
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
135
  outputs = pipeline.run(
 
149
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
150
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
151
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
152
+ trial_id = uuid.uuid4()
153
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
154
  imageio.mimsave(video_path, video, fps=15)
155
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
 
174
  req (gr.Request): Gradio request object.
175
 
176
  Returns:
177
+ Tuple[str, str]: The path to the extracted GLB file.
178
  """
179
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
180
  gs, mesh, trial_id = unpack_state(state)
 
218
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
219
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
220
  """)
221
+
222
  with gr.Row():
223
  with gr.Column():
224
+ # Image Input
225
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
226
+
227
+ # Generation Settings Accordion
228
  with gr.Accordion(label="Generation Settings", open=False):
229
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
230
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
231
+ gr.Markdown("### Stage 1: Sparse Structure Generation")
232
  with gr.Row():
233
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
234
  ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
235
+ gr.Markdown("### Stage 2: Structured Latent Generation")
236
  with gr.Row():
237
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
238
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
239
 
240
+ # Generate Button
241
  generate_btn = gr.Button("Generate")
242
+
243
+ # GLB Extraction Settings Accordion
244
  with gr.Accordion(label="GLB Extraction Settings", open=False):
245
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
246
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
247
+
248
+ # Existing Extract GLB Button
249
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
250
+
251
  # **Addition: Download High Quality GLB Button**
252
  extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
253
 
254
  with gr.Column():
255
+ # Video Output
256
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
257
+ # 3D Model Display
258
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
259
+ # Existing Download GLB Button
260
  download_glb = gr.DownloadButton(
261
  label="Download GLB",
262
  # Removed 'file_count' to prevent runtime error
263
  )
264
+
265
  # **Addition: Download High Quality GLB DownloadButton**
266
  download_high_quality_glb = gr.DownloadButton(
267
  label="Download High Quality GLB",
268
  # Removed 'file_count' to prevent runtime error
269
  )
270
+
271
+ # State Variables
272
  output_buf = gr.State()
273
  glb_path_state = gr.State() # For standard GLB
274
  glb_high_quality_path_state = gr.State() # For high-quality GLB
275
 
276
+ # Example Images
277
  with gr.Row():
278
  examples = gr.Examples(
279
  examples=[
 
287
  examples_per_page=64,
288
  )
289
 
290
+ # Event Handlers
291
  demo.load(start_session)
292
  demo.unload(end_session)
293
+
294
  image_prompt.upload(
295
  preprocess_image,
296
  inputs=[image_prompt],
 
301
  get_seed,
302
  inputs=[randomize_seed, seed],
303
  outputs=[seed],
304
+ concurrency_limit=1 # Set concurrency limit for Generate
305
  ).then(
306
  image_to_3d,
307
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, gr.Request()],
308
  outputs=[output_buf, video_output],
309
+ concurrency_limit=1 # Set concurrency limit for image_to_3d
310
  ).then(
311
  # Enable the Extract GLB and Download High Quality GLB buttons after generation
312
+ lambda: (True, True),
313
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
314
  )
315
 
316
  video_output.clear(
317
+ lambda: (False, False),
318
  outputs=[extract_glb_btn, extract_glb_high_quality_btn],
319
  )
320
 
321
  extract_glb_btn.click(
322
  extract_glb,
323
+ inputs=[output_buf, mesh_simplify, texture_size, gr.Request()],
324
  outputs=[model_output, glb_path_state],
325
+ concurrency_limit=1 # Set concurrency limit for extract_glb
326
  ).then(
327
  lambda glb_path: glb_path if glb_path else "",
328
  inputs=[glb_path_state],
 
332
  # **Addition: High-Quality GLB Extraction and Download**
333
  extract_glb_high_quality_btn.click(
334
  extract_glb_high_quality,
335
+ inputs=[output_buf, gr.Request()],
336
  outputs=[model_output, glb_high_quality_path_state],
337
+ concurrency_limit=1 # Set concurrency limit for extract_glb_high_quality
338
  ).then(
339
  lambda glb_path: glb_path if glb_path else "",
340
  inputs=[glb_high_quality_path_state],
 
346
  outputs=[download_glb, download_high_quality_glb],
347
  )
348
 
349
+ # Launch the Gradio app
350
  if __name__ == "__main__":
351
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
352
  pipeline.cuda()
353
  try:
354
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
355
+ except Exception as e:
356
+ print(f"Preloading rembg failed: {e}")
357
+
358
+ # Configure Gradio's queue without deprecated parameters
359
  demo.queue().launch()