Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2024-present Naver Cloud Corp. | |
This source code is based on code from the Segment Anything Model (SAM) | |
(https://github.com/facebookresearch/segment-anything). | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
from typing import Any, Callable | |
import onnxruntime | |
def np2tensor(np_array, device): | |
return torch.from_numpy(np_array).to(device) | |
def tensor2np(torch_tensor): | |
return torch_tensor.detach().cpu().numpy() | |
class ZIM_Encoder(): | |
def __init__(self, onnx_path, num_threads=16): | |
self.onnx_path = onnx_path | |
sessionOptions = onnxruntime.SessionOptions() | |
sessionOptions.intra_op_num_threads = num_threads | |
sessionOptions.inter_op_num_threads = num_threads | |
providers = ["CPUExecutionProvider"] | |
self.ort_session = onnxruntime.InferenceSession( | |
onnx_path, sess_options=sessionOptions, providers=providers | |
) | |
def cuda(self, device_id=0): | |
providers = [ | |
( | |
"CUDAExecutionProvider", | |
{ | |
"device_id": device_id, | |
}, | |
), | |
] | |
self.ort_session.set_providers(providers) | |
def forward( | |
self, | |
image, | |
): | |
device = image.device | |
ort_inputs = { | |
"image": tensor2np(image), | |
} | |
image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs) | |
image_embeddings = np2tensor(image_embeddings, device) | |
feat_D0 = np2tensor(feat_D0, device) | |
feat_D1 = np2tensor(feat_D1, device) | |
feat_D2 = np2tensor(feat_D2, device) | |
return image_embeddings, (feat_D0, feat_D1, feat_D2) | |
__call__: Callable[..., Any] = forward | |