Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from typing import Literal, Union | |
import cv2 | |
import numpy as np | |
import rembg | |
import torch | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
from segment_anything import ( | |
SamAutomaticMaskGenerator, | |
SamPredictor, | |
sam_model_registry, | |
) | |
from asset3d_gen.utils.process_media import filter_small_connected_components | |
from asset3d_gen.validators.quality_checkers import ImageSegChecker | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"resize_pil", | |
"trellis_preprocess", | |
"SAMRemover", | |
"SAMPredictor", | |
"RembgRemover", | |
"get_segmented_image", | |
] | |
def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image: | |
max_size = max(image.size) | |
scale = min(1, 1024 / max_size) | |
if scale < 1: | |
new_size = (int(image.width * scale), int(image.height * scale)) | |
image = image.resize(new_size, Image.Resampling.LANCZOS) | |
return image | |
def trellis_preprocess(image: Image.Image) -> Image.Image: | |
"""Process the input image as trellis done.""" | |
image_np = np.array(image) | |
alpha = image_np[:, :, 3] | |
bbox = np.argwhere(alpha > 0.8 * 255) | |
bbox = ( | |
np.min(bbox[:, 1]), | |
np.min(bbox[:, 0]), | |
np.max(bbox[:, 1]), | |
np.max(bbox[:, 0]), | |
) | |
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 | |
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) | |
size = int(size * 1.2) | |
bbox = ( | |
center[0] - size // 2, | |
center[1] - size // 2, | |
center[0] + size // 2, | |
center[1] + size // 2, | |
) | |
image = image.crop(bbox) | |
image = image.resize((518, 518), Image.Resampling.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255 | |
image = image[:, :, :3] * image[:, :, 3:4] | |
image = Image.fromarray((image * 255).astype(np.uint8)) | |
return image | |
class SAMRemover(object): | |
"""Loading SAM models and performing background removal on images. | |
Attributes: | |
checkpoint (str): Path to the model checkpoint. | |
model_type (str): Type of the SAM model to load (default: "vit_h"). | |
area_ratio (float): Area ratio filtering small connected components. | |
""" | |
def __init__( | |
self, | |
checkpoint: str = None, | |
model_type: str = "vit_h", | |
area_ratio: float = 15, | |
): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model_type = model_type | |
self.area_ratio = area_ratio | |
if checkpoint is None: | |
suffix = "sam" | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
) | |
checkpoint = os.path.join( | |
model_path, suffix, "sam_vit_h_4b8939.pth" | |
) | |
self.mask_generator = self._load_sam_model(checkpoint) | |
def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: | |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint) | |
sam.to(device=self.device) | |
return SamAutomaticMaskGenerator(sam) | |
def __call__( | |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None | |
) -> Image.Image: | |
"""Removes the background from an image using the SAM model. | |
Args: | |
image (Union[str, Image.Image, np.ndarray]): Input image, | |
can be a file path, PIL Image, or numpy array. | |
save_path (str): Path to save the output image (default: None). | |
Returns: | |
Image.Image: The image with background removed, | |
including an alpha channel. | |
""" | |
# Convert input to numpy array | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert("RGB") | |
image = resize_pil(image) | |
image = np.array(image.convert("RGB")) | |
# Generate masks | |
masks = self.mask_generator.generate(image) | |
masks = sorted(masks, key=lambda x: x["area"], reverse=True) | |
if not masks: | |
logger.warning( | |
"Segmentation failed: No mask generated, return raw image." | |
) | |
output_image = Image.fromarray(image, mode="RGB") | |
else: | |
# Use the largest mask | |
best_mask = masks[0]["segmentation"] | |
mask = (best_mask * 255).astype(np.uint8) | |
mask = filter_small_connected_components( | |
mask, area_ratio=self.area_ratio | |
) | |
# Apply the mask to remove the background | |
background_removed = cv2.bitwise_and(image, image, mask=mask) | |
output_image = np.dstack((background_removed, mask)) | |
output_image = Image.fromarray(output_image, mode="RGBA") | |
if save_path is not None: | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
output_image.save(save_path) | |
return output_image | |
class SAMPredictor(object): | |
def __init__( | |
self, | |
checkpoint: str = None, | |
model_type: str = "vit_h", | |
binary_thresh: float = 0.1, | |
device: str = "cuda", | |
): | |
self.device = device | |
self.model_type = model_type | |
if checkpoint is None: | |
suffix = "sam" | |
model_path = snapshot_download( | |
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" | |
) | |
checkpoint = os.path.join( | |
model_path, suffix, "sam_vit_h_4b8939.pth" | |
) | |
self.predictor = self._load_sam_model(checkpoint) | |
self.binary_thresh = binary_thresh | |
def _load_sam_model(self, checkpoint: str) -> SamPredictor: | |
sam = sam_model_registry[self.model_type](checkpoint=checkpoint) | |
sam.to(device=self.device) | |
return SamPredictor(sam) | |
def preprocess_image(self, image: Image.Image) -> np.ndarray: | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert("RGB") | |
image = resize_pil(image) | |
image = np.array(image.convert("RGB")) | |
return image | |
def generate_masks( | |
self, | |
image: np.ndarray, | |
selected_points: list[list[int]], | |
) -> np.ndarray: | |
if len(selected_points) == 0: | |
return [] | |
points = ( | |
torch.Tensor([p for p, _ in selected_points]) | |
.to(self.predictor.device) | |
.unsqueeze(1) | |
) | |
labels = ( | |
torch.Tensor([int(l) for _, l in selected_points]) | |
.to(self.predictor.device) | |
.unsqueeze(1) | |
) | |
transformed_points = self.predictor.transform.apply_coords_torch( | |
points, image.shape[:2] | |
) | |
masks, scores, _ = self.predictor.predict_torch( | |
point_coords=transformed_points, | |
point_labels=labels, | |
multimask_output=True, | |
) | |
valid_mask = masks[:, torch.argmax(scores, dim=1)] | |
masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy() | |
masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy() | |
if len(masks_neg) == 0: | |
masks_neg = np.zeros_like(masks_pos) | |
if len(masks_pos) == 0: | |
masks_pos = np.zeros_like(masks_neg) | |
masks_neg = masks_neg.max(axis=0, keepdims=True) | |
masks_pos = masks_pos.max(axis=0, keepdims=True) | |
valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1) | |
binary_mask = (valid_mask > self.binary_thresh).astype(np.int32) | |
return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)] | |
def get_segmented_image( | |
self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] | |
) -> Image.Image: | |
seg_image = Image.fromarray(image, mode="RGB") | |
alpha_channel = np.zeros( | |
(seg_image.height, seg_image.width), dtype=np.uint8 | |
) | |
for mask, _ in masks: | |
# Use the maximum to combine multiple masks | |
alpha_channel = np.maximum(alpha_channel, mask) | |
alpha_channel = np.clip(alpha_channel, 0, 1) | |
alpha_channel = (alpha_channel * 255).astype(np.uint8) | |
alpha_image = Image.fromarray(alpha_channel, mode="L") | |
r, g, b = seg_image.split() | |
seg_image = Image.merge("RGBA", (r, g, b, alpha_image)) | |
return seg_image | |
def __call__( | |
self, | |
image: Union[str, Image.Image, np.ndarray], | |
selected_points: list[list[int]], | |
) -> Image.Image: | |
image = self.preprocess_image(image) | |
self.predictor.set_image(image) | |
masks = self.generate_masks(image, selected_points) | |
return self.get_segmented_image(image, masks) | |
class RembgRemover(object): | |
def __init__(self): | |
self.rembg_session = rembg.new_session("u2net") | |
def __call__( | |
self, image: Union[str, Image.Image, np.ndarray], save_path: str = None | |
) -> Image.Image: | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = resize_pil(image) | |
output_image = rembg.remove(image, session=self.rembg_session) | |
if save_path is not None: | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
output_image.save(save_path) | |
return output_image | |
def invert_rgba_pil( | |
image: Image.Image, mask: Image.Image, save_path: str = None | |
) -> Image.Image: | |
mask = (255 - np.array(mask))[..., None] | |
image_array = np.concatenate([np.array(image), mask], axis=-1) | |
inverted_image = Image.fromarray(image_array, "RGBA") | |
if save_path is not None: | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
inverted_image.save(save_path) | |
return inverted_image | |
def get_segmented_image( | |
image: Image.Image, | |
sam_remover: SAMRemover, | |
rbg_remover: RembgRemover, | |
seg_checker: ImageSegChecker = None, | |
save_path: str = None, | |
mode: Literal["loose", "strict"] = "loose", | |
) -> Image.Image: | |
def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: | |
if seg_checker is None: | |
return True | |
return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0] | |
out_sam = f"{save_path}_sam.png" if save_path else None | |
out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None | |
out_rbg = f"{save_path}_rbg.png" if save_path else None | |
seg_image = sam_remover(image, out_sam) | |
seg_image = seg_image.convert("RGBA") | |
_, _, _, alpha = seg_image.split() | |
seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv) | |
seg_image_rbg = rbg_remover(image, out_rbg) | |
final_image = None | |
if _is_valid_seg(image, seg_image): | |
final_image = seg_image | |
elif _is_valid_seg(image, seg_image_inv): | |
final_image = seg_image_inv | |
elif _is_valid_seg(image, seg_image_rbg): | |
logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.") | |
final_image = seg_image_rbg | |
else: | |
if mode == "strict": | |
raise RuntimeError( | |
f"Failed to segment by `SAM` or `rembg`, abort." | |
) | |
logger.warning("Failed to segment by SAM or rembg, use raw image.") | |
final_image = image.convert("RGBA") | |
if save_path: | |
final_image.save(save_path) | |
final_image = trellis_preprocess(final_image) | |
return final_image | |
if __name__ == "__main__": | |
input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg" | |
output_image = "sample_0_seg2.png" | |
# input_image = "outputs/text2image/tmp/coffee_machine.jpeg" | |
# output_image = "outputs/text2image/tmp/coffee_machine_seg.png" | |
# input_image = "outputs/text2image/tmp/bucket.jpeg" | |
# output_image = "outputs/text2image/tmp/bucket_seg.png" | |
remover = SAMRemover( | |
# checkpoint="/horizon-bucket/robot_lab/users/xinjie.wang/weights/sam/sam_vit_h_4b8939.pth", # noqa | |
model_type="vit_h", | |
) | |
remover = RembgRemover() | |
# clean_image = remover(input_image) | |
# clean_image.save(output_image) | |
get_segmented_image( | |
Image.open(input_image), remover, remover, None, "./test_seg.png" | |
) | |