cronos3k commited on
Commit
e2e71cc
·
verified ·
1 Parent(s): b2289be

Update app.py

Browse files

fixing timeout issues

Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -16,24 +16,23 @@ from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
-
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
-
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
30
-
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
  print(f'Removing user directory: {user_dir}')
34
  shutil.rmtree(user_dir)
35
 
36
-
37
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
38
  """
39
  Preprocess the input image.
@@ -46,9 +45,10 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
46
  Image.Image: The preprocessed image.
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
- return processed_image
50
-
51
 
 
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
53
  return {
54
  'gaussian': {
@@ -66,7 +66,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
66
  'trial_id': trial_id,
67
  }
68
 
69
-
70
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
71
  gs = Gaussian(
72
  aabb=state['gaussian']['aabb'],
@@ -89,16 +88,17 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
89
 
90
  return gs, mesh, state['trial_id']
91
 
92
-
93
  def get_seed(randomize_seed: bool, seed: int) -> int:
94
  """
95
  Get the random seed.
96
  """
97
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
98
 
99
-
100
  @spaces.GPU
101
  def image_to_3d(
 
102
  image: Image.Image,
103
  seed: int,
104
  ss_guidance_strength: float,
@@ -111,12 +111,14 @@ def image_to_3d(
111
  Convert an image to a 3D model.
112
 
113
  Args:
 
114
  image (Image.Image): The input image.
115
  seed (int): The random seed.
116
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
117
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
119
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
120
 
121
  Returns:
122
  dict: The information of the generated 3D model.
@@ -140,14 +142,13 @@ def image_to_3d(
140
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
141
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
142
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
143
- trial_id = uuid.uuid4()
144
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
145
  imageio.mimsave(video_path, video, fps=15)
146
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
147
  torch.cuda.empty_cache()
148
  return state, video_path
149
 
150
-
151
  @spaces.GPU
152
  def extract_glb(
153
  state: dict,
@@ -162,6 +163,7 @@ def extract_glb(
162
  state (dict): The state of the generated 3D model.
163
  mesh_simplify (float): The mesh simplification factor.
164
  texture_size (int): The texture resolution.
 
165
 
166
  Returns:
167
  str: The path to the extracted GLB file.
@@ -174,7 +176,6 @@ def extract_glb(
174
  torch.cuda.empty_cache()
175
  return glb_path, glb_path
176
 
177
-
178
  # **Addition: High-Quality GLB Extraction Function**
179
  @spaces.GPU
180
  def extract_glb_high_quality(
@@ -201,7 +202,7 @@ def extract_glb_high_quality(
201
  torch.cuda.empty_cache()
202
  return glb_path, glb_path
203
 
204
-
205
  with gr.Blocks(delete_cache=(600, 600)) as demo:
206
  gr.Markdown("""
207
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -209,11 +210,11 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
209
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
210
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
211
  """)
212
-
213
  with gr.Row():
214
  with gr.Column():
215
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
216
-
217
  with gr.Accordion(label="Generation Settings", open=False):
218
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
219
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
@@ -227,13 +228,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
227
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
228
 
229
  generate_btn = gr.Button("Generate")
230
-
231
  with gr.Accordion(label="GLB Extraction Settings", open=False):
232
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
233
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
234
-
235
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
236
-
237
  # **Addition: Download High Quality GLB Button**
238
  extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
239
 
@@ -244,7 +245,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
244
  label="Download GLB",
245
  # Removed 'file_count' to prevent runtime error
246
  )
247
-
248
  # **Addition: Download High Quality GLB DownloadButton**
249
  download_high_quality_glb = gr.DownloadButton(
250
  label="Download High Quality GLB",
@@ -272,7 +273,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
272
  # Handlers
273
  demo.load(start_session)
274
  demo.unload(end_session)
275
-
276
  image_prompt.upload(
277
  preprocess_image,
278
  inputs=[image_prompt],
@@ -285,7 +286,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
285
  outputs=[seed],
286
  ).then(
287
  image_to_3d,
288
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
289
  outputs=[output_buf, video_output],
290
  ).then(
291
  # Enable the Extract GLB and Download High Quality GLB buttons after generation
@@ -324,7 +325,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
324
  outputs=[download_glb, download_high_quality_glb],
325
  )
326
 
327
-
328
  # **Addition: Configure Gradio's Queue to Handle Long GPU Operations**
329
  if __name__ == "__main__":
330
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
@@ -334,4 +334,5 @@ if __name__ == "__main__":
334
  except:
335
  pass
336
  # Configure Gradio's queue with appropriate settings
337
- demo.queue(concurrency_count=1, max_size=10).launch()
 
 
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
+ # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
+ # Session Management Functions
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
 
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  print(f'Removing user directory: {user_dir}')
33
  shutil.rmtree(user_dir)
34
 
35
+ # Image Preprocessing Function
36
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
37
  """
