|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
class OnnxEfficientSam(torch.nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
@property |
|
def decoder_max_num_input_points(self): |
|
return self.model.decoder_max_num_input_points |
|
|
|
@property |
|
def image_encoder(self): |
|
return self.model.image_encoder |
|
|
|
@property |
|
def get_image_embeddings(self): |
|
return self.model.get_image_embeddings |
|
|
|
@property |
|
def prompt_encoder(self): |
|
return self.model.prompt_encoder |
|
|
|
@property |
|
def mask_decoder(self): |
|
return self.model.mask_decoder |
|
|
|
def forward( |
|
self, |
|
batched_images: torch.Tensor, |
|
batched_points: torch.Tensor, |
|
batched_point_labels: torch.Tensor, |
|
): |
|
batch_size, _, input_h, input_w = batched_images.shape |
|
image_embeddings = self.get_image_embeddings(batched_images) |
|
return self.predict_masks( |
|
image_embeddings, |
|
batched_points, |
|
batched_point_labels, |
|
multimask_output=True, |
|
input_h=input_h, |
|
input_w=input_w, |
|
output_h=input_h, |
|
output_w=input_w, |
|
) |
|
|
|
def get_rescaled_pts( |
|
self, batched_points: torch.Tensor, input_h: int, input_w: int |
|
): |
|
return torch.stack( |
|
[ |
|
batched_points[..., 0] * self.image_encoder.img_size / input_w, |
|
batched_points[..., 1] * self.image_encoder.img_size / input_h, |
|
], |
|
dim=-1, |
|
) |
|
|
|
def predict_masks( |
|
self, |
|
image_embeddings: torch.Tensor, |
|
batched_points: torch.Tensor, |
|
batched_point_labels: torch.Tensor, |
|
multimask_output: bool, |
|
input_h: int, |
|
input_w: int, |
|
output_h: int = -1, |
|
output_w: int = -1, |
|
): |
|
batch_size, max_num_queries, num_pts, _ = batched_points.shape |
|
num_pts = batched_points.shape[2] |
|
rescaled_batched_points = self.get_rescaled_pts( |
|
batched_points, input_h, input_w |
|
) |
|
|
|
if num_pts > self.decoder_max_num_input_points: |
|
rescaled_batched_points = rescaled_batched_points[ |
|
:, :, : self.decoder_max_num_input_points, : |
|
] |
|
batched_point_labels = batched_point_labels[ |
|
:, :, : self.decoder_max_num_input_points |
|
] |
|
elif num_pts < self.decoder_max_num_input_points: |
|
rescaled_batched_points = F.pad( |
|
rescaled_batched_points, |
|
(0, 0, 0, self.decoder_max_num_input_points - num_pts), |
|
value=-1.0, |
|
) |
|
batched_point_labels = F.pad( |
|
batched_point_labels, |
|
(0, self.decoder_max_num_input_points - num_pts), |
|
value=-1.0, |
|
) |
|
|
|
sparse_embeddings = self.prompt_encoder( |
|
rescaled_batched_points.reshape( |
|
batch_size * max_num_queries, self.decoder_max_num_input_points, 2 |
|
), |
|
batched_point_labels.reshape( |
|
batch_size * max_num_queries, self.decoder_max_num_input_points |
|
), |
|
) |
|
sparse_embeddings = sparse_embeddings.view( |
|
batch_size, |
|
max_num_queries, |
|
sparse_embeddings.shape[1], |
|
sparse_embeddings.shape[2], |
|
) |
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
image_embeddings, |
|
self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
multimask_output=multimask_output, |
|
) |
|
_, num_predictions, low_res_size, _ = low_res_masks.shape |
|
|
|
if output_w > 0 and output_h > 0: |
|
output_masks = F.interpolate( |
|
low_res_masks, |
|
(output_h, output_w), |
|
|
|
mode="bilinear", |
|
) |
|
output_masks = torch.reshape( |
|
output_masks, |
|
(batch_size, max_num_queries, num_predictions, output_h, output_w), |
|
) |
|
else: |
|
output_masks = torch.reshape( |
|
low_res_masks, |
|
( |
|
batch_size, |
|
max_num_queries, |
|
num_predictions, |
|
low_res_size, |
|
low_res_size, |
|
), |
|
) |
|
iou_predictions = torch.reshape( |
|
iou_predictions, (batch_size, max_num_queries, num_predictions) |
|
) |
|
return output_masks, iou_predictions, low_res_masks |
|
|
|
|
|
class OnnxEfficientSamEncoder(OnnxEfficientSam): |
|
def forward(self, batched_images: torch.Tensor): |
|
return self.model.get_image_embeddings(batched_images) |
|
|
|
|
|
class OnnxEfficientSamDecoder(OnnxEfficientSam): |
|
def forward( |
|
self, image_embeddings, batched_points, batched_point_labels, orig_im_size |
|
): |
|
return self.predict_masks( |
|
image_embeddings=image_embeddings, |
|
batched_points=batched_points, |
|
batched_point_labels=batched_point_labels, |
|
multimask_output=True, |
|
input_h=orig_im_size[0], |
|
input_w=orig_im_size[1], |
|
output_h=orig_im_size[0], |
|
output_w=orig_im_size[1], |
|
) |
|
|