cronos3k commited on
Commit
258ea5a
·
verified ·
1 Parent(s): a481d7a

Update app.py

Browse files

let's try a different approach

Files changed (1) hide show
  1. app.py +39 -229
app.py CHANGED
@@ -16,96 +16,26 @@ from trellis.pipelines import TrellisImageTo3DPipeline
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
- # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
- # Session Management Functions
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
-
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  print(f'Removing user directory: {user_dir}')
33
  shutil.rmtree(user_dir)
34
 
35
- # Image Preprocessing Function
36
- def preprocess_image(image: Image.Image) -> Image.Image:
37
- """
38
- Preprocess the input image.
39
-
40
- Args:
41
- image (Image.Image): The input image.
42
-
43
- Returns:
44
- Image.Image: The preprocessed image.
45
- """
46
- # Validate image
47
- if image is None:
48
- raise ValueError("No image provided.")
49
- if image.mode != "RGBA":
50
- image = image.convert("RGBA")
51
  processed_image = pipeline.preprocess_image(image)
52
  return processed_image
53
 
54
- # State Packing and Unpacking Functions
55
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
56
- return {
57
- 'gaussian': {
58
- **gs.init_params,
59
- '_xyz': gs._xyz.cpu().numpy(),
60
- '_features_dc': gs._features_dc.cpu().numpy(),
61
- '_scaling': gs._scaling.cpu().numpy(),
62
- '_rotation': gs._rotation.cpu().numpy(),
63
- '_opacity': gs._opacity.cpu().numpy(),
64
- },
65
- 'mesh': {
66
- 'vertices': mesh.vertices.cpu().numpy(),
67
- 'faces': mesh.faces.cpu().numpy(),
68
- },
69
- 'trial_id': trial_id,
70
- }
71
-
72
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
73
- gs = Gaussian(
74
- aabb=state['gaussian']['aabb'],
75
- sh_degree=state['gaussian']['sh_degree'],
76
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
77
- scaling_bias=state['gaussian']['scaling_bias'],
78
- opacity_bias=state['gaussian']['opacity_bias'],
79
- scaling_activation=state['gaussian']['scaling_activation'],
80
- )
81
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
82
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
83
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
84
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
85
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
86
-
87
- mesh = edict(
88
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
89
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
90
- )
91
-
92
- return gs, mesh, state['trial_id']
93
-
94
- # Seed Management Function
95
- def get_seed(randomize_seed: bool, seed: int) -> int:
96
- """
97
- Get the random seed.
98
-
99
- Args:
100
- randomize_seed (bool): Whether to randomize the seed.
101
- seed (int): The provided seed value.
102
-
103
- Returns:
104
- int: The final seed to use.
105
- """
106
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
107
-
108
- # Core 3D Generation Function
109
  @spaces.GPU
110
  def image_to_3d(
111
  image: Image.Image,
@@ -115,21 +45,12 @@ def image_to_3d(
115
  slat_guidance_strength: float,
116
  slat_sampling_steps: int,
117
  req: gr.Request,
118
- ) -> Tuple[dict, str]:
119
  """
120
- Convert an image to a 3D model.
121
-
122
- Args:
123
- image (Image.Image): The input image.
124
- seed (int): The random seed.
125
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
126
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
127
- slat_guidance_strength (float): The guidance strength for structured latent generation.
128
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
129
- req (gr.Request): Gradio request object.
130
-
131
  Returns:
132
- Tuple[dict, str]: The state dictionary and the path to the generated video.
133
  """
134
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
135
  outputs = pipeline.run(
@@ -146,134 +67,61 @@ def image_to_3d(
146
  "cfg_strength": slat_guidance_strength,
147
  },
148
  )
 
 
149
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
150
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
151
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
152
- trial_id = uuid.uuid4()
153
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
154
  imageio.mimsave(video_path, video, fps=15)
155
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
156
- torch.cuda.empty_cache()
157
- return state, video_path
158
-
159
- # Existing GLB Extraction Function
160
- @spaces.GPU
161
- def extract_glb(
162
- state: dict,
163
- mesh_simplify: float,
164
- texture_size: int,
165
- req: gr.Request,
166
- ) -> Tuple[str, str]:
167
- """
168
- Extract a GLB file from the 3D model.
169
-
170
- Args:
171
- state (dict): The state of the generated 3D model.
172
- mesh_simplify (float): The mesh simplification factor.
173
- texture_size (int): The texture resolution.
174
- req (gr.Request): Gradio request object.
175
-
176
- Returns:
177
- Tuple[str, str]: The path to the extracted GLB file.
178
- """
179
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
180
- gs, mesh, trial_id = unpack_state(state)
181
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
182
- glb_path = os.path.join(user_dir, f"{trial_id}.glb")
183
- glb.export(glb_path)
184
- torch.cuda.empty_cache()
185
- return glb_path, glb_path
186
-
187
- # **Addition: High-Quality GLB Extraction Function**
188
- @spaces.GPU
189
- def extract_glb_high_quality(
190
- state: dict,
191
- req: gr.Request,
192
- ) -> Tuple[str, str]:
193
- """
194
- Extract a high-quality GLB file from the 3D model without polygon reduction.
195
-
196
- Args:
197
- state (dict): The state of the generated 3D model.
198
- req (gr.Request): Gradio request object.
199
-
200
- Returns:
201
- Tuple[str, str]: The path to the high-quality GLB file.
202
- """
203
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
204
- gs, mesh, trial_id = unpack_state(state)
205
- # Set simplify to 0.0 to disable polygon reduction
206
- # Set texture_size to 2048 for maximum texture quality
207
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.0, texture_size=2048, verbose=False)
208
- glb_path = os.path.join(user_dir, f"{trial_id}_high_quality.glb")
209
  glb.export(glb_path)
 
