cronos3k commited on
Commit
75b28df
·
verified ·
1 Parent(s): fcdcd8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -40
app.py CHANGED
@@ -100,9 +100,9 @@ def image_to_3d(
100
  slat_guidance_strength: float,
101
  slat_sampling_steps: int,
102
  req: gr.Request,
103
- ) -> Tuple[dict, str]:
104
  """
105
- Convert an image to a 3D model.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
108
  outputs = pipeline.run(
@@ -119,39 +119,30 @@ def image_to_3d(
119
  "cfg_strength": slat_guidance_strength,
120
  },
121
  )
122
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=20)['color']
123
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=20)['normal']
 
 
124
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
125
  trial_id = str(uuid.uuid4())
126
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
127
  imageio.mimsave(video_path, video, fps=15)
128
-
129
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
130
- return state, video_path
131
 
132
- @spaces.GPU
133
- def export_full_quality_glb(
134
- state: dict,
135
- req: gr.Request,
136
- ) -> Tuple[str, str]:
137
- """
138
- Export a full-quality GLB file from the 3D model state.
139
- """
140
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
141
- gs, mesh, trial_id = unpack_state(state)
142
-
143
  glb = postprocessing_utils.to_glb(
144
- gs,
145
- mesh,
146
  simplify=0.0, # No simplification
147
  fill_holes=True,
148
  fill_holes_max_size=0.04,
149
  texture_size=2048, # Maximum texture resolution
150
  verbose=False
151
  )
152
- glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
153
- glb.export(glb_path)
154
- return glb_path, glb_path
 
155
 
156
  @spaces.GPU
157
  def extract_glb(
@@ -161,7 +152,7 @@ def extract_glb(
161
  req: gr.Request,
162
  ) -> Tuple[str, str]:
163
  """
164
- Extract a GLB file from the 3D model.
165
  """
166
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
167
  gs, mesh, trial_id = unpack_state(state)
@@ -174,9 +165,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
174
  gr.Markdown("""
175
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
176
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
177
- * After generation:
178
- * Extract full-quality GLB for maximum detail
179
- * Or use the GLB Extraction Settings for a reduced size version
180
  """)
181
 
182
  with gr.Row():
@@ -196,7 +186,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
196
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
197
 
198
  generate_btn = gr.Button("Generate")
199
- extract_full_btn = gr.Button("Extract Full GLB", interactive=False)
200
 
201
  with gr.Accordion(label="GLB Extraction Settings", open=False):
202
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
@@ -244,19 +233,10 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
244
  ).then(
245
  image_to_3d,
246
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
247
- outputs=[output_buf, video_output],
248
  ).then(
249
- lambda: [gr.Button(interactive=True), gr.Button(interactive=True)],
250
- outputs=[extract_full_btn, extract_glb_btn],
251
- )
252
-
253
- extract_full_btn.click(
254
- export_full_quality_glb,
255
- inputs=[output_buf],
256
- outputs=[model_output, download_full],
257
- ).then(
258
- lambda: gr.Button(interactive=True),
259
- outputs=[download_full],
260
  )
261
 
262
  extract_glb_btn.click(
 
100
  slat_guidance_strength: float,
101
  slat_sampling_steps: int,
102
  req: gr.Request,
103
+ ) -> Tuple[dict, str, str]:
104
  """
105
+ Convert an image to a 3D model and save full-quality GLB.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
108
  outputs = pipeline.run(
 
119
  "cfg_strength": slat_guidance_strength,
120
  },
121
  )
122
+
123
+ # Generate video preview
124
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
125
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
126
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
127
  trial_id = str(uuid.uuid4())
128
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
129
  imageio.mimsave(video_path, video, fps=15)
 
 
 
130
 
131
+ # Create full quality GLB while we have the data in memory
132
+ full_glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
 
 
 
 
 
 
 
 
 
133
  glb = postprocessing_utils.to_glb(
134
+ outputs['gaussian'][0],
135
+ outputs['mesh'][0],
136
  simplify=0.0, # No simplification
137
  fill_holes=True,
138
  fill_holes_max_size=0.04,
139
  texture_size=2048, # Maximum texture resolution
140
  verbose=False
141
  )
142
+ glb.export(full_glb_path)
143
+
144
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
145
+ return state, video_path, full_glb_path
146
 
147
  @spaces.GPU
148
  def extract_glb(
 
152
  req: gr.Request,
153
  ) -> Tuple[str, str]:
154
  """
155
+ Extract a reduced-quality GLB file from the 3D model.
156
  """
157
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
158
  gs, mesh, trial_id = unpack_state(state)
 
165
  gr.Markdown("""
166
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
167
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
168
+ * The full-quality GLB will be available immediately after generation.
169
+ * You can also create a reduced size version using the GLB Extraction Settings.
 
170
  """)
171
 
172
  with gr.Row():
 
186
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
187
 
188
  generate_btn = gr.Button("Generate")
 
189
 
190
  with gr.Accordion(label="GLB Extraction Settings", open=False):
191
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
 
233
  ).then(
234
  image_to_3d,
235
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
236
+ outputs=[output_buf, video_output, download_full],
237
  ).then(
238
+ lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
239
+ outputs=[download_full, extract_glb_btn, download_reduced],
 
 
 
 
 
 
 
 
 
240
  )
241
 
242
  extract_glb_btn.click(