NandiniLokeshReddy's picture
Upload 38 files
d8dd7fb verified
raw
history blame
4.36 kB
#ONNX export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
import onnxruntime
import torch
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt
import onnx_models
def export_onnx(onnx_model, output, dynamic_axes, dummy_inputs, output_names):
with open(output, "wb") as f:
print(f"Exporting onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
inference_session = onnxruntime.InferenceSession(output)
output = inference_session.run(
output_names=output_names,
input_feed={k: v.numpy() for k, v in dummy_inputs.items()},
)
print(output_names)
print([output_i.shape for output_i in output])
def export_onnx_esam(model, output):
onnx_model = onnx_models.OnnxEfficientSam(model=model)
dynamic_axes = {
"batched_images": {0: "batch", 2: "height", 3: "width"},
"batched_point_coords": {2: "num_points"},
"batched_point_labels": {2: "num_points"},
}
dummy_inputs = {
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
"batched_point_coords": torch.randint(
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
),
"batched_point_labels": torch.randint(
low=0, high=4, size=(1, 1, 5), dtype=torch.float
),
}
output_names = ["output_masks", "iou_predictions"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def export_onnx_esam_encoder(model, output):
onnx_model = onnx_models.OnnxEfficientSamEncoder(model=model)
dynamic_axes = {
"batched_images": {0: "batch", 2: "height", 3: "width"},
}
dummy_inputs = {
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
}
output_names = ["image_embeddings"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def export_onnx_esam_decoder(model, output):
onnx_model = onnx_models.OnnxEfficientSamDecoder(model=model)
dynamic_axes = {
"image_embeddings": {0: "batch"},
"batched_point_coords": {2: "num_points"},
"batched_point_labels": {2: "num_points"},
}
dummy_inputs = {
"image_embeddings": torch.randn(1, 256, 64, 64, dtype=torch.float),
"batched_point_coords": torch.randint(
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
),
"batched_point_labels": torch.randint(
low=0, high=4, size=(1, 1, 5), dtype=torch.float
),
"orig_im_size": torch.tensor([1080, 1920], dtype=torch.long),
}
output_names = ["output_masks", "iou_predictions"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)
def main():
# faster
export_onnx_esam(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt.onnx",
)
export_onnx_esam_encoder(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt_encoder.onnx",
)
export_onnx_esam_decoder(
model=build_efficient_sam_vitt(),
output="weights/efficient_sam_vitt_decoder.onnx",
)
# more accurate
export_onnx_esam(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits.onnx",
)
export_onnx_esam_encoder(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits_encoder.onnx",
)
export_onnx_esam_decoder(
model=build_efficient_sam_vits(),
output="weights/efficient_sam_vits_decoder.onnx",
)
if __name__ == "__main__":
main()