Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2024-present Naver Cloud Corp. | |
This source code is based on code from the Segment Anything Model (SAM) | |
(https://github.com/facebookresearch/segment-anything). | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
from typing import Any, Callable | |
import onnxruntime | |
import numpy as np | |
def np2tensor(np_array, device): | |
return torch.from_numpy(np_array).to(device) | |
def tensor2np(torch_tensor): | |
if torch_tensor is None: | |
return None | |
return torch_tensor.detach().cpu().numpy() | |
class ZIM_Decoder(): | |
def __init__(self, onnx_path, num_threads=16): | |
self.onnx_path = onnx_path | |
sessionOptions = onnxruntime.SessionOptions() | |
sessionOptions.intra_op_num_threads = num_threads | |
sessionOptions.inter_op_num_threads = num_threads | |
providers = ["CPUExecutionProvider"] | |
self.ort_session = onnxruntime.InferenceSession( | |
onnx_path, sess_options=sessionOptions, providers=providers | |
) | |
self.num_mask_tokens = 4 | |
def cuda(self, device_id=0): | |
providers = [ | |
( | |
"CUDAExecutionProvider", | |
{ | |
"device_id": device_id, | |
}, | |
), | |
] | |
self.ort_session.set_providers(providers) | |
def forward( | |
self, | |
interm_feats, | |
image_embeddings, | |
points, | |
boxes, | |
attn_mask, | |
): | |
device = image_embeddings.device | |
ort_inputs = { | |
"feat_D0": tensor2np(interm_feats[0]), | |
"feat_D1": tensor2np(interm_feats[1]), | |
"feat_D2": tensor2np(interm_feats[2]), | |
"image_embeddings": tensor2np(image_embeddings), | |
"attn_mask": tensor2np(attn_mask), | |
} | |
if points is not None: | |
point_coords, point_labels = points | |
ort_inputs["point_coords"] = tensor2np(point_coords.float()) | |
ort_inputs["point_labels"] = tensor2np(point_labels.float()) | |
# add paddings as done in SAM | |
padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5 | |
padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32) | |
ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1) | |
ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1) | |
if boxes is not None: | |
ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2)) | |
ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0) | |
masks, iou_predictions = self.ort_session.run(None, ort_inputs) | |
masks = np2tensor(masks, device) | |
iou_predictions = np2tensor(iou_predictions, device) | |
return masks, iou_predictions | |
__call__: Callable[..., Any] = forward | |