xinjie.wang
update
146eff7
import os
from typing import Union
import spaces
import cv2
import numpy as np
import torch
from diffusers import (
EulerAncestralDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
)
from huggingface_hub import snapshot_download
from PIL import Image
from asset3d_gen.models.segment_model import RembgRemover
__all__ = [
"DelightingModel",
]
class DelightingModel(object):
def __init__(
self,
model_path: str = None,
num_infer_step: int = 50,
mask_erosion_size: int = 3,
image_guide_scale: float = 1.5,
text_guide_scale: float = 1.0,
device: str = "cuda",
seed: int = 0,
) -> None:
self.image_guide_scale = image_guide_scale
self.text_guide_scale = text_guide_scale
self.num_infer_step = num_infer_step
self.mask_erosion_size = mask_erosion_size
self.kernel = np.ones(
(self.mask_erosion_size, self.mask_erosion_size), np.uint8
)
self.seed = seed
self.device = device
self.bg_remover = RembgRemover()
if model_path is None:
suffix = "hunyuan3d-delight-v2-0"
model_path = snapshot_download(
repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(model_path, suffix)
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
safety_checker=None,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline.set_progress_bar_config(disable=True)
pipeline.to(self.device, torch.float16)
# pipeline.enable_model_cpu_offload()
# pipeline.enable_xformers_memory_efficient_attention()
self.pipeline = pipeline
def recenter_image(
self, image: Image.Image, border_ratio: float = 0.2
) -> Image.Image:
if image.mode == "RGB":
return image
elif image.mode == "L":
image = image.convert("RGB")
return image
alpha_channel = np.array(image)[:, :, 3]
non_zero_indices = np.argwhere(alpha_channel > 0)
if non_zero_indices.size == 0:
raise ValueError("Image is fully transparent")
min_row, min_col = non_zero_indices.min(axis=0)
max_row, max_col = non_zero_indices.max(axis=0)
cropped_image = image.crop(
(min_col, min_row, max_col + 1, max_row + 1)
)
width, height = cropped_image.size
border_width = int(width * border_ratio)
border_height = int(height * border_ratio)
new_width = width + 2 * border_width
new_height = height + 2 * border_height
square_size = max(new_width, new_height)
new_image = Image.new(
"RGBA", (square_size, square_size), (255, 255, 255, 0)
)
paste_x = (square_size - new_width) // 2 + border_width
paste_y = (square_size - new_height) // 2 + border_height
new_image.paste(cropped_image, (paste_x, paste_y))
return new_image
@spaces.GPU
@torch.no_grad()
def __call__(
self,
image: Union[str, np.ndarray, Image.Image],
preprocess: bool = False,
target_wh: tuple[int, int] = None,
) -> Image.Image:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if preprocess:
image = self.bg_remover(image)
image = self.recenter_image(image)
if target_wh is not None:
image = image.resize(target_wh)
else:
target_wh = image.size
image_array = np.array(image)
assert image_array.shape[-1] == 4, "Image must have alpha channel"
raw_alpha_channel = image_array[:, :, 3]
alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
image_array[alpha_channel == 0, :3] = 255 # must be white background
image_array[:, :, 3] = alpha_channel
image = self.pipeline(
prompt="",
image=Image.fromarray(image_array).convert("RGB"),
generator=torch.manual_seed(self.seed),
num_inference_steps=self.num_infer_step,
image_guidance_scale=self.image_guide_scale,
guidance_scale=self.text_guide_scale,
).images[0]
alpha_channel = Image.fromarray(alpha_channel)
rgba_image = image.convert("RGBA").resize(target_wh)
rgba_image.putalpha(alpha_channel)
return rgba_image
if __name__ == "__main__":
delighting_model = DelightingModel(
# model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
)
image_path = "scripts/apps/assets/example_image/room_bottle_002.jpeg"
image = delighting_model(
image_path, preprocess=True, target_wh=(512, 512)
) # noqa
image.save("delight.png")
# image_path = "asset3d_gen/scripts/test_robot.png"
# image = delighting_model(image_path)
# image.save("delighting_image_a2.png")