Segment Anything 8-Bit ONNX

How to run:

import onnxruntime as ort
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Path to the image file
image_path = "example.png"

# Load the image and preprocess it
image = Image.open(image_path).convert("RGB")
orig_width, orig_height = image.size
input_tensor = np.array(image)
mean = np.array([123.675, 116.28, 103.53])
std = np.array([58.395, 57.12, 57.375])
input_tensor = (input_tensor - mean) / std
input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)

# Pad input tensor to 1024x1024
pad_height = 1024 - input_tensor.shape[2]
pad_width = 1024 - input_tensor.shape[3]
input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width)))

# Load the encoder model and run inference
encoder = ort.InferenceSession("sam_encoder.onnx")
embeddings = encoder.run(None, {"images": input_tensor})[0]

# Choose a point (e.g., x=150, y=100) in the original image
point = [150, 100]

# Convert point coordinates to match the padded image
point = np.array([[point]])
coords = point.astype(float)
coords[..., 0] = coords[..., 0] * (1024 / orig_width)
coords[..., 1] = coords[..., 1] * (1024 / orig_height)
onnx_coord = coords.astype("float32")

# Prepare inputs for the decoder
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
onnx_label = np.array([1, -1]).astype(np.float32)[None, :]

# Load the decoder model and run inference
decoder = ort.InferenceSession("sam_decoder.onnx")
masks_output, _, _ = decoder.run(None, {
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})

# Process the output mask
mask = masks_output[0][0]
mask = (mask > 0).astype('uint8') * 255
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.