38
  Preprocess the input image.
 
45
  Image.Image: The preprocessed image.
46
  """
47
  processed_image = pipeline.preprocess_image(image)
48
+ trial_id = str(uuid.uuid4())
49
+ return trial_id, processed_image
50
 
51
+ # State Packing and Unpacking Functions
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
53
  return {
54
  'gaussian': {
 
66
  'trial_id': trial_id,
67
  }
68
 
 
69
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
70
  gs = Gaussian(
71
  aabb=state['gaussian']['aabb'],
 
88
 
89
  return gs, mesh, state['trial_id']
90
 
91
+ # Seed Management Function
92
  def get_seed(randomize_seed: bool, seed: int) -> int:
93
  """
94
  Get the random seed.
95
  """
96
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
97
 
98
+ # Core 3D Generation Function
99
  @spaces.GPU
100
  def image_to_3d(
101
+ trial_id: str,
102
  image: Image.Image,
103
  seed: int,
104
  ss_guidance_strength: float,
 
111
  Convert an image to a 3D model.
112
 
113
  Args:
114
+ trial_id (str): The UUID of the trial.
115
  image (Image.Image): The input image.
116
  seed (int): The random seed.
117
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
118
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
119
  slat_guidance_strength (float): The guidance strength for structured latent generation.
120
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
121
+ req (gr.Request): Gradio request object.
122
 
123
  Returns:
124
  dict: The information of the generated 3D model.
 
142
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
143
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
144
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
145
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
146
  imageio.mimsave(video_path, video, fps=15)
147
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
148
  torch.cuda.empty_cache()
149
  return state, video_path
150
 
151
+ # Existing GLB Extraction Function
152
  @spaces.GPU
153
  def extract_glb(
154
  state: dict,
 
163
  state (dict): The state of the generated 3D model.
164
  mesh_simplify (float): The mesh simplification factor.
165
  texture_size (int): The texture resolution.
166
+ req (gr.Request): Gradio request object.
167
 
168
  Returns:
169
  str: The path to the extracted GLB file.
 
176
  torch.cuda.empty_cache()
177
  return glb_path, glb_path
178
 
 
179
  # **Addition: High-Quality GLB Extraction Function**
180
  @spaces.GPU
181
  def extract_glb_high_quality(
 
202
  torch.cuda.empty_cache()
203
  return glb_path, glb_path
204
 
205
+ # Gradio Interface Definition
206
  with gr.Blocks(delete_cache=(600, 600)) as demo:
207
  gr.Markdown("""
208
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
210
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
211
  * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
212
  """)
213
+
214
  with gr.Row():
215
  with gr.Column():
216
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
217
+
218
  with gr.Accordion(label="Generation Settings", open=False):
219
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
220
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
228
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
229
 
230
  generate_btn = gr.Button("Generate")
231
+
232
  with gr.Accordion(label="GLB Extraction Settings", open=False):
233
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
234
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
235
+
236
  extract_glb_btn = gr.Button("Extract GLB", interactive=False)
237
+
238
  # **Addition: Download High Quality GLB Button**
239
  extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
240
 
 
245
  label="Download GLB",
246
  # Removed 'file_count' to prevent runtime error
247
  )
248
+
249
  # **Addition: Download High Quality GLB DownloadButton**
250
  download_high_quality_glb = gr.DownloadButton(
251
  label="Download High Quality GLB",
 
273
  # Handlers
274
  demo.load(start_session)
275
  demo.unload(end_session)
276
+
277
  image_prompt.upload(
278
  preprocess_image,
279
  inputs=[image_prompt],
 
286
  outputs=[seed],
287
  ).then(
288
  image_to_3d,
289
+ inputs=[output_buf, image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
290
  outputs=[output_buf, video_output],
291
  ).then(
292
  # Enable the Extract GLB and Download High Quality GLB buttons after generation
 
325
  outputs=[download_glb, download_high_quality_glb],
326
  )
327
 
 
328
  # **Addition: Configure Gradio's Queue to Handle Long GPU Operations**
329
  if __name__ == "__main__":
330
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
334
  except:
335
  pass
336
  # Configure Gradio's queue with appropriate settings
337
+ # Removed 'concurrency_count' and 'timeout' as they are deprecated
338
+ demo.queue().launch()