Surn commited on
Commit
03fb64e
·
1 Parent(s): 8001a73

Upgrade Trellis to handle landscapes

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. trellis/pipelines/trellis_image_to_3d.py +11 -6
app.py CHANGED
@@ -962,7 +962,7 @@ def generate_3d_asset_part1(depth_image_source, randomize_seed, seed, input_imag
962
  final_seed = np.random.randint(0, constants.MAX_SEED) if randomize_seed else seed
963
  # Process the image for depth estimation
964
  depth_img = depth_process_image(image_path, resized_width=1536, z_scale=336)
965
- depth_img = resize_image_with_aspect_ratio(depth_img, 1536, 1536)
966
 
967
  user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
968
  depth_img = save_image_to_temp_png(depth_img, user_dir, f"{output_name}_depth")
@@ -975,6 +975,7 @@ def generate_3d_asset_part2(depth_img, image_path, output_name, seed, steps, mod
975
  image_raw = Image.open(image_path).convert("RGB")
976
  resized_image = resize_image_with_aspect_ratio(image_raw, model_resolution, model_resolution)
977
  depth_img = Image.open(depth_img).convert("RGBA")
 
978
  if TRELLIS_PIPELINE is None:
979
  gr.Warning(f"Trellis Pipeline is not initialized: {TRELLIS_PIPELINE.device()}")
980
  return [None, None, depth_img]
@@ -982,7 +983,7 @@ def generate_3d_asset_part2(depth_img, image_path, output_name, seed, steps, mod
982
  # Preprocess and run the Trellis pipeline with fixed sampler settings
983
  try:
984
  TRELLIS_PIPELINE.cuda()
985
- processed_image = TRELLIS_PIPELINE.preprocess_image(resized_image, max_resolution=model_resolution)
986
  outputs = TRELLIS_PIPELINE.run(
987
  processed_image,
988
  seed=seed,
@@ -1503,7 +1504,7 @@ with gr.Blocks(css_paths="style_20250314.css", title=title, theme='Surn/beeuty',
1503
 
1504
  is_multiimage = gr.State(False)
1505
  output_buf = gr.State()
1506
- ddd_image_path = gr.State("./images/images/Beeuty-1.png")
1507
  ddd_file_name = gr.State("Hexagon_file")
1508
  with gr.Row():
1509
  gr.Examples(examples=[
@@ -1761,7 +1762,7 @@ if __name__ == "__main__":
1761
  TRELLIS_PIPELINE = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
1762
  TRELLIS_PIPELINE.to(device)
1763
  try:
1764
- TRELLIS_PIPELINE.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
1765
  except:
1766
  pass
1767
  hexaGrid.queue(default_concurrency_limit=1,max_size=12,api_open=False)
 
962
  final_seed = np.random.randint(0, constants.MAX_SEED) if randomize_seed else seed
963
  # Process the image for depth estimation
964
  depth_img = depth_process_image(image_path, resized_width=1536, z_scale=336)
965
+ #depth_img = resize_image_with_aspect_ratio(depth_img, 1536, 1536)
966
 
967
  user_dir = os.path.join(constants.TMPDIR, str(req.session_hash))
968
  depth_img = save_image_to_temp_png(depth_img, user_dir, f"{output_name}_depth")
 
975
  image_raw = Image.open(image_path).convert("RGB")
976
  resized_image = resize_image_with_aspect_ratio(image_raw, model_resolution, model_resolution)
977
  depth_img = Image.open(depth_img).convert("RGBA")
978
+
979
  if TRELLIS_PIPELINE is None:
980
  gr.Warning(f"Trellis Pipeline is not initialized: {TRELLIS_PIPELINE.device()}")
981
  return [None, None, depth_img]
 
983
  # Preprocess and run the Trellis pipeline with fixed sampler settings
984
  try:
985
  TRELLIS_PIPELINE.cuda()
986
+ processed_image = TRELLIS_PIPELINE.preprocess_image(resized_image, max_resolution=model_resolution, remove_bg = False)
987
  outputs = TRELLIS_PIPELINE.run(
988
  processed_image,
989
  seed=seed,
 
1504
 
1505
  is_multiimage = gr.State(False)
1506
  output_buf = gr.State()
1507
+ ddd_image_path = gr.State("./images/images/Bee-test-2.png")
1508
  ddd_file_name = gr.State("Hexagon_file")
1509
  with gr.Row():
1510
  gr.Examples(examples=[
 
1762
  TRELLIS_PIPELINE = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
1763
  TRELLIS_PIPELINE.to(device)
1764
  try:
1765
+ TRELLIS_PIPELINE.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)), 512) # Preload rembg
1766
  except:
1767
  pass
1768
  hexaGrid.queue(default_concurrency_limit=1,max_size=12,api_open=False)
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -82,7 +82,7 @@ class TrellisImageTo3DPipeline(Pipeline):
82
  ])
83
  self.image_cond_model_transform = transform
84
 
85
- def preprocess_image(self, input: Image.Image, max_resolution: int =1024) -> Image.Image:
86
  """
87
  Preprocess the input image.
88
  """
@@ -100,9 +100,12 @@ class TrellisImageTo3DPipeline(Pipeline):
100
  scale = min(1, max_resolution / max_size)
101
  if scale < 1:
102
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
103
- if getattr(self, 'rembg_session', None) is None:
104
- self.rembg_session = rembg.new_session('u2net')
105
- output = rembg.remove(input, session=self.rembg_session)
 
 
 
106
  output_np = np.array(output)
107
  alpha = output_np[:, :, 3]
108
  bbox = np.argwhere(alpha > 0.8 * 255)
@@ -264,6 +267,7 @@ class TrellisImageTo3DPipeline(Pipeline):
264
  slat_sampler_params: dict = {},
265
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
266
  preprocess_image: bool = True,
 
267
  ) -> dict:
268
  """
269
  Run the pipeline.
@@ -276,7 +280,7 @@ class TrellisImageTo3DPipeline(Pipeline):
276
  preprocess_image (bool): Whether to preprocess the image.
277
  """
278
  if preprocess_image:
279
- image = self.preprocess_image(image)
280
  cond = self.get_cond([image])
281
  torch.manual_seed(seed)
282
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
@@ -351,6 +355,7 @@ class TrellisImageTo3DPipeline(Pipeline):
351
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
352
  preprocess_image: bool = True,
353
  mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
 
354
  ) -> dict:
355
  """
356
  Run the pipeline with multiple images as condition
@@ -363,7 +368,7 @@ class TrellisImageTo3DPipeline(Pipeline):
363
  preprocess_image (bool): Whether to preprocess the image.
364
  """
365
  if preprocess_image:
366
- images = [self.preprocess_image(image) for image in images]
367
  cond = self.get_cond(images)
368
  cond['neg_cond'] = cond['neg_cond'][:1]
369
  torch.manual_seed(seed)
 
82
  ])
83
  self.image_cond_model_transform = transform
84
 
85
+ def preprocess_image(self, input: Image.Image, max_resolution: int =1024, remove_bg: bool = True) -> Image.Image:
86
  """
87
  Preprocess the input image.
88
  """
 
100
  scale = min(1, max_resolution / max_size)
101
  if scale < 1:
102
  input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
103
+ if remove_bg:
104
+ if getattr(self, 'rembg_session', None) is None:
105
+ self.rembg_session = rembg.new_session('u2net')
106
+ output = rembg.remove(input, session=self.rembg_session)
107
+ else:
108
+ output = input.convert('RGBA')
109
  output_np = np.array(output)
110
  alpha = output_np[:, :, 3]
111
  bbox = np.argwhere(alpha > 0.8 * 255)
 
267
  slat_sampler_params: dict = {},
268
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
269
  preprocess_image: bool = True,
270
+ remove_bg: bool = True,
271
  ) -> dict:
272
  """
273
  Run the pipeline.
 
280
  preprocess_image (bool): Whether to preprocess the image.
281
  """
282
  if preprocess_image:
283
+ image = self.preprocess_image(image, remove_bg=remove_bg)
284
  cond = self.get_cond([image])
285
  torch.manual_seed(seed)
286
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
 
355
  formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
356
  preprocess_image: bool = True,
357
  mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
358
+ remove_bg: bool = True,
359
  ) -> dict:
360
  """
361
  Run the pipeline with multiple images as condition
 
368
  preprocess_image (bool): Whether to preprocess the image.
369
  """
370
  if preprocess_image:
371
+ images = [self.preprocess_image(image,remove_bg=remove_bg) for image in images]
372
  cond = self.get_cond(images)
373
  cond['neg_cond'] = cond['neg_cond'][:1]
374
  torch.manual_seed(seed)