cronos3k commited on
Commit
d635e38
·
verified ·
1 Parent(s): b880652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -4
app.py CHANGED
@@ -20,8 +20,63 @@ MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
- # Rest of the utility functions remain the same...
24
- [Previous utility functions: start_session, end_session, preprocess_image, pack_state, unpack_state, get_seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def image_to_3d(
27
  image: Image.Image,
@@ -167,8 +222,94 @@ def extract_reduced_glb(
167
  torch.cuda.empty_cache()
168
  raise gr.Error(f"GLB reduction failed: {str(e)}")
169
 
170
- # Rest of the UI code and demo definition remains the same...
171
- [Previous UI code with Blocks, event handlers, etc.]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  if __name__ == "__main__":
174
  # Set some CUDA memory management options
 
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
+
24
+ def start_session(req: gr.Request):
25
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
+ print(f'Creating user directory: {user_dir}')
27
+ os.makedirs(user_dir, exist_ok=True)
28
+
29
+ def end_session(req: gr.Request):
30
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
+ print(f'Removing user directory: {user_dir}')
32
+ shutil.rmtree(user_dir)
33
+
34
+ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
35
+ processed_image = pipeline.preprocess_image(image)
36
+ return processed_image
37
+
38
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
39
+ return {
40
+ 'gaussian': {
41
+ **gs.init_params,
42
+ '_xyz': gs._xyz.cpu().numpy(),
43
+ '_features_dc': gs._features_dc.cpu().numpy(),
44
+ '_scaling': gs._scaling.cpu().numpy(),
45
+ '_rotation': gs._rotation.cpu().numpy(),
46
+ '_opacity': gs._opacity.cpu().numpy(),
47
+ },
48
+ 'mesh': {
49
+ 'vertices': mesh.vertices.cpu().numpy(),
50
+ 'faces': mesh.faces.cpu().numpy(),
51
+ },
52
+ 'trial_id': trial_id,
53
+ }
54
+
55
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
56
+ gs = Gaussian(
57
+ aabb=state['gaussian']['aabb'],
58
+ sh_degree=state['gaussian']['sh_degree'],
59
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
60
+ scaling_bias=state['gaussian']['scaling_bias'],
61
+ opacity_bias=state['gaussian']['opacity_bias'],
62
+ scaling_activation=state['gaussian']['scaling_activation'],
63
+ )
64
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
65
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
66
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
67
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
68
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
69
+
70
+ mesh = edict(
71
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
72
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
73
+ )
74
+
75
+ return gs, mesh, state['trial_id']
76
+
77
+ def get_seed(randomize_seed: bool, seed: int) -> int:
78
+ """Get the random seed."""
79
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
80
 
81
  def image_to_3d(
82
  image: Image.Image,
 
222
  torch.cuda.empty_cache()
223
  raise gr.Error(f"GLB reduction failed: {str(e)}")
224
 
225
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
226
+ gr.Markdown("""
227
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
228
+ * Upload an image and click "Generate" to create a 3D model
229
+ * You can download either:
230
+ * The full-quality GLB file (larger size, highest quality)
231
+ * A reduced version with customizable quality settings (smaller size)
232
+ """)
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
237
+
238
+ with gr.Accordion(label="Generation Settings", open=False):
239
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
240
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
241
+ gr.Markdown("Stage 1: Sparse Structure Generation")
242
+ with gr.Row():
243
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
244
+ ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
245
+ gr.Markdown("Stage 2: Structured Latent Generation")
246
+ with gr.Row():
247
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
248
+ slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
249
+
250
+ generate_btn = gr.Button("Generate")
251
+
252
+ with gr.Accordion(label="Reduced GLB Settings", open=False):
253
+ mesh_simplify = gr.Slider(0.0, 0.98, label="Mesh Simplification", value=0.95, step=0.01)
254
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
255
+
256
+ extract_reduced_btn = gr.Button("Extract Reduced GLB", interactive=False)
257
+
258
+ with gr.Column():
259
+ video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
260
+ model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
261
+ gr.Markdown("### Download Options")
262
+ with gr.Row():
263
+ download_full = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
264
+ download_reduced = gr.DownloadButton(label="Download Reduced GLB", interactive=False)
265
+
266
+ output_buf = gr.State()
267
+
268
+ # Example images
269
+ with gr.Row():
270
+ examples = gr.Examples(
271
+ examples=[
272
+ f'assets/example_image/{image}'
273
+ for image in os.listdir("assets/example_image")
274
+ ],
275
+ inputs=[image_prompt],
276
+ fn=preprocess_image,
277
+ outputs=[image_prompt],
278
+ run_on_click=True,
279
+ examples_per_page=64,
280
+ )
281
+
282
+ # Event handlers
283
+ demo.load(start_session)
284
+ demo.unload(end_session)
285
+
286
+ image_prompt.upload(
287
+ preprocess_image,
288
+ inputs=[image_prompt],
289
+ outputs=[image_prompt],
290
+ )
291
+
292
+ generate_btn.click(
293
+ get_seed,
294
+ inputs=[randomize_seed, seed],
295
+ outputs=[seed],
296
+ ).then(
297
+ image_to_3d,
298
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
299
+ outputs=[output_buf, video_output, model_output, download_full],
300
+ ).then(
301
+ lambda: (gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)),
302
+ outputs=[download_full, extract_reduced_btn, download_reduced],
303
+ )
304
+
305
+ extract_reduced_btn.click(
306
+ extract_reduced_glb,
307
+ inputs=[output_buf, mesh_simplify, texture_size],
308
+ outputs=[model_output, download_reduced],
309
+ ).then(
310
+ lambda: gr.Button(interactive=True),
311
+ outputs=[download_reduced],
312
+ )
313
 
314
  if __name__ == "__main__":
315
  # Set some CUDA memory management options