cronos3k commited on
Commit
b62fa8c
·
verified ·
1 Parent(s): 9a9f462

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -56
app.py CHANGED
@@ -31,16 +31,6 @@ def end_session(req: gr.Request):
31
  shutil.rmtree(user_dir)
32
 
33
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
34
- """
35
- Preprocess the input image.
36
-
37
- Args:
38
- image (Image.Image): The input image.
39
-
40
- Returns:
41
- str: uuid of the trial.
42
- Image.Image: The preprocessed image.
43
- """
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
46
 
@@ -84,9 +74,6 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
84
  return gs, mesh, state['trial_id']
85
 
86
  def get_seed(randomize_seed: bool, seed: int) -> int:
87
- """
88
- Get the random seed.
89
- """
90
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
91
 
92
  def image_to_3d(
@@ -97,7 +84,7 @@ def image_to_3d(
97
  slat_guidance_strength: float,
98
  slat_sampling_steps: int,
99
  req: gr.Request,
100
- ) -> Tuple[dict, str, str, str]:
101
  """
102
  Convert an image to a 3D model.
103
  """
@@ -116,49 +103,67 @@ def image_to_3d(
116
  "cfg_strength": slat_guidance_strength,
117
  },
118
  )
 
 
119
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
120
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
121
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
122
  trial_id = str(uuid.uuid4())
123
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
124
  imageio.mimsave(video_path, video, fps=15)
125
-
126
  # Save full quality GLB
127
- full_glb = postprocessing_utils.to_glb(
128
  outputs['gaussian'][0],
129
  outputs['mesh'][0],
130
- simplify=0.0, # No simplification
131
- texture_size=2048, # Maximum texture resolution
 
 
132
  verbose=False
133
  )
134
  full_glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
135
- full_glb.export(full_glb_path)
136
 
 
137
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
138
- return state, video_path, model_output, full_glb_path
 
139
 
140
- def extract_glb(
141
  state: dict,
142
  mesh_simplify: float,
143
  texture_size: int,
144
  req: gr.Request,
145
  ) -> Tuple[str, str]:
146
  """
147
- Extract a GLB file from the 3D model.
148
  """
149
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
150
  gs, mesh, trial_id = unpack_state(state)
151
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
152
- glb_path = os.path.join(user_dir, f"{trial_id}.glb")
153
- glb.export(glb_path)
154
- return glb_path, glb_path
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  with gr.Blocks(delete_cache=(600, 600)) as demo:
157
  gr.Markdown("""
158
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
159
  * Upload an image and click "Generate" to create a 3D asset
160
- * You can download the full quality GLB immediately after generation
161
- * Or create a reduced size version using the GLB Extraction Settings
 
162
  """)
163
 
164
  with gr.Row():
@@ -179,15 +184,17 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
179
 
180
  generate_btn = gr.Button("Generate")
181
 
182
- with gr.Accordion(label="GLB Extraction Settings", open=False):
183
- mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
 
184
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
185
 
186
- extract_glb_btn = gr.Button("Extract Reduced GLB", interactive=False)
187
 
188
  with gr.Column():
189
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
190
  model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
 
191
  with gr.Row():
192
  download_full = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
193
  download_reduced = gr.DownloadButton(label="Download Reduced GLB", interactive=False)
@@ -223,28 +230,4 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
223
  inputs=[randomize_seed, seed],
224
  outputs=[seed],
225
  ).then(
226
- image_to_3d,
227
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
228
- outputs=[output_buf, video_output, model_output, download_full],
229
- ).then(
230
- lambda: (gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)),
231
- outputs=[download_full, extract_glb_btn, download_reduced],
232
- )
233
-
234
- extract_glb_btn.click(
235
- extract_glb,
236
- inputs=[output_buf, mesh_simplify, texture_size],
237
- outputs=[model_output, download_reduced],
238
- ).then(
239
- lambda: gr.Button(interactive=True),
240
- outputs=[download_reduced],
241
- )
242
-
243
- if __name__ == "__main__":
244
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
245
- pipeline.cuda()
246
- try:
247
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
248
- except:
249
- pass
250
- demo.launch()
 
31
  shutil.rmtree(user_dir)
32
 
33
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
 
 
 
 
 
 
 
 
 
 
34
  processed_image = pipeline.preprocess_image(image)
35
  return processed_image
36
 
 
74
  return gs, mesh, state['trial_id']
75
 
76
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
77
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
78
 
79
  def image_to_3d(
 
84
  slat_guidance_strength: float,
85
  slat_sampling_steps: int,
86
  req: gr.Request,
87
+ ) -> Tuple[dict, str, str]:
88
  """
89
  Convert an image to a 3D model.
90
  """
 
103
  "cfg_strength": slat_guidance_strength,
104
  },
105
  )
106
+
107
+ # Generate video preview
108
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
109
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
110
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
111
  trial_id = str(uuid.uuid4())
112
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
113
  imageio.mimsave(video_path, video, fps=15)
114
+
115
  # Save full quality GLB
116
+ glb = postprocessing_utils.to_glb(
117
  outputs['gaussian'][0],
118
  outputs['mesh'][0],
119
+ simplify=0.0, # No simplification for full quality
120
+ fill_holes=True,
121
+ fill_holes_max_size=0.04,
122
+ texture_size=2048, # Maximum texture size
123
  verbose=False
124
  )
125
  full_glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
126
+ glb.export(full_glb_path)
127
 
128
+ # Pack state for potential reduced version
129
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
130
+
131
+ return state, video_path, full_glb_path
132
 
133
+ def extract_reduced_glb(
134
  state: dict,
135
  mesh_simplify: float,
136
  texture_size: int,
137
  req: gr.Request,
138
  ) -> Tuple[str, str]:
139
  """
140
+ Extract a reduced quality GLB file.
141
  """
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
  gs, mesh, trial_id = unpack_state(state)
144
+
145
+ # Create reduced quality GLB with user settings
146
+ glb = postprocessing_utils.to_glb(
147
+ gs,
148
+ mesh,
149
+ simplify=mesh_simplify,
150
+ fill_holes=True,
151
+ fill_holes_max_size=0.04,
152
+ texture_size=texture_size,
153
+ verbose=False
154
+ )
155
+ reduced_glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
156
+ glb.export(reduced_glb_path)
157
+
158
+ return reduced_glb_path, reduced_glb_path
159
 
160
  with gr.Blocks(delete_cache=(600, 600)) as demo:
161
  gr.Markdown("""
162
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
163
  * Upload an image and click "Generate" to create a 3D asset
164
+ * After generation:
165
+ * Download the full quality GLB (no mesh simplification, maximum texture resolution)
166
+ * Or create a reduced size version with customizable settings
167
  """)
168
 
169
  with gr.Row():
 
184
 
185
  generate_btn = gr.Button("Generate")
186
 
187
+ with gr.Accordion(label="Reduced GLB Settings", open=False):
188
+ mesh_simplify = gr.Slider(0.0, 0.98, label="Mesh Simplification", value=0.95, step=0.01,
189
+ info="Higher values = more reduction")
190
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
191
 
192
+ extract_reduced_btn = gr.Button("Extract Reduced GLB", interactive=False)
193
 
194
  with gr.Column():
195
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
196
  model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
197
+ gr.Markdown("### Download Options")
198
  with gr.Row():
199
  download_full = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
200
  download_reduced = gr.DownloadButton(label="Download Reduced GLB", interactive=False)
 
230
  inputs=[randomize_seed, seed],
231
  outputs=[seed],
232
  ).then(
233
+ image_to_3d,