JeffreyXiang commited on
Commit
cd41f5f
·
1 Parent(s): 4cd032d

Manage cache use gradio

Browse files
Files changed (1) hide show
  1. app.py +61 -54
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  from gradio_litmodel3d import LitModel3D
4
 
5
  import os
 
6
  os.environ['SPCONV_ALGO'] = 'native'
7
  from typing import *
8
  import torch
@@ -17,11 +18,22 @@ from trellis.utils import render_utils, postprocessing_utils
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
- TMP_DIR = "/tmp/Trellis-demo"
21
-
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
26
  """
27
  Preprocess the input image.
@@ -33,10 +45,8 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
33
  str: uuid of the trial.
34
  Image.Image: The preprocessed image.
35
  """
36
- trial_id = str(uuid.uuid4())
37
  processed_image = pipeline.preprocess_image(image)
38
- processed_image.save(f"{TMP_DIR}/{trial_id}.png")
39
- return trial_id, processed_image
40
 
41
 
42
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
@@ -80,15 +90,29 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
80
  return gs, mesh, state['trial_id']
81
 
82
 
 
 
 
 
 
 
 
83
  @spaces.GPU
84
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
 
 
 
 
 
 
 
 
85
  """
86
  Convert an image to a 3D model.
87
 
88
  Args:
89
- trial_id (str): The uuid of the trial.
90
  seed (int): The random seed.
91
- randomize_seed (bool): Whether to randomize the seed.
92
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
93
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
94
  slat_guidance_strength (float): The guidance strength for structured latent generation.
@@ -98,10 +122,9 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
98
  dict: The information of the generated 3D model.
99
  str: The path to the video of the 3D model.
100
  """
