eleelenawa commited on
Commit
4e5eb34
Β·
verified Β·
1 Parent(s): b7b00e2

update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -50
app.py CHANGED
@@ -15,22 +15,28 @@ from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
-
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
 
 
 
 
 
 
 
 
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
 
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
32
 
33
-
34
  def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
  Preprocess the input image.
@@ -44,7 +50,6 @@ def preprocess_image(image: Image.Image) -> Image.Image:
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
46
 
47
-
48
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
  """
50
  Preprocess a list of input images.
@@ -59,7 +64,6 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
59
  processed_images = [pipeline.preprocess_image(image) for image in images]
60
  return processed_images
61
 
62
-
63
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
64
  return {
65
  'gaussian': {
@@ -76,7 +80,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
76
  },
77
  }
78
 
79
-
80
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
81
  gs = Gaussian(
82
  aabb=state['gaussian']['aabb'],
@@ -99,14 +102,12 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
99
 
100
  return gs, mesh
101
 
102
-
103
  def get_seed(randomize_seed: bool, seed: int) -> int:
104
  """
105
  Get the random seed.
106
  """
107
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
108
 
109
-
110
  @spaces.GPU
111
  def image_to_3d(
112
  image: Image.Image,
@@ -122,21 +123,6 @@ def image_to_3d(
122
  ) -> Tuple[dict, str]:
123
  """
124
  Convert an image to a 3D model.
125
-
126
- Args:
127
- image (Image.Image): The input image.
128
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
129
- is_multiimage (bool): Whether is in multi-image mode.
130
- seed (int): The random seed.
131
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
132
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
133
- slat_guidance_strength (float): The guidance strength for structured latent generation.
134
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
136
-
137
- Returns:
138
- dict: The information of the generated 3D model.
139
- str: The path to the video of the 3D model.
140
  """
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
142
  if not is_multiimage:
@@ -179,7 +165,6 @@ def image_to_3d(
179
  torch.cuda.empty_cache()
180
  return state, video_path
181
 
182
-
183
  @spaces.GPU(duration=90)
184
  def extract_glb(
185
  state: dict,
@@ -189,14 +174,6 @@ def extract_glb(
189
  ) -> Tuple[str, str]:
190
  """
191
  Extract a GLB file from the 3D model.
192
-
193
- Args:
194
- state (dict): The state of the generated 3D model.
195
- mesh_simplify (float): The mesh simplification factor.
196
- texture_size (int): The texture resolution.
197
-
198
- Returns:
199
- str: The path to the extracted GLB file.
200
  """
201
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
  gs, mesh = unpack_state(state)
@@ -206,17 +183,10 @@ def extract_glb(
206
  torch.cuda.empty_cache()
207
  return glb_path, glb_path
208
 
209
-
210
  @spaces.GPU
211
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
212
  """
213
  Extract a Gaussian file from the 3D model.
214
-
215
- Args:
216
- state (dict): The state of the generated 3D model.
217
-
218
- Returns:
219
- str: The path to the extracted Gaussian file.
220
  """
221
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
  gs, _ = unpack_state(state)
@@ -225,7 +195,6 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
225
  torch.cuda.empty_cache()
226
  return gaussian_path, gaussian_path
227
 
228
-
229
  def prepare_multi_example() -> List[Image.Image]:
230
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
231
  images = []
@@ -239,7 +208,6 @@ def prepare_multi_example() -> List[Image.Image]:
239
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
240
  return images
241
 
242
-
243
  def split_image(image: Image.Image) -> List[Image.Image]:
244
  """
245
  Split an image into multiple views.
@@ -254,7 +222,6 @@ def split_image(image: Image.Image) -> List[Image.Image]:
254
  images.append(Image.fromarray(image[:, s:e+1]))
255
  return [preprocess_image(image) for image in images]
256
 
257
-
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  gr.Markdown("""
260
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
@@ -401,14 +368,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
401
  lambda: gr.Button(interactive=False),
402
  outputs=[download_glb],
403
  )
404
-
405
 
406
  # Launch the Gradio app
407
  if __name__ == "__main__":
408
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
409
- pipeline.cuda()
410
- try:
411
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
412
- except:
413
- pass
414
- demo.launch()
 
15
  from trellis.representations import Gaussian, MeshExtractResult
16
  from trellis.utils import render_utils, postprocessing_utils
17
 
18
+ # Constants
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
+ # Initialize pipeline at the module level
24
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
25
+ pipeline.cuda()
26
+ try:
27
+ # Preload rembg
28
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
29
+ except:
30
+ pass
31
 
32
  def start_session(req: gr.Request):
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  os.makedirs(user_dir, exist_ok=True)
35
 
 
36
  def end_session(req: gr.Request):
37
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
38
  shutil.rmtree(user_dir)
39
 
 
40
  def preprocess_image(image: Image.Image) -> Image.Image:
41
  """
42
  Preprocess the input image.
 
50
  processed_image = pipeline.preprocess_image(image)
51
  return processed_image
52
 
 
53
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
54
  """
55
  Preprocess a list of input images.
 
64
  processed_images = [pipeline.preprocess_image(image) for image in images]
65
  return processed_images
66
 
 
67
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
68
  return {
69
  'gaussian': {
 
80
  },
81
  }
82
 
 
83
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
84
  gs = Gaussian(
85
  aabb=state['gaussian']['aabb'],
 
102
 
103
  return gs, mesh
104
 
 
105
  def get_seed(randomize_seed: bool, seed: int) -> int:
106
  """
107
  Get the random seed.
108
  """
109
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
110
 
 
111
  @spaces.GPU
112
  def image_to_3d(
113
  image: Image.Image,
 
123
  ) -> Tuple[dict, str]:
124
  """
125
  Convert an image to a 3D model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  """
127
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
128
  if not is_multiimage:
 
165
  torch.cuda.empty_cache()
166
  return state, video_path
167
 
 
168
  @spaces.GPU(duration=90)
169
  def extract_glb(
170
  state: dict,
 
174
  ) -> Tuple[str, str]:
175
  """
176
  Extract a GLB file from the 3D model.
 
 
 
 
 
 
 
 
177
  """
178
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
179
  gs, mesh = unpack_state(state)
 
183
  torch.cuda.empty_cache()
184
  return glb_path, glb_path
185
 
 
186
  @spaces.GPU
187
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
188
  """
189
  Extract a Gaussian file from the 3D model.
 
 
 
 
 
 
190
  """
191
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
192
  gs, _ = unpack_state(state)
 
195
  torch.cuda.empty_cache()
196
  return gaussian_path, gaussian_path
197
 
 
198
  def prepare_multi_example() -> List[Image.Image]:
199
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
200
  images = []
 
208
  images.append(Image.fromarray(np.concatenate(_images, axis=1)))
209
  return images
210
 
 
211
  def split_image(image: Image.Image) -> List[Image.Image]:
212
  """
213
  Split an image into multiple views.
 
222
  images.append(Image.fromarray(image[:, s:e+1]))
223
  return [preprocess_image(image) for image in images]
224
 
 
225
  with gr.Blocks(delete_cache=(600, 600)) as demo:
226
  gr.Markdown("""
227
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
 
368
  lambda: gr.Button(interactive=False),
369
  outputs=[download_glb],
370
  )
 
371
 
372
  # Launch the Gradio app
373
  if __name__ == "__main__":
374
+ demo.launch()