210
  torch.cuda.empty_cache()
211
- return glb_path, glb_path
212
 
213
- # Gradio Interface Definition
214
  with gr.Blocks(delete_cache=(600, 600)) as demo:
215
  gr.Markdown("""
216
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
217
- * Upload an image and click "Generate" to create a 3D asset. If the image has an alpha channel, it will be used as the mask. Otherwise, the background will be removed automatically.
218
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
219
- * **New:** Click "Download High Quality GLB" to download the GLB file without any polygon reduction and with maximum texture quality.
220
  """)
221
 
222
  with gr.Row():
223
  with gr.Column():
224
- # Image Input
225
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
226
 
227
- # Generation Settings Accordion
228
  with gr.Accordion(label="Generation Settings", open=False):
229
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
230
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
231
- gr.Markdown("### Stage 1: Sparse Structure Generation")
232
  with gr.Row():
233
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
234
  ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
235
- gr.Markdown("### Stage 2: Structured Latent Generation")
236
  with gr.Row():
237
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
238
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
239
 
240
- # Generate Button
241
  generate_btn = gr.Button("Generate")
242
-
243
- # GLB Extraction Settings Accordion
244
- with gr.Accordion(label="GLB Extraction Settings", open=False):
245
- mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
246
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
247
-
248
- # Existing Extract GLB Button
249
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
250
-
251
- # **Addition: Download High Quality GLB Button**
252
- extract_glb_high_quality_btn = gr.Button("Download High Quality GLB", interactive=False)
253
 
254
  with gr.Column():
255
- # Video Output
256
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
257
- # 3D Model Display
258
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
259
- # Existing Download GLB Button
260
- download_glb = gr.DownloadButton(
261
- label="Download GLB",
262
- # Removed 'file_count' to prevent runtime error
263
- )
264
 
265
- # **Addition: Download High Quality GLB DownloadButton**
266
- download_high_quality_glb = gr.DownloadButton(
267
- label="Download High Quality GLB",
268
- # Removed 'file_count' to prevent runtime error
269
- )
270
-
271
- # State Variables
272
- output_buf = gr.State()
273
- glb_path_state = gr.State() # For standard GLB
274
- glb_high_quality_path_state = gr.State() # For high-quality GLB
275
-
276
- # Example Images
277
  with gr.Row():
278
  examples = gr.Examples(
279
  examples=[
@@ -287,7 +135,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
287
  examples_per_page=64,
288
  )
289
 
290
- # Event Handlers
291
  demo.load(start_session)
292
  demo.unload(end_session)
293
 
@@ -301,59 +149,21 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
301
  get_seed,
302
  inputs=[randomize_seed, seed],
303
  outputs=[seed],
304
- concurrency_limit=1 # Set concurrency limit for Generate
305
  ).then(
306
  image_to_3d,
307
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, gr.Request()],
308
- outputs=[output_buf, video_output],
309
- concurrency_limit=1 # Set concurrency limit for image_to_3d
310
- ).then(
311
- # Enable the Extract GLB and Download High Quality GLB buttons after generation
312
- lambda: (True, True),
313
- outputs=[extract_glb_btn, extract_glb_high_quality_btn],
314
- )
315
-
316
- video_output.clear(
317
- lambda: (False, False),
318
- outputs=[extract_glb_btn, extract_glb_high_quality_btn],
319
- )
320
-
321
- extract_glb_btn.click(
322
- extract_glb,
323
- inputs=[output_buf, mesh_simplify, texture_size, gr.Request()],
324
- outputs=[model_output, glb_path_state],
325
- concurrency_limit=1 # Set concurrency limit for extract_glb
326
  ).then(
327
- lambda glb_path: glb_path if glb_path else "",
328
- inputs=[glb_path_state],
329
  outputs=[download_glb],
330
  )
