cronos3k commited on
Commit
b41662d
·
verified ·
1 Parent(s): 75b28df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -17
app.py CHANGED
@@ -100,11 +100,13 @@ def image_to_3d(
100
  slat_guidance_strength: float,
101
  slat_sampling_steps: int,
102
  req: gr.Request,
103
- ) -> Tuple[dict, str, str]:
104
  """
105
- Convert an image to a 3D model and save full-quality GLB.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
108
  outputs = pipeline.run(
109
  image,
110
  seed=seed,
@@ -120,29 +122,82 @@ def image_to_3d(
120
  },
121
  )
122
 
123
- # Generate video preview
124
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
125
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
127
  trial_id = str(uuid.uuid4())
128
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
129
  imageio.mimsave(video_path, video, fps=15)
 
 
 
 
 
 
 
 
 
130
 
131
- # Create full quality GLB while we have the data in memory
132
- full_glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
 
 
 
 
 
 
 
 
 
 
 
 
133
  glb = postprocessing_utils.to_glb(
134
- outputs['gaussian'][0],
135
- outputs['mesh'][0],
136
  simplify=0.0, # No simplification
137
  fill_holes=True,
138
  fill_holes_max_size=0.04,
139
  texture_size=2048, # Maximum texture resolution
140
- verbose=False
141
  )
142
- glb.export(full_glb_path)
 
143
 
144
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
145
- return state, video_path, full_glb_path
 
146
 
147
  @spaces.GPU
148
  def extract_glb(
@@ -152,7 +207,7 @@ def extract_glb(
152
  req: gr.Request,
153
  ) -> Tuple[str, str]:
154
  """
155
- Extract a reduced-quality GLB file from the 3D model.
156
  """
157
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
158
  gs, mesh, trial_id = unpack_state(state)
@@ -165,8 +220,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
165
  gr.Markdown("""
166
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
167
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
168
- * The full-quality GLB will be available immediately after generation.
169
- * You can also create a reduced size version using the GLB Extraction Settings.
 
170
  """)
171
 
172
  with gr.Row():
@@ -233,12 +289,21 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
233
  ).then(
234
  image_to_3d,
235
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
236
- outputs=[output_buf, video_output, download_full],
237
  ).then(
238
  lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
239
  outputs=[download_full, extract_glb_btn, download_reduced],
240
  )
241
 
 
 
 
 
 
 
 
 
 
242
  extract_glb_btn.click(
243
  extract_glb,
244
  inputs=[output_buf, mesh_simplify, texture_size],
 
100
  slat_guidance_strength: float,
101
  slat_sampling_steps: int,
102
  req: gr.Request,
103
+ ) -> Tuple[dict, str]:
104
  """
105
+ Convert an image to a 3D model with memory management.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
108
+
109
+ # Generate base outputs
110
  outputs = pipeline.run(
111
  image,
112
  seed=seed,
 
122
  },
123
  )
124
 
125
+ # Clear CUDA cache after model generation
126
+ torch.cuda.empty_cache()
127
+
128
+ # Generate video preview in smaller batches
129
+ video = []
130
+ video_geo = []
131
+ batch_size = 30 # Process 30 frames at a time
132
+ num_frames = 120
133
+
134
+ for i in range(0, num_frames, batch_size):
135
+ end_idx = min(i + batch_size, num_frames)
136
+ curr_frames = end_idx - i
137
+
138
+ # Generate color frames
139
+ batch_frames = render_utils.render_video(
140
+ outputs['gaussian'][0],
141
+ num_frames=curr_frames,
142
+ start_frame=i
143
+ )['color']
144
+ video.extend(batch_frames)
145
+
146
+ # Generate geometry frames
147
+ batch_geo = render_utils.render_video(
148
+ outputs['mesh'][0],
149
+ num_frames=curr_frames,
150
+ start_frame=i
151
+ )['normal']
152
+ video_geo.extend(batch_geo)
153
+
154
+ # Clear cache after each batch
155
+ torch.cuda.empty_cache()
156
+
157
+ # Combine and save video
158
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
159
  trial_id = str(uuid.uuid4())
160
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
161
  imageio.mimsave(video_path, video, fps=15)
162
+
163
+ # Clear memory
164
+ del video
165
+ del video_geo
166
+ torch.cuda.empty_cache()
167
+
168
+ # Pack state and return
169
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
170
+ return state, video_path
171
 
172
+ @spaces.GPU
173
+ def export_full_quality_glb(
174
+ state: dict,
175
+ req: gr.Request,
176
+ ) -> Tuple[str, str]:
177
+ """
178
+ Export a full-quality GLB file with memory management.
179
+ """
180
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
181
+ gs, mesh, trial_id = unpack_state(state)
182
+
183
+ # Clear cache before starting
184
+ torch.cuda.empty_cache()
185
+
186
  glb = postprocessing_utils.to_glb(
187
+ gs,
188
+ mesh,
189
  simplify=0.0, # No simplification
190
  fill_holes=True,
191
  fill_holes_max_size=0.04,
192
  texture_size=2048, # Maximum texture resolution
193
+ verbose=True # Show progress
194
  )
195
+ glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
196
+ glb.export(glb_path)
197
 
198
+ # Clear cache after finishing
199
+ torch.cuda.empty_cache()
200
+ return glb_path, glb_path
201
 
202
  @spaces.GPU
203
  def extract_glb(
 
207
  req: gr.Request,
208
  ) -> Tuple[str, str]:
209
  """
210
+ Extract a GLB file from the 3D model.
211
  """
212
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
213
  gs, mesh, trial_id = unpack_state(state)
 
220
  gr.Markdown("""
221
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
222
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
223
+ * After generation:
224
+ * Click "Download Full-Quality GLB" for maximum quality
225
+ * Or use GLB Extraction Settings for a reduced size version
226
  """)
227
 
228
  with gr.Row():
 
289
  ).then(
290
  image_to_3d,
291
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
292
+ outputs=[output_buf, video_output],
293
  ).then(
294
  lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
295
  outputs=[download_full, extract_glb_btn, download_reduced],
296
  )
297
 
298
+ download_full.click(
299
+ export_full_quality_glb,
300
+ inputs=[output_buf],
301
+ outputs=[model_output, download_full],
302
+ ).then(
303
+ lambda: gr.Button(interactive=True),
304
+ outputs=[download_full],
305
+ )
306
+
307
  extract_glb_btn.click(
308
  extract_glb,
309
  inputs=[output_buf, mesh_simplify, texture_size],