xinjie.wang
update
55ed985
import logging
import os
from typing import Literal, Union
import cv2
import numpy as np
import rembg
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from segment_anything import (
SamAutomaticMaskGenerator,
SamPredictor,
sam_model_registry,
)
from asset3d_gen.utils.process_media import filter_small_connected_components
from asset3d_gen.validators.quality_checkers import ImageSegChecker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"resize_pil",
"trellis_preprocess",
"SAMRemover",
"SAMPredictor",
"RembgRemover",
"get_segmented_image",
]
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
max_size = max(image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
new_size = (int(image.width * scale), int(image.height * scale))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def trellis_preprocess(image: Image.Image) -> Image.Image:
"""Process the input image as trellis done."""
image_np = np.array(image)
alpha = image_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = (
np.min(bbox[:, 1]),
np.min(bbox[:, 0]),
np.max(bbox[:, 1]),
np.max(bbox[:, 0]),
)
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.2)
bbox = (
center[0] - size // 2,
center[1] - size // 2,
center[0] + size // 2,
center[1] + size // 2,
)
image = image.crop(bbox)
image = image.resize((518, 518), Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255
image = image[:, :, :3] * image[:, :, 3:4]
image = Image.fromarray((image * 255).astype(np.uint8))
return image
class SAMRemover(object):
"""Loading SAM models and performing background removal on images.
Attributes:
checkpoint (str): Path to the model checkpoint.
model_type (str): Type of the SAM model to load (default: "vit_h").
area_ratio (float): Area ratio filtering small connected components.
"""
def __init__(
self,
checkpoint: str = None,
model_type: str = "vit_h",
area_ratio: float = 15,
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_type = model_type
self.area_ratio = area_ratio
if checkpoint is None:
suffix = "sam"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
checkpoint = os.path.join(
model_path, suffix, "sam_vit_h_4b8939.pth"
)
self.mask_generator = self._load_sam_model(checkpoint)
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
sam.to(device=self.device)
return SamAutomaticMaskGenerator(sam)
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
"""Removes the background from an image using the SAM model.
Args:
image (Union[str, Image.Image, np.ndarray]): Input image,
can be a file path, PIL Image, or numpy array.
save_path (str): Path to save the output image (default: None).
Returns:
Image.Image: The image with background removed,
including an alpha channel.
"""
# Convert input to numpy array
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
image = resize_pil(image)
image = np.array(image.convert("RGB"))
# Generate masks
masks = self.mask_generator.generate(image)
masks = sorted(masks, key=lambda x: x["area"], reverse=True)
if not masks:
logger.warning(
"Segmentation failed: No mask generated, return raw image."
)
output_image = Image.fromarray(image, mode="RGB")
else:
# Use the largest mask
best_mask = masks[0]["segmentation"]
mask = (best_mask * 255).astype(np.uint8)
mask = filter_small_connected_components(
mask, area_ratio=self.area_ratio
)
# Apply the mask to remove the background
background_removed = cv2.bitwise_and(image, image, mask=mask)
output_image = np.dstack((background_removed, mask))
output_image = Image.fromarray(output_image, mode="RGBA")
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
output_image.save(save_path)
return output_image
class SAMPredictor(object):
def __init__(
self,
checkpoint: str = None,
model_type: str = "vit_h",
binary_thresh: float = 0.1,
device: str = "cuda",
):
self.device = device
self.model_type = model_type
if checkpoint is None:
suffix = "sam"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
checkpoint = os.path.join(
model_path, suffix, "sam_vit_h_4b8939.pth"
)
self.predictor = self._load_sam_model(checkpoint)
self.binary_thresh = binary_thresh
def _load_sam_model(self, checkpoint: str) -> SamPredictor:
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
sam.to(device=self.device)
return SamPredictor(sam)
def preprocess_image(self, image: Image.Image) -> np.ndarray:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
image = resize_pil(image)
image = np.array(image.convert("RGB"))
return image
def generate_masks(
self,
image: np.ndarray,
selected_points: list[list[int]],
) -> np.ndarray:
if len(selected_points) == 0:
return []
points = (
torch.Tensor([p for p, _ in selected_points])
.to(self.predictor.device)
.unsqueeze(1)
)
labels = (
torch.Tensor([int(l) for _, l in selected_points])
.to(self.predictor.device)
.unsqueeze(1)
)
transformed_points = self.predictor.transform.apply_coords_torch(
points, image.shape[:2]
)
masks, scores, _ = self.predictor.predict_torch(
point_coords=transformed_points,
point_labels=labels,
multimask_output=True,
)
valid_mask = masks[:, torch.argmax(scores, dim=1)]
masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
if len(masks_neg) == 0:
masks_neg = np.zeros_like(masks_pos)
if len(masks_pos) == 0:
masks_pos = np.zeros_like(masks_neg)
masks_neg = masks_neg.max(axis=0, keepdims=True)
masks_pos = masks_pos.max(axis=0, keepdims=True)
valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
def get_segmented_image(
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
) -> Image.Image:
seg_image = Image.fromarray(image, mode="RGB")
alpha_channel = np.zeros(
(seg_image.height, seg_image.width), dtype=np.uint8
)
for mask, _ in masks:
# Use the maximum to combine multiple masks
alpha_channel = np.maximum(alpha_channel, mask)
alpha_channel = np.clip(alpha_channel, 0, 1)
alpha_channel = (alpha_channel * 255).astype(np.uint8)
alpha_image = Image.fromarray(alpha_channel, mode="L")
r, g, b = seg_image.split()
seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
return seg_image
def __call__(
self,
image: Union[str, Image.Image, np.ndarray],
selected_points: list[list[int]],
) -> Image.Image:
image = self.preprocess_image(image)
self.predictor.set_image(image)
masks = self.generate_masks(image, selected_points)
return self.get_segmented_image(image, masks)
class RembgRemover(object):
def __init__(self):
self.rembg_session = rembg.new_session("u2net")
def __call__(
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
) -> Image.Image:
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = resize_pil(image)
output_image = rembg.remove(image, session=self.rembg_session)
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
output_image.save(save_path)
return output_image
def invert_rgba_pil(
image: Image.Image, mask: Image.Image, save_path: str = None
) -> Image.Image:
mask = (255 - np.array(mask))[..., None]
image_array = np.concatenate([np.array(image), mask], axis=-1)
inverted_image = Image.fromarray(image_array, "RGBA")
if save_path is not None:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
inverted_image.save(save_path)
return inverted_image
def get_segmented_image(
image: Image.Image,
sam_remover: SAMRemover,
rbg_remover: RembgRemover,
seg_checker: ImageSegChecker = None,
save_path: str = None,
mode: Literal["loose", "strict"] = "loose",
) -> Image.Image:
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
if seg_checker is None:
return True
return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
out_sam = f"{save_path}_sam.png" if save_path else None
out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
out_rbg = f"{save_path}_rbg.png" if save_path else None
seg_image = sam_remover(image, out_sam)
seg_image = seg_image.convert("RGBA")
_, _, _, alpha = seg_image.split()
seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
seg_image_rbg = rbg_remover(image, out_rbg)
final_image = None
if _is_valid_seg(image, seg_image):
final_image = seg_image
elif _is_valid_seg(image, seg_image_inv):
final_image = seg_image_inv
elif _is_valid_seg(image, seg_image_rbg):
logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
final_image = seg_image_rbg
else:
if mode == "strict":
raise RuntimeError(
f"Failed to segment by `SAM` or `rembg`, abort."
)
logger.warning("Failed to segment by SAM or rembg, use raw image.")
final_image = image.convert("RGBA")
if save_path:
final_image.save(save_path)
final_image = trellis_preprocess(final_image)
return final_image
if __name__ == "__main__":
input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
output_image = "sample_0_seg2.png"
# input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
# output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
# input_image = "outputs/text2image/tmp/bucket.jpeg"
# output_image = "outputs/text2image/tmp/bucket_seg.png"
remover = SAMRemover(
# checkpoint="/horizon-bucket/robot_lab/users/xinjie.wang/weights/sam/sam_vit_h_4b8939.pth", # noqa
model_type="vit_h",
)
remover = RembgRemover()
# clean_image = remover(input_image)
# clean_image.save(output_image)
get_segmented_image(
Image.open(input_image), remover, remover, None, "./test_seg.png"
)