|
import os |
|
from typing import List |
|
|
|
import numpy as np |
|
import onnxruntime as ort |
|
import pooch |
|
from PIL import Image |
|
from PIL.Image import Image as PILImage |
|
|
|
from .base import BaseSession |
|
|
|
|
|
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): |
|
scale = long_side_length * 1.0 / max(oldh, oldw) |
|
newh, neww = oldh * scale, oldw * scale |
|
neww = int(neww + 0.5) |
|
newh = int(newh + 0.5) |
|
return (newh, neww) |
|
|
|
|
|
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray: |
|
old_h, old_w = original_size |
|
new_h, new_w = get_preprocess_shape( |
|
original_size[0], original_size[1], target_length |
|
) |
|
coords = coords.copy().astype(float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
|
|
def resize_longes_side(img: PILImage, size=1024): |
|
w, h = img.size |
|
if h > w: |
|
new_h, new_w = size, int(w * size / h) |
|
else: |
|
new_h, new_w = int(h * size / w), size |
|
|
|
return img.resize((new_w, new_h)) |
|
|
|
|
|
def pad_to_square(img: np.ndarray, size=1024): |
|
h, w = img.shape[:2] |
|
padh = size - h |
|
padw = size - w |
|
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant") |
|
img = img.astype(np.float32) |
|
return img |
|
|
|
|
|
class SamSession(BaseSession): |
|
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): |
|
self.model_name = model_name |
|
paths = self.__class__.download_models() |
|
self.encoder = ort.InferenceSession( |
|
str(paths[0]), |
|
providers=ort.get_available_providers(), |
|
sess_options=sess_opts, |
|
) |
|
self.decoder = ort.InferenceSession( |
|
str(paths[1]), |
|
providers=ort.get_available_providers(), |
|
sess_options=sess_opts, |
|
) |
|
|
|
def normalize( |
|
self, |
|
img: np.ndarray, |
|
mean=(123.675, 116.28, 103.53), |
|
std=(58.395, 57.12, 57.375), |
|
size=(1024, 1024), |
|
*args, |
|
**kwargs, |
|
): |
|
pixel_mean = np.array([*mean]).reshape(1, 1, -1) |
|
pixel_std = np.array([*std]).reshape(1, 1, -1) |
|
x = (img - pixel_mean) / pixel_std |
|
return x |
|
|
|
def predict( |
|
self, |
|
img: PILImage, |
|
*args, |
|
**kwargs, |
|
) -> List[PILImage]: |
|
|
|
image = resize_longes_side(img) |
|
image = np.array(image) |
|
image = self.normalize(image) |
|
image = pad_to_square(image) |
|
|
|
input_labels = kwargs.get("input_labels") |
|
input_points = kwargs.get("input_points") |
|
|
|
if input_labels is None: |
|
raise ValueError("input_labels is required") |
|
if input_points is None: |
|
raise ValueError("input_points is required") |
|
|
|
|
|
image = image.transpose(2, 0, 1)[None, :, :, :] |
|
|
|
encoded = self.encoder.run(None, {"x": image}) |
|
image_embedding = encoded[0] |
|
|
|
|
|
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[ |
|
None, :, : |
|
] |
|
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ |
|
None, : |
|
].astype(np.float32) |
|
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) |
|
|
|
|
|
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) |
|
onnx_has_mask_input = np.zeros(1, dtype=np.float32) |
|
|
|
decoder_inputs = { |
|
"image_embeddings": image_embedding, |
|
"point_coords": onnx_coord, |
|
"point_labels": onnx_label, |
|
"mask_input": onnx_mask_input, |
|
"has_mask_input": onnx_has_mask_input, |
|
"orig_im_size": np.array(img.size[::-1], dtype=np.float32), |
|
} |
|
|
|
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) |
|
masks = masks > 0.0 |
|
masks = [ |
|
Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) |
|
for i in range(masks.shape[0]) |
|
] |
|
|
|
return masks |
|
|
|
@classmethod |
|
def download_models(cls, *args, **kwargs): |
|
fname_encoder = f"{cls.name()}_encoder.onnx" |
|
fname_decoder = f"{cls.name()}_decoder.onnx" |
|
|
|
pooch.retrieve( |
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", |
|
None |
|
if cls.checksum_disabled(*args, **kwargs) |
|
else "md5:13d97c5c79ab13ef86d67cbde5f1b250", |
|
fname=fname_encoder, |
|
path=cls.u2net_home(*args, **kwargs), |
|
progressbar=True, |
|
) |
|
|
|
pooch.retrieve( |
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx", |
|
None |
|
if cls.checksum_disabled(*args, **kwargs) |
|
else "md5:fa3d1c36a3187d3de1c8deebf33dd127", |
|
fname=fname_decoder, |
|
path=cls.u2net_home(*args, **kwargs), |
|
progressbar=True, |
|
) |
|
|
|
return ( |
|
os.path.join(cls.u2net_home(), fname_encoder), |
|
os.path.join(cls.u2net_home(), fname_decoder), |
|
) |
|
|
|
@classmethod |
|
def name(cls, *args, **kwargs): |
|
return "sam" |
|
|