101
- if randomize_seed:
102
- seed = np.random.randint(0, MAX_SEED)
103
  outputs = pipeline.run(
104
- Image.open(f"{TMP_DIR}/{trial_id}.png"),
105
  seed=seed,
106
  formats=["gaussian", "mesh"],
107
  preprocess_image=False,
@@ -118,15 +141,20 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
118
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
119
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
120
  trial_id = uuid.uuid4()
121
- video_path = f"{TMP_DIR}/{trial_id}.mp4"
122
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
123
  imageio.mimsave(video_path, video, fps=15)
124
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
 
125
  return state, video_path
126
 
127
 
128
  @spaces.GPU
129
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
 
 
 
 
 
130
  """
131
  Extract a GLB file from the 3D model.
132
 
@@ -138,22 +166,16 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
138
  Returns:
139
  str: The path to the extracted GLB file.
140
  """
 
141
  gs, mesh, trial_id = unpack_state(state)
142
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
143
- glb_path = f"{TMP_DIR}/{trial_id}.glb"
144
  glb.export(glb_path)
 
145
  return glb_path, glb_path
146
 
147
 
148
- def activate_button() -> gr.Button:
149
- return gr.Button(interactive=True)
150
-
151
-
152
- def deactivate_button() -> gr.Button:
153
- return gr.Button(interactive=False)
154
-
155
-
156
- with gr.Blocks() as demo:
157
  gr.Markdown("""
158
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
159
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
@@ -162,7 +184,7 @@ with gr.Blocks() as demo:
162
 
163
  with gr.Row():
164
  with gr.Column():
165
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
166
 
167
  with gr.Accordion(label="Generation Settings", open=False):
168
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -189,7 +211,6 @@ with gr.Blocks() as demo:
189
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
190
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
191
 
192
- trial_id = gr.Textbox(visible=False)
193
  output_buf = gr.State()
194
 
195
  # Example images at the bottom of the page
@@ -201,33 +222,36 @@ with gr.Blocks() as demo:
201
  ],
202
  inputs=[image_prompt],
203
  fn=preprocess_image,
204
- outputs=[trial_id, image_prompt],
205
  run_on_click=True,
206
  examples_per_page=64,
207
  )
208
 
209
  # Handlers
 
 
 
210
  image_prompt.upload(
211
  preprocess_image,
212
  inputs=[image_prompt],
213
- outputs=[trial_id, image_prompt],
214
- )
215
- image_prompt.clear(
216
- lambda: '',
217
- outputs=[trial_id],
218
  )
219
 
220
  generate_btn.click(
 
 
 
 
221
  image_to_3d,
222
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
223
  outputs=[output_buf, video_output],
224
  ).then(
225
- activate_button,
226
  outputs=[extract_glb_btn],
227
  )
228
 
229
  video_output.clear(
230
- deactivate_button,
231
  outputs=[extract_glb_btn],
232
  )
233
 
@@ -236,33 +260,16 @@ with gr.Blocks() as demo:
236
  inputs=[output_buf, mesh_simplify, texture_size],
237
  outputs=[model_output, download_glb],
238
  ).then(
239
- activate_button,
240
  outputs=[download_glb],
241
  )
242
 
243
  model_output.clear(
244
- deactivate_button,
245
  outputs=[download_glb],
246
  )
247
 
248
 
249
- # Cleans up the temporary directory every 10 minutes
250
- import threading
251
- import time
252
-
253
- def cleanup_tmp_dir():
254
- while True:
255
- if os.path.exists(TMP_DIR):
256
- for file in os.listdir(TMP_DIR):
257
- # remove files older than 10 minutes
258
- if time.time() - os.path.getmtime(os.path.join(TMP_DIR, file)) > 600:
259
- os.remove(os.path.join(TMP_DIR, file))
260
- time.sleep(600)
261
-
262
- cleanup_thread = threading.Thread(target=cleanup_tmp_dir)
263
- cleanup_thread.start()
264
-
265
-
266
  # Launch the Gradio app
267
  if __name__ == "__main__":
268
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
3
  from gradio_litmodel3d import LitModel3D
4
 
5
  import os
6
+ import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
9
  import torch
 
18
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
 
25
+ def start_session(req: gr.Request):
26
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
+ print(f'Creating user directory: {user_dir}')
28
+ os.makedirs(user_dir, exist_ok=True)
29
+
30
+
31
+ def end_session(req: gr.Request):
32
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
+ print(f'Removing user directory: {user_dir}')
34
+ shutil.rmtree(user_dir)
35
+
36
+
37
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
38
  """
39
  Preprocess the input image.
 
45
  str: uuid of the trial.
46
  Image.Image: The preprocessed image.
47
  """
 
48
  processed_image = pipeline.preprocess_image(image)
49
+ return processed_image
 
50
 
51
 
52
  def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
 
90
  return gs, mesh, state['trial_id']
91
 
92
 
93
+ def get_seed(randomize_seed: bool, seed: int) -> int:
94
+ """
95
+ Get the random seed.
96
+ """
97
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
98
+
99
+
100
  @spaces.GPU
101
+ def image_to_3d(
102
+ image: Image.Image,
103
+ seed: int,
104
+ ss_guidance_strength: float,
105
+ ss_sampling_steps: int,
106
+ slat_guidance_strength: float,
107
+ slat_sampling_steps: int,
108
+ req: gr.Request,
109
+ ) -> Tuple[dict, str]:
110
  """
111
  Convert an image to a 3D model.
112
 
113
  Args:
114
+ image (Image.Image): The input image.
115
  seed (int): The random seed.
 
116
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
117
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
 
122
  dict: The information of the generated 3D model.
123
  str: The path to the video of the 3D model.
124
  """
125
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
126
  outputs = pipeline.run(
127
+ image,
128
  seed=seed,
129
  formats=["gaussian", "mesh"],
130
  preprocess_image=False,
 
141
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
142
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
143
  trial_id = uuid.uuid4()
144
+ video_path = os.path.join(user_dir, f"{trial_id}.mp4")
 
145
  imageio.mimsave(video_path, video, fps=15)
146
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
147
+ torch.cuda.empty_cache()
148
  return state, video_path
149
 
150
 
151
  @spaces.GPU
152
+ def extract_glb(
153
+ state: dict,
154
+ mesh_simplify: float,
155
+ texture_size: int,
156
+ req: gr.Request,
157
+ ) -> Tuple[str, str]:
158
  """
159
  Extract a GLB file from the 3D model.
160
 
 
166
  Returns:
167
  str: The path to the extracted GLB file.
168
  """
169
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
  gs, mesh, trial_id = unpack_state(state)
171
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
172
+ glb_path = os.path.join(user_dir, f"{trial_id}.glb")
173
  glb.export(glb_path)
174
+ torch.cuda.empty_cache()
175
  return glb_path, glb_path
176
 
177
 
178
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
 
 
 
 
 
 
 
 
179
  gr.Markdown("""
180
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
 
184
 
185
  with gr.Row():
186
  with gr.Column():
187
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
188
 
189
  with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
211
  model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
212
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
213
 
 
214
  output_buf = gr.State()
215
 
216
  # Example images at the bottom of the page
 
222
  ],
223
  inputs=[image_prompt],
224
  fn=preprocess_image,
225
+ outputs=[image_prompt],
226
  run_on_click=True,
227
  examples_per_page=64,
228
  )
229
 
230
  # Handlers
231
+ demo.load(start_session)
232
+ demo.unload(end_session)
233
+
234
  image_prompt.upload(
235
  preprocess_image,
236
  inputs=[image_prompt],
237
+ outputs=[image_prompt],
 
 
 
 
238
  )
239
 
240
  generate_btn.click(
241
+ get_seed,
242
+ inputs=[randomize_seed, seed],
243
+ outputs=[seed],
244
+ ).then(
245
  image_to_3d,
246
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
247
  outputs=[output_buf, video_output],
248
  ).then(
249
+ lambda: gr.Button(interactive=True),
250
  outputs=[extract_glb_btn],
251
  )
252
 
253
  video_output.clear(
254
+ lambda: gr.Button(interactive=False),
255
  outputs=[extract_glb_btn],
256
  )
257
 
 
260
  inputs=[output_buf, mesh_simplify, texture_size],
261
  outputs=[model_output, download_glb],
262
  ).then(
263
+ lambda: gr.Button(interactive=True),
264
  outputs=[download_glb],
265
  )
266
 
267
  model_output.clear(
268
+ lambda: gr.Button(interactive=False),
269
  outputs=[download_glb],
270
  )
271
 
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # Launch the Gradio app
274
  if __name__ == "__main__":
275
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")