yejunliang23 commited on
Commit
e1c929d
·
verified ·
1 Parent(s): 755cb36

Update trellis/pipelines/trellis_image_to_3d.py

Browse files
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -79,6 +79,62 @@ class TrellisImageTo3DPipeline(Pipeline):
79
  ])
80
  self.image_cond_model_transform = transform
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def preprocess_image(self, input: Image.Image) -> Image.Image:
83
  """
84
  Preprocess the input image.
 
79
  ])
80
  self.image_cond_model_transform = transform
81
 
82
+ def preprocess_image_white(self, input: Image.Image) -> Image.Image:
83
+ avg_sw = 0.626
84
+ avg_sh = 0.608
85
+ has_alpha = False
86
+ if input.mode == 'RGBA':
87
+ alpha = np.array(input)[:, :, 3]
88
+ if not np.all(alpha == 255):
89
+ has_alpha = True
90
+
91
+ if has_alpha:
92
+ fg = input.convert('RGBA')
93
+ else:
94
+ img = input.convert('RGB')
95
+ max_size = max(img.size)
96
+ scale0 = min(1, 1024 / max_size)
97
+ if scale0 < 1:
98
+ img = img.resize(
99
+ (int(img.width * scale0), int(img.height * scale0)),
100
+ Image.Resampling.LANCZOS
101
+ )
102
+ if getattr(self, 'rembg_session', None) is None:
103
+ self.rembg_session = rembg.new_session('u2net')
104
+ fg = rembg.remove(img, session=self.rembg_session)
105
+
106
+ # —— 2. 找包围盒并裁剪(±20%) —— #
107
+ arr = np.array(fg)
108
+ alpha = arr[:, :, 3]
109
+ ys, xs = np.where(alpha > 0.8 * 255)
110
+ x0, y0 = xs.min(), ys.min()
111
+ x1, y1 = xs.max(), ys.max()
112
+ # 原始宽高
113
+ w0, h0 = x1 - x0, y1 - y0
114
+ # 中心 & 放大 20%
115
+ cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
116
+ L = max(w0, h0) * 1.2
117
+ x0n = int(cx - L/2); y0n = int(cy - L/2)
118
+ x1n = int(cx + L/2); y1n = int(cy + L/2)
119
+ fg = fg.crop((x0n, y0n, x1n, y1n))
120
+
121
+ # —— 3. 按 avg_sw/avg_sh 调整前景尺寸 —— #
122
+ # 假设你在类里定义了 avg_sw, avg_sh = 前景占宽/高的目标比例(0~1)
123
+ W, H = 512, 512
124
+ w, h = fg.size
125
+ target_w = avg_sw * W
126
+ target_h = avg_sh * H
127
+ scale1 = min(target_w / w, target_h / h)
128
+ new_w, new_h = int(w * scale1), int(h * scale1)
129
+ fg_resized = fg.resize((new_w, new_h), Image.Resampling.LANCZOS)
130
+
131
+ # —— 4. 白底画布上居中贴前景,输出 RGB —— #
132
+ canvas = Image.new('RGBA', (W, H), (255, 255, 255, 255))
133
+ x_off = (W - new_w) // 2
134
+ y_off = (H - new_h) // 2
135
+ canvas.paste(fg_resized, (x_off, y_off), fg_resized)
136
+ return canvas.convert('RGB')
137
+
138
  def preprocess_image(self, input: Image.Image) -> Image.Image:
139
  """
140
  Preprocess the input image.