cavargas10 commited on
Commit
e92b7dd
·
verified ·
1 Parent(s): e25010f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -47
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
@@ -15,32 +14,33 @@ from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
-
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
-
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
-
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
32
 
33
-
34
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
 
 
 
 
 
 
 
35
  images = [image[0] for image in images]
36
  processed_images = [pipeline.preprocess_image(image) for image in images]
37
  return processed_images
38
 
39
-
40
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
41
  return {
42
  'gaussian': {
43
- **gs.init_params,
44
  '_xyz': gs._xyz.cpu().numpy(),
45
  '_features_dc': gs._features_dc.cpu().numpy(),
46
  '_scaling': gs._scaling.cpu().numpy(),
@@ -52,9 +52,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
52
  'faces': mesh.faces.cpu().numpy(),
53
  },
54
  }
55
-
56
-
57
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
58
  gs = Gaussian(
59
  aabb=state['gaussian']['aabb'],
60
  sh_degree=state['gaussian']['sh_degree'],
@@ -68,19 +67,18 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
69
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
70
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
71
-
72
  mesh = edict(
73
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
74
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
75
  )
76
-
77
  return gs, mesh
78
 
79
-
80
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
81
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
82
 
83
-
84
  @spaces.GPU
85
  def image_to_3d(
86
  multiimages: List[Tuple[Image.Image, str]],
@@ -92,22 +90,25 @@ def image_to_3d(
92
  multiimage_algo: Literal["multidiffusion", "stochastic"],
93
  req: gr.Request,
94
  ) -> Tuple[dict, str]:
 
 
 
95
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
96
- outputs = pipeline.run_multi_image(
97
- [image[0] for image in multiimages],
98
- seed=seed,
99
- formats=["gaussian", "mesh"],
100
- preprocess_image=False,
101
- sparse_structure_sampler_params={
102
- "steps": ss_sampling_steps,
103
- "cfg_strength": ss_guidance_strength,
104
- },
105
- slat_sampler_params={
106
- "steps": slat_sampling_steps,
107
- "cfg_strength": slat_guidance_strength,
108
- },
109
- mode=multiimage_algo,
110
- )
111
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
112
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
113
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -117,7 +118,6 @@ def image_to_3d(
117
  torch.cuda.empty_cache()
118
  return state, video_path
119
 
120
-
121
  @spaces.GPU(duration=90)
122
  def extract_glb(
123
  state: dict,
@@ -125,6 +125,9 @@ def extract_glb(
125
  texture_size: int,
126
  req: gr.Request,
127
  ) -> Tuple[str, str]:
 
 
 
128
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
129
  gs, mesh = unpack_state(state)
130
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
@@ -133,9 +136,11 @@ def extract_glb(
133
  torch.cuda.empty_cache()
134
  return glb_path, glb_path
135
 
136
-
137
  @spaces.GPU
138
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
 
 
 
139
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
140
  gs, _ = unpack_state(state)
141
  gaussian_path = os.path.join(user_dir, 'sample.ply')
@@ -143,18 +148,17 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
143
  torch.cuda.empty_cache()
144
  return gaussian_path, gaussian_path
145
 
146
-
147
  def prepare_multi_example() -> List[Image.Image]:
148
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
149
  images = []
150
  for case in multi_case:
151
- _images = []
152
  for i in range(1, 4):
153
  img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
154
  W, H = img.size
155
  img = img.resize((int(W / H * 512), 512))
156
- _images.append(np.array(img))
157
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
158
  return images
159
 
160
  with gr.Blocks(delete_cache=(600, 600)) as demo:
@@ -208,30 +212,30 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
208
  with gr.Row():
209
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
210
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
211
-
212
  output_buf = gr.State()
213
-
214
  # Example images at the bottom of the page
215
  with gr.Row(visible=True) as multiimage_example:
216
  examples_multi = gr.Examples(
217
  examples=prepare_multi_example(),
218
  inputs=[multiimage_prompt],
219
- fn=split_image,
220
  outputs=[multiimage_prompt],
221
  run_on_click=True,
222
  examples_per_page=8,
223
  )
224
-
225
  # Handlers
226
  demo.load(start_session)
227
  demo.unload(end_session)
228
-
229
  multiimage_prompt.upload(
230
  preprocess_images,
231
  inputs=[multiimage_prompt],
232
  outputs=[multiimage_prompt],
233
  )
234
-
235
  generate_btn.click(
236
  get_seed,
237
  inputs=[randomize_seed, seed],
@@ -244,12 +248,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
244
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
245
  outputs=[extract_glb_btn, extract_gs_btn],
246
  )
247
-
248
  video_output.clear(
249
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
250
  outputs=[extract_glb_btn, extract_gs_btn],
251
  )
252
-
253
  extract_glb_btn.click(
254
  extract_glb,
255
  inputs=[output_buf, mesh_simplify, texture_size],
@@ -258,7 +262,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
258
  lambda: gr.Button(interactive=True),
259
  outputs=[download_glb],
260
  )
261
-
262
  extract_gs_btn.click(
263
  extract_gaussian,
264
  inputs=[output_buf],
@@ -267,7 +271,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
267
  lambda: gr.Button(interactive=True),
268
  outputs=[download_gs],
269
  )
270
-
271
  model_output.clear(
272
  lambda: gr.Button(interactive=False),
273
  outputs=[download_glb],
@@ -281,4 +285,4 @@ if __name__ == "__main__":
281
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
282
  except:
283
  pass
284
- demo.launch(show_error=True)
 
1
  import gradio as gr
2
  import spaces
3
  from gradio_litmodel3d import LitModel3D
 
4
  import os
5
  import shutil
6
  os.environ['SPCONV_ALGO'] = 'native'
 
14
  from trellis.representations import Gaussian, MeshExtractResult
15
  from trellis.utils import render_utils, postprocessing_utils
16
 
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
19
  os.makedirs(TMP_DIR, exist_ok=True)
20
 
 
21
  def start_session(req: gr.Request):
22
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
23
  os.makedirs(user_dir, exist_ok=True)
24
+
 
25
  def end_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  shutil.rmtree(user_dir)
28
 
 
29
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
30
+ """
31
+ Preprocess a list of input images.
32
+ Args:
33
+ images (List[Tuple[Image.Image, str]]): The input images.
34
+ Returns:
35
+ List[Image.Image]: The preprocessed images.
36
+ """
37
  images = [image[0] for image in images]
38
  processed_images = [pipeline.preprocess_image(image) for image in images]
39
  return processed_images
40
 
 
41
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
42
  return {
43
  'gaussian': {
 
44
  '_xyz': gs._xyz.cpu().numpy(),
45
  '_features_dc': gs._features_dc.cpu().numpy(),
46
  '_scaling': gs._scaling.cpu().numpy(),
 
52
  'faces': mesh.faces.cpu().numpy(),
53
  },
54
  }
55
+
56
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
 
57
  gs = Gaussian(
58
  aabb=state['gaussian']['aabb'],
59
  sh_degree=state['gaussian']['sh_degree'],
 
67
  gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
68
  gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
69
  gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
 
70
  mesh = edict(
71
  vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
72
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
73
  )
 
74
  return gs, mesh
75
 
 
76
  def get_seed(randomize_seed: bool, seed: int) -> int:
77
+ """
78
+ Get the random seed.
79
+ """
80
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
81
 
 
82
  @spaces.GPU
83
  def image_to_3d(
84
  multiimages: List[Tuple[Image.Image, str]],
 
90
  multiimage_algo: Literal["multidiffusion", "stochastic"],
91
  req: gr.Request,
92
  ) -> Tuple[dict, str]:
93
+ """
94
+ Convert multiple images to a 3D model.
95
+ """
96
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
97
+ outputs = pipeline.run_multi_image(
98
+ [image[0] for image in multiimages],
99
+ seed=seed,
100
+ formats=["gaussian", "mesh"],
101
+ preprocess_image=False,
102
+ sparse_structure_sampler_params={
103
+ "steps": ss_sampling_steps,
104
+ "cfg_strength": ss_guidance_strength,
105
+ },
106
+ slat_sampler_params={
107
+ "steps": slat_sampling_steps,
108
+ "cfg_strength": slat_guidance_strength,
109
+ },
110
+ mode=multiimage_algo,
111
+ )
112
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
113
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
114
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
118
  torch.cuda.empty_cache()
119
  return state, video_path
120
 
 
121
  @spaces.GPU(duration=90)
122
  def extract_glb(
123
  state: dict,
 
125
  texture_size: int,
126
  req: gr.Request,
127
  ) -> Tuple[str, str]:
128
+ """
129
+ Extract a GLB file from the 3D model.
130
+ """
131
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  gs, mesh = unpack_state(state)
133
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
 
136
  torch.cuda.empty_cache()
137
  return glb_path, glb_path
138
 
 
139
  @spaces.GPU
140
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
141
+ """
142
+ Extract a Gaussian file from the 3D model.
143
+ """
144
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
145
  gs, _ = unpack_state(state)
146
  gaussian_path = os.path.join(user_dir, 'sample.ply')
 
148
  torch.cuda.empty_cache()
149
  return gaussian_path, gaussian_path
150
 
 
151
  def prepare_multi_example() -> List[Image.Image]:
152
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
153
  images = []
154
  for case in multi_case:
155
+ views = []
156
  for i in range(1, 4):
157
  img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
158
  W, H = img.size
159
  img = img.resize((int(W / H * 512), 512))
160
+ views.append(img)
161
+ images.append(views)
162
  return images
163
 
164
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
212
  with gr.Row():
213
  download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
214
  download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
215
+
216
  output_buf = gr.State()
217
+
218
  # Example images at the bottom of the page
219
  with gr.Row(visible=True) as multiimage_example:
220
  examples_multi = gr.Examples(
221
  examples=prepare_multi_example(),
222
  inputs=[multiimage_prompt],
223
+ fn=lambda x: x,
224
  outputs=[multiimage_prompt],
225
  run_on_click=True,
226
  examples_per_page=8,
227
  )
228
+
229
  # Handlers
230
  demo.load(start_session)
231
  demo.unload(end_session)
232
+
233
  multiimage_prompt.upload(
234
  preprocess_images,
235
  inputs=[multiimage_prompt],
236
  outputs=[multiimage_prompt],
237
  )
238
+
239
  generate_btn.click(
240
  get_seed,
241
  inputs=[randomize_seed, seed],
 
248
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
249
  outputs=[extract_glb_btn, extract_gs_btn],
250
  )
251
+
252
  video_output.clear(
253
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
254
  outputs=[extract_glb_btn, extract_gs_btn],
255
  )
256
+
257
  extract_glb_btn.click(
258
  extract_glb,
259
  inputs=[output_buf, mesh_simplify, texture_size],
 
262
  lambda: gr.Button(interactive=True),
263
  outputs=[download_glb],
264
  )
265
+
266
  extract_gs_btn.click(
267
  extract_gaussian,
268
  inputs=[output_buf],
 
271
  lambda: gr.Button(interactive=True),
272
  outputs=[download_gs],
273
  )
274
+
275
  model_output.clear(
276
  lambda: gr.Button(interactive=False),
277
  outputs=[download_glb],
 
285
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
286
  except:
287
  pass
288
+ demo.launch(show_error=True)