import logging import os from typing import Literal, Union import cv2 import numpy as np import rembg import torch from huggingface_hub import snapshot_download from PIL import Image from segment_anything import ( SamAutomaticMaskGenerator, SamPredictor, sam_model_registry, ) from asset3d_gen.utils.process_media import filter_small_connected_components from asset3d_gen.validators.quality_checkers import ImageSegChecker logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) __all__ = [ "resize_pil", "trellis_preprocess", "SAMRemover", "SAMPredictor", "RembgRemover", "get_segmented_image", ] def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image: max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: new_size = (int(image.width * scale), int(image.height * scale)) image = image.resize(new_size, Image.Resampling.LANCZOS) return image def trellis_preprocess(image: Image.Image) -> Image.Image: """Process the input image as trellis done.""" image_np = np.array(image) alpha = image_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = ( np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]), ) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.2) bbox = ( center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2, ) image = image.crop(bbox) image = image.resize((518, 518), Image.Resampling.LANCZOS) image = np.array(image).astype(np.float32) / 255 image = image[:, :, :3] * image[:, :, 3:4] image = Image.fromarray((image * 255).astype(np.uint8)) return image class SAMRemover(object): """Loading SAM models and performing background removal on images. Attributes: checkpoint (str): Path to the model checkpoint. model_type (str): Type of the SAM model to load (default: "vit_h"). area_ratio (float): Area ratio filtering small connected components. """ def __init__( self, checkpoint: str = None, model_type: str = "vit_h", area_ratio: float = 15, ): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_type = model_type self.area_ratio = area_ratio if checkpoint is None: suffix = "sam" model_path = snapshot_download( repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" ) checkpoint = os.path.join( model_path, suffix, "sam_vit_h_4b8939.pth" ) self.mask_generator = self._load_sam_model(checkpoint) def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) return SamAutomaticMaskGenerator(sam) def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ) -> Image.Image: """Removes the background from an image using the SAM model. Args: image (Union[str, Image.Image, np.ndarray]): Input image, can be a file path, PIL Image, or numpy array. save_path (str): Path to save the output image (default: None). Returns: Image.Image: The image with background removed, including an alpha channel. """ # Convert input to numpy array if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") image = resize_pil(image) image = np.array(image.convert("RGB")) # Generate masks masks = self.mask_generator.generate(image) masks = sorted(masks, key=lambda x: x["area"], reverse=True) if not masks: logger.warning( "Segmentation failed: No mask generated, return raw image." ) output_image = Image.fromarray(image, mode="RGB") else: # Use the largest mask best_mask = masks[0]["segmentation"] mask = (best_mask * 255).astype(np.uint8) mask = filter_small_connected_components( mask, area_ratio=self.area_ratio ) # Apply the mask to remove the background background_removed = cv2.bitwise_and(image, image, mask=mask) output_image = np.dstack((background_removed, mask)) output_image = Image.fromarray(output_image, mode="RGBA") if save_path is not None: os.makedirs(os.path.dirname(save_path), exist_ok=True) output_image.save(save_path) return output_image class SAMPredictor(object): def __init__( self, checkpoint: str = None, model_type: str = "vit_h", binary_thresh: float = 0.1, device: str = "cuda", ): self.device = device self.model_type = model_type if checkpoint is None: suffix = "sam" model_path = snapshot_download( repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" ) checkpoint = os.path.join( model_path, suffix, "sam_vit_h_4b8939.pth" ) self.predictor = self._load_sam_model(checkpoint) self.binary_thresh = binary_thresh def _load_sam_model(self, checkpoint: str) -> SamPredictor: sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) return SamPredictor(sam) def preprocess_image(self, image: Image.Image) -> np.ndarray: if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") image = resize_pil(image) image = np.array(image.convert("RGB")) return image def generate_masks( self, image: np.ndarray, selected_points: list[list[int]], ) -> np.ndarray: if len(selected_points) == 0: return [] points = ( torch.Tensor([p for p, _ in selected_points]) .to(self.predictor.device) .unsqueeze(1) ) labels = ( torch.Tensor([int(l) for _, l in selected_points]) .to(self.predictor.device) .unsqueeze(1) ) transformed_points = self.predictor.transform.apply_coords_torch( points, image.shape[:2] ) masks, scores, _ = self.predictor.predict_torch( point_coords=transformed_points, point_labels=labels, multimask_output=True, ) valid_mask = masks[:, torch.argmax(scores, dim=1)] masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy() masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy() if len(masks_neg) == 0: masks_neg = np.zeros_like(masks_pos) if len(masks_pos) == 0: masks_pos = np.zeros_like(masks_neg) masks_neg = masks_neg.max(axis=0, keepdims=True) masks_pos = masks_pos.max(axis=0, keepdims=True) valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1) binary_mask = (valid_mask > self.binary_thresh).astype(np.int32) return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)] def get_segmented_image( self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] ) -> Image.Image: seg_image = Image.fromarray(image, mode="RGB") alpha_channel = np.zeros( (seg_image.height, seg_image.width), dtype=np.uint8 ) for mask, _ in masks: # Use the maximum to combine multiple masks alpha_channel = np.maximum(alpha_channel, mask) alpha_channel = np.clip(alpha_channel, 0, 1) alpha_channel = (alpha_channel * 255).astype(np.uint8) alpha_image = Image.fromarray(alpha_channel, mode="L") r, g, b = seg_image.split() seg_image = Image.merge("RGBA", (r, g, b, alpha_image)) return seg_image def __call__( self, image: Union[str, Image.Image, np.ndarray], selected_points: list[list[int]], ) -> Image.Image: image = self.preprocess_image(image) self.predictor.set_image(image) masks = self.generate_masks(image, selected_points) return self.get_segmented_image(image, masks) class RembgRemover(object): def __init__(self): self.rembg_session = rembg.new_session("u2net") def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ) -> Image.Image: if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) image = resize_pil(image) output_image = rembg.remove(image, session=self.rembg_session) if save_path is not None: os.makedirs(os.path.dirname(save_path), exist_ok=True) output_image.save(save_path) return output_image def invert_rgba_pil( image: Image.Image, mask: Image.Image, save_path: str = None ) -> Image.Image: mask = (255 - np.array(mask))[..., None] image_array = np.concatenate([np.array(image), mask], axis=-1) inverted_image = Image.fromarray(image_array, "RGBA") if save_path is not None: os.makedirs(os.path.dirname(save_path), exist_ok=True) inverted_image.save(save_path) return inverted_image def get_segmented_image( image: Image.Image, sam_remover: SAMRemover, rbg_remover: RembgRemover, seg_checker: ImageSegChecker = None, save_path: str = None, mode: Literal["loose", "strict"] = "loose", ) -> Image.Image: def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: if seg_checker is None: return True return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0] out_sam = f"{save_path}_sam.png" if save_path else None out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None out_rbg = f"{save_path}_rbg.png" if save_path else None seg_image = sam_remover(image, out_sam) seg_image = seg_image.convert("RGBA") _, _, _, alpha = seg_image.split() seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv) seg_image_rbg = rbg_remover(image, out_rbg) final_image = None if _is_valid_seg(image, seg_image): final_image = seg_image elif _is_valid_seg(image, seg_image_inv): final_image = seg_image_inv elif _is_valid_seg(image, seg_image_rbg): logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.") final_image = seg_image_rbg else: if mode == "strict": raise RuntimeError( f"Failed to segment by `SAM` or `rembg`, abort." ) logger.warning("Failed to segment by SAM or rembg, use raw image.") final_image = image.convert("RGBA") if save_path: final_image.save(save_path) final_image = trellis_preprocess(final_image) return final_image if __name__ == "__main__": input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg" output_image = "sample_0_seg2.png" # input_image = "outputs/text2image/tmp/coffee_machine.jpeg" # output_image = "outputs/text2image/tmp/coffee_machine_seg.png" # input_image = "outputs/text2image/tmp/bucket.jpeg" # output_image = "outputs/text2image/tmp/bucket_seg.png" remover = SAMRemover( # checkpoint="/horizon-bucket/robot_lab/users/xinjie.wang/weights/sam/sam_vit_h_4b8939.pth", # noqa model_type="vit_h", ) remover = RembgRemover() # clean_image = remover(input_image) # clean_image.save(output_image) get_segmented_image( Image.open(input_image), remover, remover, None, "./test_seg.png" )