331
 
332
- # **Addition: High-Quality GLB Extraction and Download**
333
- extract_glb_high_quality_btn.click(
334
- extract_glb_high_quality,
335
- inputs=[output_buf, gr.Request()],
336
- outputs=[model_output, glb_high_quality_path_state],
337
- concurrency_limit=1 # Set concurrency limit for extract_glb_high_quality
338
- ).then(
339
- lambda glb_path: glb_path if glb_path else "",
340
- inputs=[glb_high_quality_path_state],
341
- outputs=[download_high_quality_glb],
342
- )
343
-
344
- model_output.clear(
345
- lambda: (gr.File.update(value=None), gr.File.update(value=None)),
346
- outputs=[download_glb, download_high_quality_glb],
347
- )
348
-
349
  # Launch the Gradio app
350
  if __name__ == "__main__":
351
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
352
  pipeline.cuda()
353
  try:
354
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
355
- except Exception as e:
356
- print(f"Preloading rembg failed: {e}")
357
-
358
- # Configure Gradio's queue without deprecated parameters
359
- demo.queue().launch()
 
16
  from trellis.representations import Gaussian, MeshExtractResult
17
  from trellis.utils import render_utils, postprocessing_utils
18
 
19
+
20
  MAX_SEED = np.iinfo(np.int32).max
21
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
+
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
+
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  print(f'Removing user directory: {user_dir}')
33
  shutil.rmtree(user_dir)
34
 
35
+ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  processed_image = pipeline.preprocess_image(image)
37
  return processed_image
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @spaces.GPU
40
  def image_to_3d(
41
  image: Image.Image,
 
45
  slat_guidance_strength: float,
46
  slat_sampling_steps: int,
47
  req: gr.Request,
48
+ ) -> Tuple[str, str, str]:
49
  """
50
+ Convert an image to a 3D model and save both video preview and full-quality GLB.
51
+
 
 
 
 
 
 
 
 
 
52
  Returns:
53
+ Tuple[str, str, str]: (video_path, glb_path, download_path)
54
  """
55
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
56
  outputs = pipeline.run(
 
67
  "cfg_strength": slat_guidance_strength,
68
  },
69
  )
70
+
71
+ # Generate and save video preview
72
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
73
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
74
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
75
+ trial_id = str(uuid.uuid4())
76
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
77
  imageio.mimsave(video_path, video, fps=15)
78
+
79
+ # Save full-quality GLB directly from the generated mesh
80
+ glb = postprocessing_utils.to_glb(
81
+ outputs['gaussian'][0],
82
+ outputs['mesh'][0],
83
+ simplify=0.0, # No simplification
84
+ texture_size=2048, # Maximum texture resolution
85
+ verbose=False
86
+ )
87
+ glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  glb.export(glb_path)
89
+
90
  torch.cuda.empty_cache()
91
+ return video_path, glb_path, glb_path
92
 
 
93
  with gr.Blocks(delete_cache=(600, 600)) as demo:
94
  gr.Markdown("""
95
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
96
+ * Upload an image and click "Generate" to create a high-quality 3D model
97
+ * Once generation is complete, you can download the full-quality GLB file
98
+ * The model will be in maximum quality with no reduction applied
99
  """)
100
 
101
  with gr.Row():
102
  with gr.Column():
 
103
  image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
104
 
 
105
  with gr.Accordion(label="Generation Settings", open=False):
106
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
107
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
108
+ gr.Markdown("Stage 1: Sparse Structure Generation")
109
  with gr.Row():
110
  ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
111
  ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
112
+ gr.Markdown("Stage 2: Structured Latent Generation")
113
  with gr.Row():
114
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
115
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
116
 
 
117
  generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  with gr.Column():
120
+ video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
121
+ model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
122
+ download_glb = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
 
 
 
 
 
 
123
 
124
+ # Example images
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Row():
126
  examples = gr.Examples(
127
  examples=[
 
135
  examples_per_page=64,
136
  )
137
 
138
+ # Event handlers
139
  demo.load(start_session)
140
  demo.unload(end_session)
141
 
 
149
  get_seed,
150
  inputs=[randomize_seed, seed],
151
  outputs=[seed],
 
152
  ).then(
153
  image_to_3d,
154
+ inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
155
+ outputs=[video_output, model_output, download_glb],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  ).then(
157
+ lambda: gr.Button(interactive=True),
 
158
  outputs=[download_glb],
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  # Launch the Gradio app
162
  if __name__ == "__main__":
163
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
164
  pipeline.cuda()
165
  try:
166
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
167
+ except:
168
+ pass
169
+ demo.launch()