NandiniLokeshReddy's picture
Upload 38 files
d8dd7fb verified
raw
history blame
5.47 kB
#Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
import torch
import torch.nn.functional as F
class OnnxEfficientSam(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
@property
def decoder_max_num_input_points(self):
return self.model.decoder_max_num_input_points
@property
def image_encoder(self):
return self.model.image_encoder
@property
def get_image_embeddings(self):
return self.model.get_image_embeddings
@property
def prompt_encoder(self):
return self.model.prompt_encoder
@property
def mask_decoder(self):
return self.model.mask_decoder
def forward(
self,
batched_images: torch.Tensor,
batched_points: torch.Tensor,
batched_point_labels: torch.Tensor,
):
batch_size, _, input_h, input_w = batched_images.shape
image_embeddings = self.get_image_embeddings(batched_images)
return self.predict_masks(
image_embeddings,
batched_points,
batched_point_labels,
multimask_output=True,
input_h=input_h,
input_w=input_w,
output_h=input_h,
output_w=input_w,
)
def get_rescaled_pts(
self, batched_points: torch.Tensor, input_h: int, input_w: int
):
return torch.stack(
[
batched_points[..., 0] * self.image_encoder.img_size / input_w,
batched_points[..., 1] * self.image_encoder.img_size / input_h,
],
dim=-1,
)
def predict_masks(
self,
image_embeddings: torch.Tensor,
batched_points: torch.Tensor,
batched_point_labels: torch.Tensor,
multimask_output: bool,
input_h: int,
input_w: int,
output_h: int = -1,
output_w: int = -1,
):
batch_size, max_num_queries, num_pts, _ = batched_points.shape
num_pts = batched_points.shape[2]
rescaled_batched_points = self.get_rescaled_pts(
batched_points, input_h, input_w
)
if num_pts > self.decoder_max_num_input_points:
rescaled_batched_points = rescaled_batched_points[
:, :, : self.decoder_max_num_input_points, :
]
batched_point_labels = batched_point_labels[
:, :, : self.decoder_max_num_input_points
]
elif num_pts < self.decoder_max_num_input_points:
rescaled_batched_points = F.pad(
rescaled_batched_points,
(0, 0, 0, self.decoder_max_num_input_points - num_pts),
value=-1.0,
)
batched_point_labels = F.pad(
batched_point_labels,
(0, self.decoder_max_num_input_points - num_pts),
value=-1.0,
)
sparse_embeddings = self.prompt_encoder(
rescaled_batched_points.reshape(
batch_size * max_num_queries, self.decoder_max_num_input_points, 2
),
batched_point_labels.reshape(
batch_size * max_num_queries, self.decoder_max_num_input_points
),
)
sparse_embeddings = sparse_embeddings.view(
batch_size,
max_num_queries,
sparse_embeddings.shape[1],
sparse_embeddings.shape[2],
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings,
self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
multimask_output=multimask_output,
)
_, num_predictions, low_res_size, _ = low_res_masks.shape
if output_w > 0 and output_h > 0:
output_masks = F.interpolate(
low_res_masks,
(output_h, output_w),
# NOTE: "bicubic" is inefficient on onnx
mode="bilinear",
)
output_masks = torch.reshape(
output_masks,
(batch_size, max_num_queries, num_predictions, output_h, output_w),
)
else:
output_masks = torch.reshape(
low_res_masks,
(
batch_size,
max_num_queries,
num_predictions,
low_res_size,
low_res_size,
),
)
iou_predictions = torch.reshape(
iou_predictions, (batch_size, max_num_queries, num_predictions)
)
return output_masks, iou_predictions, low_res_masks
class OnnxEfficientSamEncoder(OnnxEfficientSam):
def forward(self, batched_images: torch.Tensor):
return self.model.get_image_embeddings(batched_images)
class OnnxEfficientSamDecoder(OnnxEfficientSam):
def forward(
self, image_embeddings, batched_points, batched_point_labels, orig_im_size
):
return self.predict_masks(
image_embeddings=image_embeddings,
batched_points=batched_points,
batched_point_labels=batched_point_labels,
multimask_output=True,
input_h=orig_im_size[0],
input_w=orig_im_size[1],
output_h=orig_im_size[0],
output_w=orig_im_size[1],
)