NandiniLokeshReddy
commited on
Upload 38 files
Browse files- .gitattributes +4 -0
- EfficientSAM-main/.DS_Store +0 -0
- EfficientSAM-main/.gitignore +5 -0
- EfficientSAM-main/EfficientSAM_example.py +54 -0
- EfficientSAM-main/EfficientSAM_onnx_example.py +77 -0
- EfficientSAM-main/LICENSE +201 -0
- EfficientSAM-main/README.md +63 -0
- EfficientSAM-main/efficient_sam/__init__.py +7 -0
- EfficientSAM-main/efficient_sam/build_efficient_sam.py +22 -0
- EfficientSAM-main/efficient_sam/efficient_sam.py +305 -0
- EfficientSAM-main/efficient_sam/efficient_sam_decoder.py +315 -0
- EfficientSAM-main/efficient_sam/efficient_sam_encoder.py +257 -0
- EfficientSAM-main/efficient_sam/mlp.py +29 -0
- EfficientSAM-main/efficient_sam/two_way_transformer.py +266 -0
- EfficientSAM-main/export_to_onnx.py +139 -0
- EfficientSAM-main/export_to_torchscript.py +17 -0
- EfficientSAM-main/figs/.DS_Store +0 -0
- EfficientSAM-main/figs/examples/demo_box.png +3 -0
- EfficientSAM-main/figs/examples/demo_everything.png +3 -0
- EfficientSAM-main/figs/examples/demo_point.png +3 -0
- EfficientSAM-main/figs/examples/demo_saliency.png +3 -0
- EfficientSAM-main/figs/examples/dogs.jpg +0 -0
- EfficientSAM-main/figs/examples/dogs_efficient_sam_vits_mask.png +0 -0
- EfficientSAM-main/figs/examples/dogs_efficient_sam_vitt_mask.png +0 -0
- EfficientSAM-main/figs/examples/dogs_efficientsam_s_mask.png +0 -0
- EfficientSAM-main/figs/examples/dogs_efficientsam_ti_mask.png +0 -0
- EfficientSAM-main/figs/examples/dogs_squeeze_sam_mask.png +0 -0
- EfficientSAM-main/linter.sh +32 -0
- EfficientSAM-main/notebooks/EfficientSAM_example.ipynb +0 -0
- EfficientSAM-main/notebooks/EfficientSAM_segment_everything_example.ipynb +0 -0
- EfficientSAM-main/onnx_models.py +166 -0
- EfficientSAM-main/setup.cfg +11 -0
- EfficientSAM-main/setup.py +18 -0
- EfficientSAM-main/torchscripted_model/efficient_sam_vitt_torchscript.pt +3 -0
- EfficientSAM-main/weights/efficient_sam_vits.pt.zip +3 -0
- EfficientSAM-main/weights/efficient_sam_vitt.onnx +3 -0
- EfficientSAM-main/weights/efficient_sam_vitt.pt +3 -0
- EfficientSAM-main/weights/efficient_sam_vitt_decoder.onnx +3 -0
- EfficientSAM-main/weights/efficient_sam_vitt_encoder.onnx +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
EfficientSAM-main/figs/examples/demo_box.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
EfficientSAM-main/figs/examples/demo_everything.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
EfficientSAM-main/figs/examples/demo_point.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
EfficientSAM-main/figs/examples/demo_saliency.png filter=lfs diff=lfs merge=lfs -text
|
EfficientSAM-main/.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
EfficientSAM-main/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.pyo
|
3 |
+
*.pyd
|
4 |
+
__py
|
5 |
+
**/__pycache__/
|
EfficientSAM-main/EfficientSAM_example.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
|
2 |
+
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
models = {}
|
13 |
+
|
14 |
+
# Build the EfficientSAM-Ti model.
|
15 |
+
models['efficientsam_ti'] = build_efficient_sam_vitt()
|
16 |
+
|
17 |
+
# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
|
18 |
+
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
|
19 |
+
zip_ref.extractall("weights")
|
20 |
+
# Build the EfficientSAM-S model.
|
21 |
+
models['efficientsam_s'] = build_efficient_sam_vits()
|
22 |
+
|
23 |
+
# Build the SqueezeSAM model.
|
24 |
+
# models['squeeze_sam'] = build_squeeze_sam()
|
25 |
+
|
26 |
+
# load an image
|
27 |
+
sample_image_np = np.array(Image.open("figs/examples/dogs.jpg"))
|
28 |
+
sample_image_tensor = transforms.ToTensor()(sample_image_np)
|
29 |
+
# Feed a few (x,y) points in the mask as input.
|
30 |
+
|
31 |
+
input_points = torch.tensor([[[[580, 350], [650, 350]]]])
|
32 |
+
input_labels = torch.tensor([[[1, 1]]])
|
33 |
+
|
34 |
+
# Run inference for both EfficientSAM-Ti and EfficientSAM-S models.
|
35 |
+
for model_name, model in models.items():
|
36 |
+
print('Running inference using ', model_name)
|
37 |
+
predicted_logits, predicted_iou = model(
|
38 |
+
sample_image_tensor[None, ...],
|
39 |
+
input_points,
|
40 |
+
input_labels,
|
41 |
+
)
|
42 |
+
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
|
43 |
+
predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
|
44 |
+
predicted_logits = torch.take_along_dim(
|
45 |
+
predicted_logits, sorted_ids[..., None, None], dim=2
|
46 |
+
)
|
47 |
+
# The masks are already sorted by their predicted IOUs.
|
48 |
+
# The first dimension is the batch size (we have a single image. so it is 1).
|
49 |
+
# The second dimension is the number of masks we want to generate (in this case, it is only 1)
|
50 |
+
# The third dimension is the number of candidate masks output by the model.
|
51 |
+
# For this demo we use the first mask.
|
52 |
+
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
|
53 |
+
masked_image_np = sample_image_np.copy().astype(np.uint8) * mask[:,:,None]
|
54 |
+
Image.fromarray(masked_image_np).save(f"figs/examples/dogs_{model_name}_mask.png")
|
EfficientSAM-main/EfficientSAM_onnx_example.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import imgviz
|
7 |
+
import onnxruntime
|
8 |
+
import time
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
def predict_onnx(input_image, input_points, input_labels):
|
13 |
+
if 0:
|
14 |
+
inference_session = onnxruntime.InferenceSession(
|
15 |
+
"weights/efficient_sam_vitt.onnx"
|
16 |
+
)
|
17 |
+
(
|
18 |
+
predicted_logits,
|
19 |
+
predicted_iou,
|
20 |
+
predicted_lowres_logits,
|
21 |
+
) = inference_session.run(
|
22 |
+
output_names=None,
|
23 |
+
input_feed={
|
24 |
+
"batched_images": input_image,
|
25 |
+
"batched_point_coords": input_points,
|
26 |
+
"batched_point_labels": input_labels,
|
27 |
+
},
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
inference_session = onnxruntime.InferenceSession(
|
31 |
+
"weights/efficient_sam_vitt_encoder.onnx"
|
32 |
+
)
|
33 |
+
t_start = time.time()
|
34 |
+
image_embeddings, = inference_session.run(
|
35 |
+
output_names=None,
|
36 |
+
input_feed={
|
37 |
+
"batched_images": input_image,
|
38 |
+
},
|
39 |
+
)
|
40 |
+
print("encoder time", time.time() - t_start)
|
41 |
+
|
42 |
+
inference_session = onnxruntime.InferenceSession(
|
43 |
+
"weights/efficient_sam_vitt_decoder.onnx"
|
44 |
+
)
|
45 |
+
t_start = time.time()
|
46 |
+
(
|
47 |
+
predicted_logits,
|
48 |
+
predicted_iou,
|
49 |
+
predicted_lowres_logits,
|
50 |
+
) = inference_session.run(
|
51 |
+
output_names=None,
|
52 |
+
input_feed={
|
53 |
+
"image_embeddings": image_embeddings,
|
54 |
+
"batched_point_coords": input_points,
|
55 |
+
"batched_point_labels": input_labels,
|
56 |
+
"orig_im_size": np.array(input_image.shape[2:], dtype=np.int64),
|
57 |
+
},
|
58 |
+
)
|
59 |
+
print("decoder time", time.time() - t_start)
|
60 |
+
mask = predicted_logits[0, 0, 0, :, :] >= 0
|
61 |
+
imgviz.io.imsave(f"figs/examples/dogs_onnx_mask.png", mask)
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
image = np.array(Image.open("figs/examples/dogs.jpg"))
|
66 |
+
|
67 |
+
input_image = image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
|
68 |
+
# batch_size, num_queries, num_points, 2
|
69 |
+
input_points = np.array([[[[580, 350], [650, 350]]]], dtype=np.float32)
|
70 |
+
# batch_size, num_queries, num_points
|
71 |
+
input_labels = np.array([[[1, 1]]], dtype=np.float32)
|
72 |
+
|
73 |
+
predict_onnx(input_image, input_points, input_labels)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
EfficientSAM-main/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
EfficientSAM-main/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientSAM
|
2 |
+
EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
|
3 |
+
|
4 |
+
## News
|
5 |
+
[Jan.12 2024] ONNX version of EfficientSAM including separate encoder and decoder is available on the [Hugging Face Space](https://huggingface.co/spaces/yunyangx/EfficientSAM/tree/main) (thanks to @wkentaro Kentaro Wada for implementing onnx export)
|
6 |
+
|
7 |
+
[Dec.31 2023] EfficientSAM is integrated into the annotation tool, [Labelme](https://github.com/labelmeai/labelme) (huge thanks to lableme team and @wkentaro Kentaro Wada)
|
8 |
+
|
9 |
+
[Dec.11 2023] The EfficientSAM model code with checkpoints is fully available in this repository. The [example](https://github.com/yformer/EfficientSAM/blob/main/EfficientSAM_example.py) script shows how to instantiate the model with checkpoint and query points on an image.
|
10 |
+
|
11 |
+
[Dec.10 2023] Grounded EfficientSAM demo is available on [Grounded-Efficient-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM) (huge thanks to IDEA-Research team and @rentainhe for supporting [grounded-efficient-sam demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/EfficientSAM/grounded_efficient_sam.py) under [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)).
|
12 |
+
|
13 |
+
[Dec.6 2023] EfficientSAM demo is available on the [Hugging Face Space](https://huggingface.co/spaces/yunyangx/EfficientSAM) (huge thanks to all the HF team for their support).
|
14 |
+
|
15 |
+
[Dec.5 2023] We release the torchscript version of EfficientSAM and share a colab.
|
16 |
+
|
17 |
+
## Online Demo & Examples
|
18 |
+
Online demo and examples can be found in the [project page](https://yformer.github.io/efficient-sam/).
|
19 |
+
|
20 |
+
## EfficientSAM Instance Segmentation Examples
|
21 |
+
| | |
|
22 |
+
:-------------------------:|:-------------------------:
|
23 |
+
Point-prompt | ![point-prompt](figs/examples/demo_point.png)
|
24 |
+
Box-prompt | ![box-prompt](figs/examples/demo_box.png)
|
25 |
+
Segment everything |![segment everything](figs/examples/demo_everything.png)
|
26 |
+
Saliency | ![Saliency](figs/examples/demo_saliency.png)
|
27 |
+
|
28 |
+
## Model
|
29 |
+
EfficientSAM checkpoints are available under the weights folder of this github repository. Example instantiations and run of the models can be found in [EfficientSAM_example.py](https://github.com/yformer/EfficientSAM/blob/main/EfficientSAM_example.py).
|
30 |
+
|
31 |
+
| EfficientSAM-S | EfficientSAM-Ti |
|
32 |
+
|------------------------------|------------------------------|
|
33 |
+
| [Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vits.pt.zip) |[Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vitt.pt)|
|
34 |
+
|
35 |
+
You can directly use EfficientSAM with checkpoints,
|
36 |
+
```
|
37 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
|
38 |
+
efficientsam = build_efficient_sam_vitt()
|
39 |
+
```
|
40 |
+
|
41 |
+
## Jupyter Notebook Example
|
42 |
+
The notebook is shared [here](https://github.com/yformer/EfficientSAM/blob/main/notebooks)
|
43 |
+
|
44 |
+
|
45 |
+
## Acknowledgement
|
46 |
+
|
47 |
+
+ [SAM](https://github.com/facebookresearch/segment-anything)
|
48 |
+
+ [MobileSAM](https://github.com/ChaoningZhang/MobileSAM)
|
49 |
+
+ [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
|
50 |
+
+ [U-2-Net](https://github.com/xuebinqin/U-2-Net)
|
51 |
+
|
52 |
+
|
53 |
+
If you're using EfficientSAM in your research or applications, please cite using this BibTeX:
|
54 |
+
```bibtex
|
55 |
+
|
56 |
+
|
57 |
+
@article{xiong2023efficientsam,
|
58 |
+
title={EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything},
|
59 |
+
author={Yunyang Xiong, Bala Varadarajan, Lemeng Wu, Xiaoyu Xiang, Fanyi Xiao, Chenchen Zhu, Xiaoliang Dai, Dilin Wang, Fei Sun, Forrest Iandola, Raghuraman Krishnamoorthi, Vikas Chandra},
|
60 |
+
journal={arXiv:2312.00863},
|
61 |
+
year={2023}
|
62 |
+
}
|
63 |
+
```
|
EfficientSAM-main/efficient_sam/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
from .build_efficient_sam import (
|
5 |
+
build_efficient_sam_vitt,
|
6 |
+
build_efficient_sam_vits,
|
7 |
+
)
|
EfficientSAM-main/efficient_sam/build_efficient_sam.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .efficient_sam import build_efficient_sam
|
8 |
+
|
9 |
+
def build_efficient_sam_vitt():
|
10 |
+
return build_efficient_sam(
|
11 |
+
encoder_patch_embed_dim=192,
|
12 |
+
encoder_num_heads=3,
|
13 |
+
checkpoint="weights/efficient_sam_vitt.pt",
|
14 |
+
).eval()
|
15 |
+
|
16 |
+
|
17 |
+
def build_efficient_sam_vits():
|
18 |
+
return build_efficient_sam(
|
19 |
+
encoder_patch_embed_dim=384,
|
20 |
+
encoder_num_heads=6,
|
21 |
+
checkpoint="weights/efficient_sam_vits.pt",
|
22 |
+
).eval()
|
EfficientSAM-main/efficient_sam/efficient_sam.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Any, List, Tuple, Type
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from torch import nn, Tensor
|
14 |
+
|
15 |
+
from .efficient_sam_decoder import MaskDecoder, PromptEncoder
|
16 |
+
from .efficient_sam_encoder import ImageEncoderViT
|
17 |
+
from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
|
18 |
+
|
19 |
+
class EfficientSam(nn.Module):
|
20 |
+
mask_threshold: float = 0.0
|
21 |
+
image_format: str = "RGB"
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
image_encoder: ImageEncoderViT,
|
26 |
+
prompt_encoder: PromptEncoder,
|
27 |
+
decoder_max_num_input_points: int,
|
28 |
+
mask_decoder: MaskDecoder,
|
29 |
+
pixel_mean: List[float] = [0.485, 0.456, 0.406],
|
30 |
+
pixel_std: List[float] = [0.229, 0.224, 0.225],
|
31 |
+
) -> None:
|
32 |
+
"""
|
33 |
+
SAM predicts object masks from an image and input prompts.
|
34 |
+
|
35 |
+
Arguments:
|
36 |
+
image_encoder (ImageEncoderViT): The backbone used to encode the
|
37 |
+
image into image embeddings that allow for efficient mask prediction.
|
38 |
+
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
39 |
+
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
40 |
+
and encoded prompts.
|
41 |
+
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
42 |
+
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
self.image_encoder = image_encoder
|
46 |
+
self.prompt_encoder = prompt_encoder
|
47 |
+
self.decoder_max_num_input_points = decoder_max_num_input_points
|
48 |
+
self.mask_decoder = mask_decoder
|
49 |
+
self.register_buffer(
|
50 |
+
"pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
|
51 |
+
)
|
52 |
+
self.register_buffer(
|
53 |
+
"pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
|
54 |
+
)
|
55 |
+
|
56 |
+
@torch.jit.export
|
57 |
+
def predict_masks(
|
58 |
+
self,
|
59 |
+
image_embeddings: torch.Tensor,
|
60 |
+
batched_points: torch.Tensor,
|
61 |
+
batched_point_labels: torch.Tensor,
|
62 |
+
multimask_output: bool,
|
63 |
+
input_h: int,
|
64 |
+
input_w: int,
|
65 |
+
output_h: int = -1,
|
66 |
+
output_w: int = -1,
|
67 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
68 |
+
"""
|
69 |
+
Predicts masks given image embeddings and prompts. This only runs the decoder.
|
70 |
+
|
71 |
+
Arguments:
|
72 |
+
image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
|
73 |
+
batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
|
74 |
+
batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
|
75 |
+
Returns:
|
76 |
+
A tuple of two tensors:
|
77 |
+
low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
|
78 |
+
iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
|
79 |
+
"""
|
80 |
+
|
81 |
+
batch_size, max_num_queries, num_pts, _ = batched_points.shape
|
82 |
+
num_pts = batched_points.shape[2]
|
83 |
+
rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
|
84 |
+
|
85 |
+
if num_pts > self.decoder_max_num_input_points:
|
86 |
+
rescaled_batched_points = rescaled_batched_points[
|
87 |
+
:, :, : self.decoder_max_num_input_points, :
|
88 |
+
]
|
89 |
+
batched_point_labels = batched_point_labels[
|
90 |
+
:, :, : self.decoder_max_num_input_points
|
91 |
+
]
|
92 |
+
elif num_pts < self.decoder_max_num_input_points:
|
93 |
+
rescaled_batched_points = F.pad(
|
94 |
+
rescaled_batched_points,
|
95 |
+
(0, 0, 0, self.decoder_max_num_input_points - num_pts),
|
96 |
+
value=-1.0,
|
97 |
+
)
|
98 |
+
batched_point_labels = F.pad(
|
99 |
+
batched_point_labels,
|
100 |
+
(0, self.decoder_max_num_input_points - num_pts),
|
101 |
+
value=-1.0,
|
102 |
+
)
|
103 |
+
|
104 |
+
sparse_embeddings = self.prompt_encoder(
|
105 |
+
rescaled_batched_points.reshape(
|
106 |
+
batch_size * max_num_queries, self.decoder_max_num_input_points, 2
|
107 |
+
),
|
108 |
+
batched_point_labels.reshape(
|
109 |
+
batch_size * max_num_queries, self.decoder_max_num_input_points
|
110 |
+
),
|
111 |
+
)
|
112 |
+
sparse_embeddings = sparse_embeddings.view(
|
113 |
+
batch_size,
|
114 |
+
max_num_queries,
|
115 |
+
sparse_embeddings.shape[1],
|
116 |
+
sparse_embeddings.shape[2],
|
117 |
+
)
|
118 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
119 |
+
image_embeddings,
|
120 |
+
self.prompt_encoder.get_dense_pe(),
|
121 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
122 |
+
multimask_output=multimask_output,
|
123 |
+
)
|
124 |
+
_, num_predictions, low_res_size, _ = low_res_masks.shape
|
125 |
+
|
126 |
+
if output_w > 0 and output_h > 0:
|
127 |
+
output_masks = F.interpolate(
|
128 |
+
low_res_masks, (output_h, output_w), mode="bicubic"
|
129 |
+
)
|
130 |
+
output_masks = torch.reshape(
|
131 |
+
output_masks,
|
132 |
+
(batch_size, max_num_queries, num_predictions, output_h, output_w),
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
output_masks = torch.reshape(
|
136 |
+
low_res_masks,
|
137 |
+
(
|
138 |
+
batch_size,
|
139 |
+
max_num_queries,
|
140 |
+
num_predictions,
|
141 |
+
low_res_size,
|
142 |
+
low_res_size,
|
143 |
+
),
|
144 |
+
)
|
145 |
+
iou_predictions = torch.reshape(
|
146 |
+
iou_predictions, (batch_size, max_num_queries, num_predictions)
|
147 |
+
)
|
148 |
+
return output_masks, iou_predictions
|
149 |
+
|
150 |
+
def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
|
151 |
+
return torch.stack(
|
152 |
+
[
|
153 |
+
torch.where(
|
154 |
+
batched_points[..., 0] >= 0,
|
155 |
+
batched_points[..., 0] * self.image_encoder.img_size / input_w,
|
156 |
+
-1.0,
|
157 |
+
),
|
158 |
+
torch.where(
|
159 |
+
batched_points[..., 1] >= 0,
|
160 |
+
batched_points[..., 1] * self.image_encoder.img_size / input_h,
|
161 |
+
-1.0,
|
162 |
+
),
|
163 |
+
],
|
164 |
+
dim=-1,
|
165 |
+
)
|
166 |
+
|
167 |
+
@torch.jit.export
|
168 |
+
def get_image_embeddings(self, batched_images) -> torch.Tensor:
|
169 |
+
"""
|
170 |
+
Predicts masks end-to-end from provided images and prompts.
|
171 |
+
If prompts are not known in advance, using SamPredictor is
|
172 |
+
recommended over calling the model directly.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
batched_images: A tensor of shape [B, 3, H, W]
|
176 |
+
Returns:
|
177 |
+
List of image embeddings each of of shape [B, C(i), H(i), W(i)].
|
178 |
+
The last embedding corresponds to the final layer.
|
179 |
+
"""
|
180 |
+
batched_images = self.preprocess(batched_images)
|
181 |
+
return self.image_encoder(batched_images)
|
182 |
+
|
183 |
+
def forward(
|
184 |
+
self,
|
185 |
+
batched_images: torch.Tensor,
|
186 |
+
batched_points: torch.Tensor,
|
187 |
+
batched_point_labels: torch.Tensor,
|
188 |
+
scale_to_original_image_size: bool = True,
|
189 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
190 |
+
"""
|
191 |
+
Predicts masks end-to-end from provided images and prompts.
|
192 |
+
If prompts are not known in advance, using SamPredictor is
|
193 |
+
recommended over calling the model directly.
|
194 |
+
|
195 |
+
Arguments:
|
196 |
+
batched_images: A tensor of shape [B, 3, H, W]
|
197 |
+
batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
|
198 |
+
batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
A list tuples of two tensors where the ith element is by considering the first i+1 points.
|
202 |
+
low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
|
203 |
+
iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
|
204 |
+
"""
|
205 |
+
batch_size, _, input_h, input_w = batched_images.shape
|
206 |
+
image_embeddings = self.get_image_embeddings(batched_images)
|
207 |
+
return self.predict_masks(
|
208 |
+
image_embeddings,
|
209 |
+
batched_points,
|
210 |
+
batched_point_labels,
|
211 |
+
multimask_output=True,
|
212 |
+
input_h=input_h,
|
213 |
+
input_w=input_w,
|
214 |
+
output_h=input_h if scale_to_original_image_size else -1,
|
215 |
+
output_w=input_w if scale_to_original_image_size else -1,
|
216 |
+
)
|
217 |
+
|
218 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
219 |
+
"""Normalize pixel values and pad to a square input."""
|
220 |
+
if (
|
221 |
+
x.shape[2] != self.image_encoder.img_size
|
222 |
+
or x.shape[3] != self.image_encoder.img_size
|
223 |
+
):
|
224 |
+
x = F.interpolate(
|
225 |
+
x,
|
226 |
+
(self.image_encoder.img_size, self.image_encoder.img_size),
|
227 |
+
mode="bilinear",
|
228 |
+
)
|
229 |
+
return (x - self.pixel_mean) / self.pixel_std
|
230 |
+
|
231 |
+
|
232 |
+
def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
|
233 |
+
img_size = 1024
|
234 |
+
encoder_patch_size = 16
|
235 |
+
encoder_depth = 12
|
236 |
+
encoder_mlp_ratio = 4.0
|
237 |
+
encoder_neck_dims = [256, 256]
|
238 |
+
decoder_max_num_input_points = 6
|
239 |
+
decoder_transformer_depth = 2
|
240 |
+
decoder_transformer_mlp_dim = 2048
|
241 |
+
decoder_num_heads = 8
|
242 |
+
decoder_upscaling_layer_dims = [64, 32]
|
243 |
+
num_multimask_outputs = 3
|
244 |
+
iou_head_depth = 3
|
245 |
+
iou_head_hidden_dim = 256
|
246 |
+
activation = "gelu"
|
247 |
+
normalization_type = "layer_norm"
|
248 |
+
normalize_before_activation = False
|
249 |
+
|
250 |
+
assert activation == "relu" or activation == "gelu"
|
251 |
+
if activation == "relu":
|
252 |
+
activation_fn = nn.ReLU
|
253 |
+
else:
|
254 |
+
activation_fn = nn.GELU
|
255 |
+
|
256 |
+
image_encoder = ImageEncoderViT(
|
257 |
+
img_size=img_size,
|
258 |
+
patch_size=encoder_patch_size,
|
259 |
+
in_chans=3,
|
260 |
+
patch_embed_dim=encoder_patch_embed_dim,
|
261 |
+
normalization_type=normalization_type,
|
262 |
+
depth=encoder_depth,
|
263 |
+
num_heads=encoder_num_heads,
|
264 |
+
mlp_ratio=encoder_mlp_ratio,
|
265 |
+
neck_dims=encoder_neck_dims,
|
266 |
+
act_layer=activation_fn,
|
267 |
+
)
|
268 |
+
|
269 |
+
image_embedding_size = image_encoder.image_embedding_size
|
270 |
+
encoder_transformer_output_dim = image_encoder.transformer_output_dim
|
271 |
+
|
272 |
+
sam = EfficientSam(
|
273 |
+
image_encoder=image_encoder,
|
274 |
+
prompt_encoder=PromptEncoder(
|
275 |
+
embed_dim=encoder_transformer_output_dim,
|
276 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
277 |
+
input_image_size=(img_size, img_size),
|
278 |
+
),
|
279 |
+
decoder_max_num_input_points=decoder_max_num_input_points,
|
280 |
+
mask_decoder=MaskDecoder(
|
281 |
+
transformer_dim=encoder_transformer_output_dim,
|
282 |
+
transformer=TwoWayTransformer(
|
283 |
+
depth=decoder_transformer_depth,
|
284 |
+
embedding_dim=encoder_transformer_output_dim,
|
285 |
+
num_heads=decoder_num_heads,
|
286 |
+
mlp_dim=decoder_transformer_mlp_dim,
|
287 |
+
activation=activation_fn,
|
288 |
+
normalize_before_activation=normalize_before_activation,
|
289 |
+
),
|
290 |
+
num_multimask_outputs=num_multimask_outputs,
|
291 |
+
activation=activation_fn,
|
292 |
+
normalization_type=normalization_type,
|
293 |
+
normalize_before_activation=normalize_before_activation,
|
294 |
+
iou_head_depth=iou_head_depth - 1,
|
295 |
+
iou_head_hidden_dim=iou_head_hidden_dim,
|
296 |
+
upscaling_layer_dims=decoder_upscaling_layer_dims,
|
297 |
+
),
|
298 |
+
pixel_mean=[0.485, 0.456, 0.406],
|
299 |
+
pixel_std=[0.229, 0.224, 0.225],
|
300 |
+
)
|
301 |
+
if checkpoint is not None:
|
302 |
+
with open(checkpoint, "rb") as f:
|
303 |
+
state_dict = torch.load(f, map_location="cpu")
|
304 |
+
sam.load_state_dict(state_dict["model"])
|
305 |
+
return sam
|
EfficientSAM-main/efficient_sam/efficient_sam_decoder.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import List, Tuple, Type
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .mlp import MLPBlock
|
15 |
+
|
16 |
+
|
17 |
+
class PromptEncoder(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
embed_dim: int,
|
21 |
+
image_embedding_size: Tuple[int, int],
|
22 |
+
input_image_size: Tuple[int, int],
|
23 |
+
) -> None:
|
24 |
+
"""
|
25 |
+
Encodes prompts for input to SAM's mask decoder.
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
embed_dim (int): The prompts' embedding dimension
|
29 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
30 |
+
image embedding, as (H, W).
|
31 |
+
input_image_size (int): The padded size of the image as input
|
32 |
+
to the image encoder, as (H, W).
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
self.embed_dim = embed_dim
|
36 |
+
self.input_image_size = input_image_size
|
37 |
+
self.image_embedding_size = image_embedding_size
|
38 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
39 |
+
self.invalid_points = nn.Embedding(1, embed_dim)
|
40 |
+
self.point_embeddings = nn.Embedding(1, embed_dim)
|
41 |
+
self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
|
42 |
+
self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
|
43 |
+
|
44 |
+
def get_dense_pe(self) -> torch.Tensor:
|
45 |
+
"""
|
46 |
+
Returns the positional encoding used to encode point prompts,
|
47 |
+
applied to a dense set of points the shape of the image encoding.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: Positional encoding with shape
|
51 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
52 |
+
"""
|
53 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
54 |
+
|
55 |
+
def _embed_points(
|
56 |
+
self,
|
57 |
+
points: torch.Tensor,
|
58 |
+
labels: torch.Tensor,
|
59 |
+
) -> torch.Tensor:
|
60 |
+
"""Embeds point prompts."""
|
61 |
+
points = points + 0.5 # Shift to center of pixel
|
62 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
63 |
+
points, self.input_image_size
|
64 |
+
)
|
65 |
+
invalid_label_ids = torch.eq(labels, -1)[:,:,None]
|
66 |
+
point_label_ids = torch.eq(labels, 1)[:,:,None]
|
67 |
+
topleft_label_ids = torch.eq(labels, 2)[:,:,None]
|
68 |
+
bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
|
69 |
+
point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
|
70 |
+
point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
|
71 |
+
point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
|
72 |
+
point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
|
73 |
+
return point_embedding
|
74 |
+
|
75 |
+
def forward(
|
76 |
+
self,
|
77 |
+
coords,
|
78 |
+
labels,
|
79 |
+
) -> torch.Tensor:
|
80 |
+
"""
|
81 |
+
Embeds different types of prompts, returning both sparse and dense
|
82 |
+
embeddings.
|
83 |
+
|
84 |
+
Arguments:
|
85 |
+
points: A tensor of shape [B, 2]
|
86 |
+
labels: An integer tensor of shape [B] where each element is 1,2 or 3.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
90 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
91 |
+
and boxes.
|
92 |
+
"""
|
93 |
+
return self._embed_points(coords, labels)
|
94 |
+
|
95 |
+
|
96 |
+
class PositionEmbeddingRandom(nn.Module):
|
97 |
+
"""
|
98 |
+
Positional encoding using random spatial frequencies.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, num_pos_feats: int) -> None:
|
102 |
+
super().__init__()
|
103 |
+
self.register_buffer(
|
104 |
+
"positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
|
105 |
+
)
|
106 |
+
|
107 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
108 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
109 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
110 |
+
coords = 2 * coords - 1
|
111 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
112 |
+
coords = 2 * np.pi * coords
|
113 |
+
# outputs d_1 x ... x d_n x C shape
|
114 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
115 |
+
|
116 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
117 |
+
"""Generate positional encoding for a grid of the specified size."""
|
118 |
+
h, w = size
|
119 |
+
device = self.positional_encoding_gaussian_matrix.device
|
120 |
+
grid = torch.ones([h, w], device=device, dtype=torch.float32)
|
121 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
122 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
123 |
+
y_embed = y_embed / h
|
124 |
+
x_embed = x_embed / w
|
125 |
+
|
126 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
127 |
+
return pe.permute(2, 0, 1) # C x H x W
|
128 |
+
|
129 |
+
def forward_with_coords(
|
130 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
131 |
+
) -> torch.Tensor:
|
132 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
133 |
+
coords = coords_input.clone()
|
134 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
135 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
136 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
137 |
+
|
138 |
+
|
139 |
+
class MaskDecoder(nn.Module):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
*,
|
143 |
+
transformer_dim: int,
|
144 |
+
transformer: nn.Module,
|
145 |
+
num_multimask_outputs: int,
|
146 |
+
activation: Type[nn.Module],
|
147 |
+
normalization_type: str,
|
148 |
+
normalize_before_activation: bool,
|
149 |
+
iou_head_depth: int,
|
150 |
+
iou_head_hidden_dim: int,
|
151 |
+
upscaling_layer_dims: List[int],
|
152 |
+
) -> None:
|
153 |
+
"""
|
154 |
+
Predicts masks given an image and prompt embeddings, using a
|
155 |
+
transformer architecture.
|
156 |
+
|
157 |
+
Arguments:
|
158 |
+
transformer_dim (int): the channel dimension of the transformer
|
159 |
+
transformer (nn.Module): the transformer used to predict masks
|
160 |
+
num_multimask_outputs (int): the number of masks to predict
|
161 |
+
when disambiguating masks
|
162 |
+
activation (nn.Module): the type of activation to use when
|
163 |
+
upscaling masks
|
164 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
165 |
+
mask quality
|
166 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
167 |
+
used to predict mask quality
|
168 |
+
"""
|
169 |
+
super().__init__()
|
170 |
+
self.transformer_dim = transformer_dim
|
171 |
+
self.transformer = transformer
|
172 |
+
|
173 |
+
self.num_multimask_outputs = num_multimask_outputs
|
174 |
+
|
175 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
176 |
+
if num_multimask_outputs > 1:
|
177 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
178 |
+
else:
|
179 |
+
self.num_mask_tokens = 1
|
180 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
181 |
+
output_dim_after_upscaling = transformer_dim
|
182 |
+
|
183 |
+
self.final_output_upscaling_layers = nn.ModuleList([])
|
184 |
+
for idx, layer_dims in enumerate(upscaling_layer_dims):
|
185 |
+
self.final_output_upscaling_layers.append(
|
186 |
+
nn.Sequential(
|
187 |
+
nn.ConvTranspose2d(
|
188 |
+
output_dim_after_upscaling,
|
189 |
+
layer_dims,
|
190 |
+
kernel_size=2,
|
191 |
+
stride=2,
|
192 |
+
),
|
193 |
+
nn.GroupNorm(1, layer_dims)
|
194 |
+
if idx < len(upscaling_layer_dims) - 1
|
195 |
+
else nn.Identity(),
|
196 |
+
activation(),
|
197 |
+
)
|
198 |
+
)
|
199 |
+
output_dim_after_upscaling = layer_dims
|
200 |
+
|
201 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
202 |
+
[
|
203 |
+
MLPBlock(
|
204 |
+
input_dim=transformer_dim,
|
205 |
+
hidden_dim=transformer_dim,
|
206 |
+
output_dim=output_dim_after_upscaling,
|
207 |
+
num_layers=2,
|
208 |
+
act=activation,
|
209 |
+
)
|
210 |
+
for i in range(self.num_mask_tokens)
|
211 |
+
]
|
212 |
+
)
|
213 |
+
|
214 |
+
self.iou_prediction_head = MLPBlock(
|
215 |
+
input_dim=transformer_dim,
|
216 |
+
hidden_dim=iou_head_hidden_dim,
|
217 |
+
output_dim=self.num_mask_tokens,
|
218 |
+
num_layers=iou_head_depth,
|
219 |
+
act=activation,
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
image_embeddings: torch.Tensor,
|
225 |
+
image_pe: torch.Tensor,
|
226 |
+
sparse_prompt_embeddings: torch.Tensor,
|
227 |
+
multimask_output: bool,
|
228 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
229 |
+
"""
|
230 |
+
Predict masks given image and prompt embeddings.
|
231 |
+
|
232 |
+
Arguments:
|
233 |
+
image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
|
234 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
|
235 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
236 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
237 |
+
mask.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
torch.Tensor: batched predicted masks
|
241 |
+
torch.Tensor: batched predictions of mask quality
|
242 |
+
"""
|
243 |
+
|
244 |
+
(
|
245 |
+
batch_size,
|
246 |
+
max_num_queries,
|
247 |
+
sparse_embed_dim_1,
|
248 |
+
sparse_embed_dim_2,
|
249 |
+
) = sparse_prompt_embeddings.shape
|
250 |
+
|
251 |
+
(
|
252 |
+
_,
|
253 |
+
image_embed_dim_c,
|
254 |
+
image_embed_dim_h,
|
255 |
+
image_embed_dim_w,
|
256 |
+
) = image_embeddings.shape
|
257 |
+
|
258 |
+
# Tile the image embedding for all queries.
|
259 |
+
image_embeddings_tiled = torch.tile(
|
260 |
+
image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
|
261 |
+
).view(
|
262 |
+
batch_size * max_num_queries,
|
263 |
+
image_embed_dim_c,
|
264 |
+
image_embed_dim_h,
|
265 |
+
image_embed_dim_w,
|
266 |
+
)
|
267 |
+
sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
|
268 |
+
batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
|
269 |
+
)
|
270 |
+
masks, iou_pred = self.predict_masks(
|
271 |
+
image_embeddings=image_embeddings_tiled,
|
272 |
+
image_pe=image_pe,
|
273 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
274 |
+
)
|
275 |
+
if multimask_output and self.num_multimask_outputs > 1:
|
276 |
+
return masks[:, 1:, :], iou_pred[:, 1:]
|
277 |
+
else:
|
278 |
+
return masks[:, :1, :], iou_pred[:, :1]
|
279 |
+
|
280 |
+
def predict_masks(
|
281 |
+
self,
|
282 |
+
image_embeddings: torch.Tensor,
|
283 |
+
image_pe: torch.Tensor,
|
284 |
+
sparse_prompt_embeddings: torch.Tensor,
|
285 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
286 |
+
"""Predicts masks. See 'forward' for more details."""
|
287 |
+
# Concatenate output tokens
|
288 |
+
output_tokens = torch.cat(
|
289 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
290 |
+
)
|
291 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
292 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
293 |
+
)
|
294 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
295 |
+
# Expand per-image data in batch direction to be per-mask
|
296 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
297 |
+
b, c, h, w = image_embeddings.shape
|
298 |
+
hs, src = self.transformer(image_embeddings, pos_src, tokens)
|
299 |
+
iou_token_out = hs[:, 0, :]
|
300 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
301 |
+
|
302 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
303 |
+
upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
|
304 |
+
|
305 |
+
for upscaling_layer in self.final_output_upscaling_layers:
|
306 |
+
upscaled_embedding = upscaling_layer(upscaled_embedding)
|
307 |
+
hyper_in_list: List[torch.Tensor] = []
|
308 |
+
for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
|
309 |
+
hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
|
310 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
311 |
+
b, c, h, w = upscaled_embedding.shape
|
312 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
313 |
+
# Generate mask quality predictions
|
314 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
315 |
+
return masks, iou_pred
|
EfficientSAM-main/efficient_sam/efficient_sam_encoder.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import List, Optional, Tuple, Type
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class LayerNorm2d(nn.Module):
|
16 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
19 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
20 |
+
self.eps = eps
|
21 |
+
|
22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
23 |
+
u = x.mean(1, keepdim=True)
|
24 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
25 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
26 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class PatchEmbed(nn.Module):
|
31 |
+
"""2D Image to Patch Embedding"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
img_size,
|
36 |
+
patch_size,
|
37 |
+
in_chans,
|
38 |
+
embed_dim,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.proj = nn.Conv2d(
|
42 |
+
in_chans,
|
43 |
+
embed_dim,
|
44 |
+
kernel_size=(patch_size, patch_size),
|
45 |
+
stride=(patch_size, patch_size),
|
46 |
+
bias=True,
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
B, C, H, W = x.shape
|
51 |
+
x = self.proj(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class Attention(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
dim,
|
59 |
+
num_heads,
|
60 |
+
qkv_bias,
|
61 |
+
qk_scale=None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.num_heads = num_heads
|
65 |
+
head_dim = dim // num_heads
|
66 |
+
self.scale = qk_scale or head_dim**-0.5
|
67 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
68 |
+
self.proj = nn.Linear(dim, dim)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = (
|
73 |
+
self.qkv(x)
|
74 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
75 |
+
.permute(2, 0, 3, 1, 4)
|
76 |
+
)
|
77 |
+
q, k, v = (
|
78 |
+
qkv[0],
|
79 |
+
qkv[1],
|
80 |
+
qkv[2],
|
81 |
+
)
|
82 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
83 |
+
attn = attn.softmax(dim=-1)
|
84 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
85 |
+
x = self.proj(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Mlp(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
in_features,
|
93 |
+
hidden_features=None,
|
94 |
+
out_features=None,
|
95 |
+
act_layer=nn.GELU,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
out_features = out_features or in_features
|
99 |
+
hidden_features = hidden_features or in_features
|
100 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
101 |
+
self.act = act_layer()
|
102 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.fc1(x)
|
106 |
+
x = self.act(x)
|
107 |
+
x = self.fc2(x)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class Block(nn.Module):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
dim,
|
115 |
+
num_heads,
|
116 |
+
mlp_ratio=4.0,
|
117 |
+
qkv_bias=False,
|
118 |
+
qk_scale=None,
|
119 |
+
act_layer=nn.GELU,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
123 |
+
self.attn = Attention(
|
124 |
+
dim,
|
125 |
+
num_heads=num_heads,
|
126 |
+
qkv_bias=qkv_bias,
|
127 |
+
qk_scale=qk_scale,
|
128 |
+
)
|
129 |
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
130 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
131 |
+
self.mlp = Mlp(
|
132 |
+
in_features=dim,
|
133 |
+
hidden_features=mlp_hidden_dim,
|
134 |
+
act_layer=act_layer,
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
x = x + self.attn(self.norm1(x))
|
139 |
+
x = x + self.mlp(self.norm2(x))
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
@torch.jit.export
|
144 |
+
def get_abs_pos(
|
145 |
+
abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
|
146 |
+
) -> torch.Tensor:
|
147 |
+
"""
|
148 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
149 |
+
dimension for the original embeddings.
|
150 |
+
Args:
|
151 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
152 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
153 |
+
hw (Tuple): size of input image tokens.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
157 |
+
"""
|
158 |
+
h = hw[0]
|
159 |
+
w = hw[1]
|
160 |
+
if has_cls_token:
|
161 |
+
abs_pos = abs_pos[:, 1:]
|
162 |
+
xy_num = abs_pos.shape[1]
|
163 |
+
size = int(math.sqrt(xy_num))
|
164 |
+
assert size * size == xy_num
|
165 |
+
|
166 |
+
if size != h or size != w:
|
167 |
+
new_abs_pos = F.interpolate(
|
168 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
169 |
+
size=(h, w),
|
170 |
+
mode="bicubic",
|
171 |
+
align_corners=False,
|
172 |
+
)
|
173 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
174 |
+
else:
|
175 |
+
return abs_pos.reshape(1, h, w, -1)
|
176 |
+
|
177 |
+
|
178 |
+
# Image encoder for efficient SAM.
|
179 |
+
class ImageEncoderViT(nn.Module):
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
img_size: int,
|
183 |
+
patch_size: int,
|
184 |
+
in_chans: int,
|
185 |
+
patch_embed_dim: int,
|
186 |
+
normalization_type: str,
|
187 |
+
depth: int,
|
188 |
+
num_heads: int,
|
189 |
+
mlp_ratio: float,
|
190 |
+
neck_dims: List[int],
|
191 |
+
act_layer: Type[nn.Module],
|
192 |
+
) -> None:
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
img_size (int): Input image size.
|
196 |
+
patch_size (int): Patch size.
|
197 |
+
in_chans (int): Number of input image channels.
|
198 |
+
patch_embed_dim (int): Patch embedding dimension.
|
199 |
+
depth (int): Depth of ViT.
|
200 |
+
num_heads (int): Number of attention heads in each ViT block.
|
201 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
+
act_layer (nn.Module): Activation layer.
|
203 |
+
"""
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
self.img_size = img_size
|
207 |
+
self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
|
208 |
+
self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
|
209 |
+
self.pretrain_use_cls_token = True
|
210 |
+
pretrain_img_size = 224
|
211 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
|
212 |
+
# Initialize absolute positional embedding with pretrain image size.
|
213 |
+
num_patches = (pretrain_img_size // patch_size) * (
|
214 |
+
pretrain_img_size // patch_size
|
215 |
+
)
|
216 |
+
num_positions = num_patches + 1
|
217 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
|
218 |
+
self.blocks = nn.ModuleList()
|
219 |
+
for i in range(depth):
|
220 |
+
vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
|
221 |
+
self.blocks.append(vit_block)
|
222 |
+
self.neck = nn.Sequential(
|
223 |
+
nn.Conv2d(
|
224 |
+
patch_embed_dim,
|
225 |
+
neck_dims[0],
|
226 |
+
kernel_size=1,
|
227 |
+
bias=False,
|
228 |
+
),
|
229 |
+
LayerNorm2d(neck_dims[0]),
|
230 |
+
nn.Conv2d(
|
231 |
+
neck_dims[0],
|
232 |
+
neck_dims[0],
|
233 |
+
kernel_size=3,
|
234 |
+
padding=1,
|
235 |
+
bias=False,
|
236 |
+
),
|
237 |
+
LayerNorm2d(neck_dims[0]),
|
238 |
+
)
|
239 |
+
|
240 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
241 |
+
assert (
|
242 |
+
x.shape[2] == self.img_size and x.shape[3] == self.img_size
|
243 |
+
), "input image size must match self.img_size"
|
244 |
+
x = self.patch_embed(x)
|
245 |
+
# B C H W -> B H W C
|
246 |
+
x = x.permute(0, 2, 3, 1)
|
247 |
+
x = x + get_abs_pos(
|
248 |
+
self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
|
249 |
+
)
|
250 |
+
num_patches = x.shape[1]
|
251 |
+
assert x.shape[2] == num_patches
|
252 |
+
x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
|
253 |
+
for blk in self.blocks:
|
254 |
+
x = blk(x)
|
255 |
+
x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
|
256 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
257 |
+
return x
|
EfficientSAM-main/efficient_sam/mlp.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
# Lightly adapted from
|
7 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
8 |
+
class MLPBlock(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
input_dim: int,
|
12 |
+
hidden_dim: int,
|
13 |
+
output_dim: int,
|
14 |
+
num_layers: int,
|
15 |
+
act: Type[nn.Module],
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.num_layers = num_layers
|
19 |
+
h = [hidden_dim] * (num_layers - 1)
|
20 |
+
self.layers = nn.ModuleList(
|
21 |
+
nn.Sequential(nn.Linear(n, k), act())
|
22 |
+
for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
|
23 |
+
)
|
24 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
for layer in self.layers:
|
28 |
+
x = layer(x)
|
29 |
+
return self.fc(x)
|
EfficientSAM-main/efficient_sam/two_way_transformer.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple, Type
|
3 |
+
import torch
|
4 |
+
from torch import nn, Tensor
|
5 |
+
from .mlp import MLPBlock
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class TwoWayTransformer(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
depth: int,
|
14 |
+
embedding_dim: int,
|
15 |
+
num_heads: int,
|
16 |
+
mlp_dim: int,
|
17 |
+
activation: Type[nn.Module],
|
18 |
+
normalize_before_activation: bool,
|
19 |
+
attention_downsample_rate: int = 2,
|
20 |
+
) -> None:
|
21 |
+
"""
|
22 |
+
A transformer decoder that attends to an input image using
|
23 |
+
queries whose positional embedding is supplied.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
depth (int): number of layers in the transformer
|
27 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
28 |
+
num_heads (int): the number of heads for multihead attention. Must
|
29 |
+
divide embedding_dim
|
30 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
31 |
+
activation (nn.Module): the activation to use in the MLP block
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
self.depth = depth
|
35 |
+
self.embedding_dim = embedding_dim
|
36 |
+
self.num_heads = num_heads
|
37 |
+
self.mlp_dim = mlp_dim
|
38 |
+
self.layers = nn.ModuleList()
|
39 |
+
|
40 |
+
for i in range(depth):
|
41 |
+
curr_layer = TwoWayAttentionBlock(
|
42 |
+
embedding_dim=embedding_dim,
|
43 |
+
num_heads=num_heads,
|
44 |
+
mlp_dim=mlp_dim,
|
45 |
+
activation=activation,
|
46 |
+
normalize_before_activation=normalize_before_activation,
|
47 |
+
attention_downsample_rate=attention_downsample_rate,
|
48 |
+
skip_first_layer_pe=(i == 0),
|
49 |
+
)
|
50 |
+
self.layers.append(curr_layer)
|
51 |
+
|
52 |
+
self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
|
53 |
+
embedding_dim,
|
54 |
+
num_heads,
|
55 |
+
downsample_rate=attention_downsample_rate,
|
56 |
+
)
|
57 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
image_embedding: Tensor,
|
62 |
+
image_pe: Tensor,
|
63 |
+
point_embedding: Tensor,
|
64 |
+
) -> Tuple[Tensor, Tensor]:
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
68 |
+
B x embedding_dim x h x w for any h and w.
|
69 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
70 |
+
have the same shape as image_embedding.
|
71 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
72 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
torch.Tensor: the processed point_embedding
|
76 |
+
torch.Tensor: the processed image_embedding
|
77 |
+
"""
|
78 |
+
|
79 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
80 |
+
bs, c, h, w = image_embedding.shape
|
81 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
82 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
83 |
+
|
84 |
+
# Prepare queries
|
85 |
+
queries = point_embedding
|
86 |
+
keys = image_embedding
|
87 |
+
|
88 |
+
# Apply transformer blocks and final layernorm
|
89 |
+
for idx, layer in enumerate(self.layers):
|
90 |
+
queries, keys = layer(
|
91 |
+
queries=queries,
|
92 |
+
keys=keys,
|
93 |
+
query_pe=point_embedding,
|
94 |
+
key_pe=image_pe,
|
95 |
+
)
|
96 |
+
|
97 |
+
# Apply the final attention layer from the points to the image
|
98 |
+
q = queries + point_embedding
|
99 |
+
k = keys + image_pe
|
100 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
101 |
+
queries = queries + attn_out
|
102 |
+
queries = self.norm_final_attn(queries)
|
103 |
+
return queries, keys
|
104 |
+
|
105 |
+
|
106 |
+
class TwoWayAttentionBlock(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
embedding_dim: int,
|
110 |
+
num_heads: int,
|
111 |
+
mlp_dim: int,
|
112 |
+
activation: Type[nn.Module],
|
113 |
+
normalize_before_activation: bool,
|
114 |
+
attention_downsample_rate: int = 2,
|
115 |
+
skip_first_layer_pe: bool = False,
|
116 |
+
) -> None:
|
117 |
+
"""
|
118 |
+
A transformer block with four layers: (1) self-attention of sparse
|
119 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
120 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
121 |
+
inputs.
|
122 |
+
|
123 |
+
Arguments:
|
124 |
+
embedding_dim (int): the channel dimension of the embeddings
|
125 |
+
num_heads (int): the number of heads in the attention layers
|
126 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
127 |
+
activation (nn.Module): the activation of the mlp block
|
128 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
129 |
+
"""
|
130 |
+
super().__init__()
|
131 |
+
self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
|
132 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
133 |
+
|
134 |
+
self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
|
135 |
+
embedding_dim,
|
136 |
+
num_heads,
|
137 |
+
downsample_rate=attention_downsample_rate,
|
138 |
+
)
|
139 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
140 |
+
|
141 |
+
self.mlp = MLPBlock(
|
142 |
+
embedding_dim,
|
143 |
+
mlp_dim,
|
144 |
+
embedding_dim,
|
145 |
+
1,
|
146 |
+
activation,
|
147 |
+
)
|
148 |
+
|
149 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
150 |
+
|
151 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
152 |
+
self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
|
153 |
+
embedding_dim,
|
154 |
+
num_heads,
|
155 |
+
downsample_rate=attention_downsample_rate,
|
156 |
+
)
|
157 |
+
|
158 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
159 |
+
|
160 |
+
def forward(
|
161 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
162 |
+
) -> Tuple[Tensor, Tensor]:
|
163 |
+
# Self attention block
|
164 |
+
if not self.skip_first_layer_pe:
|
165 |
+
queries = queries + query_pe
|
166 |
+
attn_out = self.self_attn(q=queries, k=queries, v=queries)
|
167 |
+
queries = queries + attn_out
|
168 |
+
queries = self.norm1(queries)
|
169 |
+
|
170 |
+
# Cross attention block, tokens attending to image embedding
|
171 |
+
q = queries + query_pe
|
172 |
+
k = keys + key_pe
|
173 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
174 |
+
queries = queries + attn_out
|
175 |
+
queries = self.norm2(queries)
|
176 |
+
|
177 |
+
# MLP block
|
178 |
+
mlp_out = self.mlp(queries)
|
179 |
+
queries = queries + mlp_out
|
180 |
+
queries = self.norm3(queries)
|
181 |
+
|
182 |
+
# Cross attention block, image embedding attending to tokens
|
183 |
+
q = queries + query_pe
|
184 |
+
k = keys + key_pe
|
185 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
186 |
+
keys = keys + attn_out
|
187 |
+
keys = self.norm4(keys)
|
188 |
+
|
189 |
+
return queries, keys
|
190 |
+
|
191 |
+
|
192 |
+
class AttentionForTwoWayAttentionBlock(nn.Module):
|
193 |
+
"""
|
194 |
+
An attention layer that allows for downscaling the size of the embedding
|
195 |
+
after projection to queries, keys, and values.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
embedding_dim: int,
|
201 |
+
num_heads: int,
|
202 |
+
downsample_rate: int = 1,
|
203 |
+
) -> None:
|
204 |
+
super().__init__()
|
205 |
+
self.embedding_dim = embedding_dim
|
206 |
+
self.internal_dim = embedding_dim // downsample_rate
|
207 |
+
self.num_heads = num_heads
|
208 |
+
assert (
|
209 |
+
self.internal_dim % num_heads == 0
|
210 |
+
), "num_heads must divide embedding_dim."
|
211 |
+
self.c_per_head = self.internal_dim / num_heads
|
212 |
+
self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head)
|
213 |
+
|
214 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
215 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
216 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
217 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
218 |
+
self._reset_parameters()
|
219 |
+
|
220 |
+
def _reset_parameters(self) -> None:
|
221 |
+
# The fan_out is incorrect, but matches pytorch's initialization
|
222 |
+
# for which qkv is a single 3*embedding_dim x embedding_dim matrix
|
223 |
+
fan_in = self.embedding_dim
|
224 |
+
fan_out = 3 * self.internal_dim
|
225 |
+
# Xavier uniform with our custom fan_out
|
226 |
+
bnd = math.sqrt(6 / (fan_in + fan_out))
|
227 |
+
nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
|
228 |
+
nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
|
229 |
+
nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
|
230 |
+
# out_proj.weight is left with default initialization, like pytorch attention
|
231 |
+
nn.init.zeros_(self.q_proj.bias)
|
232 |
+
nn.init.zeros_(self.k_proj.bias)
|
233 |
+
nn.init.zeros_(self.v_proj.bias)
|
234 |
+
nn.init.zeros_(self.out_proj.bias)
|
235 |
+
|
236 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
237 |
+
b, n, c = x.shape
|
238 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
239 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
240 |
+
|
241 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
242 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
243 |
+
x = x.transpose(1, 2)
|
244 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
245 |
+
|
246 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
247 |
+
# Input projections
|
248 |
+
q = self.q_proj(q)
|
249 |
+
k = self.k_proj(k)
|
250 |
+
v = self.v_proj(v)
|
251 |
+
|
252 |
+
# Separate into heads
|
253 |
+
q = self._separate_heads(q, self.num_heads)
|
254 |
+
k = self._separate_heads(k, self.num_heads)
|
255 |
+
v = self._separate_heads(v, self.num_heads)
|
256 |
+
|
257 |
+
# Attention
|
258 |
+
_, _, _, c_per_head = q.shape
|
259 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
260 |
+
attn = attn * self.inv_sqrt_c_per_head
|
261 |
+
attn = torch.softmax(attn, dim=-1)
|
262 |
+
# Get output
|
263 |
+
out = attn @ v
|
264 |
+
out = self._recombine_heads(out)
|
265 |
+
out = self.out_proj(out)
|
266 |
+
return out
|
EfficientSAM-main/export_to_onnx.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ONNX export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
|
2 |
+
|
3 |
+
import onnxruntime
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
|
7 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt
|
8 |
+
|
9 |
+
import onnx_models
|
10 |
+
|
11 |
+
|
12 |
+
def export_onnx(onnx_model, output, dynamic_axes, dummy_inputs, output_names):
|
13 |
+
with open(output, "wb") as f:
|
14 |
+
print(f"Exporting onnx model to {output}...")
|
15 |
+
torch.onnx.export(
|
16 |
+
onnx_model,
|
17 |
+
tuple(dummy_inputs.values()),
|
18 |
+
f,
|
19 |
+
export_params=True,
|
20 |
+
verbose=False,
|
21 |
+
opset_version=17,
|
22 |
+
do_constant_folding=True,
|
23 |
+
input_names=list(dummy_inputs.keys()),
|
24 |
+
output_names=output_names,
|
25 |
+
dynamic_axes=dynamic_axes,
|
26 |
+
)
|
27 |
+
|
28 |
+
inference_session = onnxruntime.InferenceSession(output)
|
29 |
+
output = inference_session.run(
|
30 |
+
output_names=output_names,
|
31 |
+
input_feed={k: v.numpy() for k, v in dummy_inputs.items()},
|
32 |
+
)
|
33 |
+
print(output_names)
|
34 |
+
print([output_i.shape for output_i in output])
|
35 |
+
|
36 |
+
|
37 |
+
def export_onnx_esam(model, output):
|
38 |
+
onnx_model = onnx_models.OnnxEfficientSam(model=model)
|
39 |
+
dynamic_axes = {
|
40 |
+
"batched_images": {0: "batch", 2: "height", 3: "width"},
|
41 |
+
"batched_point_coords": {2: "num_points"},
|
42 |
+
"batched_point_labels": {2: "num_points"},
|
43 |
+
}
|
44 |
+
dummy_inputs = {
|
45 |
+
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
|
46 |
+
"batched_point_coords": torch.randint(
|
47 |
+
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
|
48 |
+
),
|
49 |
+
"batched_point_labels": torch.randint(
|
50 |
+
low=0, high=4, size=(1, 1, 5), dtype=torch.float
|
51 |
+
),
|
52 |
+
}
|
53 |
+
output_names = ["output_masks", "iou_predictions"]
|
54 |
+
export_onnx(
|
55 |
+
onnx_model=onnx_model,
|
56 |
+
output=output,
|
57 |
+
dynamic_axes=dynamic_axes,
|
58 |
+
dummy_inputs=dummy_inputs,
|
59 |
+
output_names=output_names,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def export_onnx_esam_encoder(model, output):
|
64 |
+
onnx_model = onnx_models.OnnxEfficientSamEncoder(model=model)
|
65 |
+
dynamic_axes = {
|
66 |
+
"batched_images": {0: "batch", 2: "height", 3: "width"},
|
67 |
+
}
|
68 |
+
dummy_inputs = {
|
69 |
+
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
|
70 |
+
}
|
71 |
+
output_names = ["image_embeddings"]
|
72 |
+
export_onnx(
|
73 |
+
onnx_model=onnx_model,
|
74 |
+
output=output,
|
75 |
+
dynamic_axes=dynamic_axes,
|
76 |
+
dummy_inputs=dummy_inputs,
|
77 |
+
output_names=output_names,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def export_onnx_esam_decoder(model, output):
|
82 |
+
onnx_model = onnx_models.OnnxEfficientSamDecoder(model=model)
|
83 |
+
dynamic_axes = {
|
84 |
+
"image_embeddings": {0: "batch"},
|
85 |
+
"batched_point_coords": {2: "num_points"},
|
86 |
+
"batched_point_labels": {2: "num_points"},
|
87 |
+
}
|
88 |
+
dummy_inputs = {
|
89 |
+
"image_embeddings": torch.randn(1, 256, 64, 64, dtype=torch.float),
|
90 |
+
"batched_point_coords": torch.randint(
|
91 |
+
low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
|
92 |
+
),
|
93 |
+
"batched_point_labels": torch.randint(
|
94 |
+
low=0, high=4, size=(1, 1, 5), dtype=torch.float
|
95 |
+
),
|
96 |
+
"orig_im_size": torch.tensor([1080, 1920], dtype=torch.long),
|
97 |
+
}
|
98 |
+
output_names = ["output_masks", "iou_predictions"]
|
99 |
+
export_onnx(
|
100 |
+
onnx_model=onnx_model,
|
101 |
+
output=output,
|
102 |
+
dynamic_axes=dynamic_axes,
|
103 |
+
dummy_inputs=dummy_inputs,
|
104 |
+
output_names=output_names,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
# faster
|
110 |
+
export_onnx_esam(
|
111 |
+
model=build_efficient_sam_vitt(),
|
112 |
+
output="weights/efficient_sam_vitt.onnx",
|
113 |
+
)
|
114 |
+
export_onnx_esam_encoder(
|
115 |
+
model=build_efficient_sam_vitt(),
|
116 |
+
output="weights/efficient_sam_vitt_encoder.onnx",
|
117 |
+
)
|
118 |
+
export_onnx_esam_decoder(
|
119 |
+
model=build_efficient_sam_vitt(),
|
120 |
+
output="weights/efficient_sam_vitt_decoder.onnx",
|
121 |
+
)
|
122 |
+
|
123 |
+
# more accurate
|
124 |
+
export_onnx_esam(
|
125 |
+
model=build_efficient_sam_vits(),
|
126 |
+
output="weights/efficient_sam_vits.onnx",
|
127 |
+
)
|
128 |
+
export_onnx_esam_encoder(
|
129 |
+
model=build_efficient_sam_vits(),
|
130 |
+
output="weights/efficient_sam_vits_encoder.onnx",
|
131 |
+
)
|
132 |
+
export_onnx_esam_decoder(
|
133 |
+
model=build_efficient_sam_vits(),
|
134 |
+
output="weights/efficient_sam_vits_decoder.onnx",
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
EfficientSAM-main/export_to_torchscript.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
|
3 |
+
# from squeeze_sam.build_squeeze_sam import build_squeeze_sam
|
4 |
+
import zipfile
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Efficient SAM (VIT-tiny)
|
8 |
+
torch.jit.save(torch.jit.script(build_efficient_sam_vitt()), "torchscripted_model/efficient_sam_vitt_torchscript.pt")
|
9 |
+
|
10 |
+
# Efficient SAM (VIT-small)
|
11 |
+
# Since VIT-small is >100MB, we store the zip file.
|
12 |
+
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
|
13 |
+
zip_ref.extractall("weights")
|
14 |
+
torch.jit.save(torch.jit.script(build_efficient_sam_vits()), "torchscripted_model/efficient_sam_vits_torchscript.pt")
|
15 |
+
|
16 |
+
# Squeeze SAM (UNET)
|
17 |
+
# torch.jit.save(torch.jit.script(build_squeeze_sam()), "torchscripted_model/squeeze_sam_torchscript.pt")
|
EfficientSAM-main/figs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
EfficientSAM-main/figs/examples/demo_box.png
ADDED
Git LFS Details
|
EfficientSAM-main/figs/examples/demo_everything.png
ADDED
Git LFS Details
|
EfficientSAM-main/figs/examples/demo_point.png
ADDED
Git LFS Details
|
EfficientSAM-main/figs/examples/demo_saliency.png
ADDED
Git LFS Details
|
EfficientSAM-main/figs/examples/dogs.jpg
ADDED
EfficientSAM-main/figs/examples/dogs_efficient_sam_vits_mask.png
ADDED
EfficientSAM-main/figs/examples/dogs_efficient_sam_vitt_mask.png
ADDED
EfficientSAM-main/figs/examples/dogs_efficientsam_s_mask.png
ADDED
EfficientSAM-main/figs/examples/dogs_efficientsam_ti_mask.png
ADDED
EfficientSAM-main/figs/examples/dogs_squeeze_sam_mask.png
ADDED
EfficientSAM-main/linter.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -e
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
|
4 |
+
{
|
5 |
+
black --version | grep -E "23\." > /dev/null
|
6 |
+
} || {
|
7 |
+
echo "Linter requires 'black==23.*' !"
|
8 |
+
exit 1
|
9 |
+
}
|
10 |
+
|
11 |
+
ISORT_VERSION=$(isort --version-number)
|
12 |
+
if [[ "$ISORT_VERSION" != 5.12* ]]; then
|
13 |
+
echo "Linter requires isort==5.12.0 !"
|
14 |
+
exit 1
|
15 |
+
fi
|
16 |
+
|
17 |
+
echo "Running isort ..."
|
18 |
+
isort . --atomic
|
19 |
+
|
20 |
+
echo "Running black ..."
|
21 |
+
black -l 100 .
|
22 |
+
|
23 |
+
echo "Running flake8 ..."
|
24 |
+
if [ -x "$(command -v flake8)" ]; then
|
25 |
+
flake8 .
|
26 |
+
else
|
27 |
+
python3 -m flake8 .
|
28 |
+
fi
|
29 |
+
|
30 |
+
echo "Running mypy..."
|
31 |
+
|
32 |
+
mypy --exclude 'setup.py|notebooks' .
|
EfficientSAM-main/notebooks/EfficientSAM_example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
EfficientSAM-main/notebooks/EfficientSAM_segment_everything_example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
EfficientSAM-main/onnx_models.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class OnnxEfficientSam(torch.nn.Module):
|
8 |
+
def __init__(self, model):
|
9 |
+
super().__init__()
|
10 |
+
self.model = model
|
11 |
+
|
12 |
+
@property
|
13 |
+
def decoder_max_num_input_points(self):
|
14 |
+
return self.model.decoder_max_num_input_points
|
15 |
+
|
16 |
+
@property
|
17 |
+
def image_encoder(self):
|
18 |
+
return self.model.image_encoder
|
19 |
+
|
20 |
+
@property
|
21 |
+
def get_image_embeddings(self):
|
22 |
+
return self.model.get_image_embeddings
|
23 |
+
|
24 |
+
@property
|
25 |
+
def prompt_encoder(self):
|
26 |
+
return self.model.prompt_encoder
|
27 |
+
|
28 |
+
@property
|
29 |
+
def mask_decoder(self):
|
30 |
+
return self.model.mask_decoder
|
31 |
+
|
32 |
+
def forward(
|
33 |
+
self,
|
34 |
+
batched_images: torch.Tensor,
|
35 |
+
batched_points: torch.Tensor,
|
36 |
+
batched_point_labels: torch.Tensor,
|
37 |
+
):
|
38 |
+
batch_size, _, input_h, input_w = batched_images.shape
|
39 |
+
image_embeddings = self.get_image_embeddings(batched_images)
|
40 |
+
return self.predict_masks(
|
41 |
+
image_embeddings,
|
42 |
+
batched_points,
|
43 |
+
batched_point_labels,
|
44 |
+
multimask_output=True,
|
45 |
+
input_h=input_h,
|
46 |
+
input_w=input_w,
|
47 |
+
output_h=input_h,
|
48 |
+
output_w=input_w,
|
49 |
+
)
|
50 |
+
|
51 |
+
def get_rescaled_pts(
|
52 |
+
self, batched_points: torch.Tensor, input_h: int, input_w: int
|
53 |
+
):
|
54 |
+
return torch.stack(
|
55 |
+
[
|
56 |
+
batched_points[..., 0] * self.image_encoder.img_size / input_w,
|
57 |
+
batched_points[..., 1] * self.image_encoder.img_size / input_h,
|
58 |
+
],
|
59 |
+
dim=-1,
|
60 |
+
)
|
61 |
+
|
62 |
+
def predict_masks(
|
63 |
+
self,
|
64 |
+
image_embeddings: torch.Tensor,
|
65 |
+
batched_points: torch.Tensor,
|
66 |
+
batched_point_labels: torch.Tensor,
|
67 |
+
multimask_output: bool,
|
68 |
+
input_h: int,
|
69 |
+
input_w: int,
|
70 |
+
output_h: int = -1,
|
71 |
+
output_w: int = -1,
|
72 |
+
):
|
73 |
+
batch_size, max_num_queries, num_pts, _ = batched_points.shape
|
74 |
+
num_pts = batched_points.shape[2]
|
75 |
+
rescaled_batched_points = self.get_rescaled_pts(
|
76 |
+
batched_points, input_h, input_w
|
77 |
+
)
|
78 |
+
|
79 |
+
if num_pts > self.decoder_max_num_input_points:
|
80 |
+
rescaled_batched_points = rescaled_batched_points[
|
81 |
+
:, :, : self.decoder_max_num_input_points, :
|
82 |
+
]
|
83 |
+
batched_point_labels = batched_point_labels[
|
84 |
+
:, :, : self.decoder_max_num_input_points
|
85 |
+
]
|
86 |
+
elif num_pts < self.decoder_max_num_input_points:
|
87 |
+
rescaled_batched_points = F.pad(
|
88 |
+
rescaled_batched_points,
|
89 |
+
(0, 0, 0, self.decoder_max_num_input_points - num_pts),
|
90 |
+
value=-1.0,
|
91 |
+
)
|
92 |
+
batched_point_labels = F.pad(
|
93 |
+
batched_point_labels,
|
94 |
+
(0, self.decoder_max_num_input_points - num_pts),
|
95 |
+
value=-1.0,
|
96 |
+
)
|
97 |
+
|
98 |
+
sparse_embeddings = self.prompt_encoder(
|
99 |
+
rescaled_batched_points.reshape(
|
100 |
+
batch_size * max_num_queries, self.decoder_max_num_input_points, 2
|
101 |
+
),
|
102 |
+
batched_point_labels.reshape(
|
103 |
+
batch_size * max_num_queries, self.decoder_max_num_input_points
|
104 |
+
),
|
105 |
+
)
|
106 |
+
sparse_embeddings = sparse_embeddings.view(
|
107 |
+
batch_size,
|
108 |
+
max_num_queries,
|
109 |
+
sparse_embeddings.shape[1],
|
110 |
+
sparse_embeddings.shape[2],
|
111 |
+
)
|
112 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
113 |
+
image_embeddings,
|
114 |
+
self.prompt_encoder.get_dense_pe(),
|
115 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
116 |
+
multimask_output=multimask_output,
|
117 |
+
)
|
118 |
+
_, num_predictions, low_res_size, _ = low_res_masks.shape
|
119 |
+
|
120 |
+
if output_w > 0 and output_h > 0:
|
121 |
+
output_masks = F.interpolate(
|
122 |
+
low_res_masks,
|
123 |
+
(output_h, output_w),
|
124 |
+
# NOTE: "bicubic" is inefficient on onnx
|
125 |
+
mode="bilinear",
|
126 |
+
)
|
127 |
+
output_masks = torch.reshape(
|
128 |
+
output_masks,
|
129 |
+
(batch_size, max_num_queries, num_predictions, output_h, output_w),
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
output_masks = torch.reshape(
|
133 |
+
low_res_masks,
|
134 |
+
(
|
135 |
+
batch_size,
|
136 |
+
max_num_queries,
|
137 |
+
num_predictions,
|
138 |
+
low_res_size,
|
139 |
+
low_res_size,
|
140 |
+
),
|
141 |
+
)
|
142 |
+
iou_predictions = torch.reshape(
|
143 |
+
iou_predictions, (batch_size, max_num_queries, num_predictions)
|
144 |
+
)
|
145 |
+
return output_masks, iou_predictions, low_res_masks
|
146 |
+
|
147 |
+
|
148 |
+
class OnnxEfficientSamEncoder(OnnxEfficientSam):
|
149 |
+
def forward(self, batched_images: torch.Tensor):
|
150 |
+
return self.model.get_image_embeddings(batched_images)
|
151 |
+
|
152 |
+
|
153 |
+
class OnnxEfficientSamDecoder(OnnxEfficientSam):
|
154 |
+
def forward(
|
155 |
+
self, image_embeddings, batched_points, batched_point_labels, orig_im_size
|
156 |
+
):
|
157 |
+
return self.predict_masks(
|
158 |
+
image_embeddings=image_embeddings,
|
159 |
+
batched_points=batched_points,
|
160 |
+
batched_point_labels=batched_point_labels,
|
161 |
+
multimask_output=True,
|
162 |
+
input_h=orig_im_size[0],
|
163 |
+
input_w=orig_im_size[1],
|
164 |
+
output_h=orig_im_size[0],
|
165 |
+
output_w=orig_im_size[1],
|
166 |
+
)
|
EfficientSAM-main/setup.cfg
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[isort]
|
2 |
+
line_length=100
|
3 |
+
multi_line_output=3
|
4 |
+
include_trailing_comma=True
|
5 |
+
known_standard_library=numpy,setuptools
|
6 |
+
skip_glob=*/__init__.py
|
7 |
+
known_myself=efficient_sam
|
8 |
+
known_third_party=matplotlib,torch,torchvision,black,isort
|
9 |
+
no_lines_before=STDLIB,THIRDPARTY
|
10 |
+
sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
|
11 |
+
default_section=FIRSTPARTY
|
EfficientSAM-main/setup.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from setuptools import find_packages, setup
|
8 |
+
|
9 |
+
setup(
|
10 |
+
name="efficient_sam",
|
11 |
+
version="1.0",
|
12 |
+
install_requires=[],
|
13 |
+
packages=find_packages(exclude="notebooks"),
|
14 |
+
extras_require={
|
15 |
+
"all": ["matplotlib", "onnx", "onnxruntime"],
|
16 |
+
"dev": ["flake8", "isort", "black", "mypy"],
|
17 |
+
},
|
18 |
+
)
|
EfficientSAM-main/torchscripted_model/efficient_sam_vitt_torchscript.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d300fa20a94f40bf307616a3979f4d4b6a4a347edd08ed37d69535c2185188e
|
3 |
+
size 41074768
|
EfficientSAM-main/weights/efficient_sam_vits.pt.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1801de05adeea87a6b779b0bedf3ab6751e03c21facb82d2c660867c02813fc
|
3 |
+
size 98304933
|
EfficientSAM-main/weights/efficient_sam_vitt.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:143c3198a7b2a15f23c21cdb723432fb3fbcdbabbdad3483cf3babd8b95c1397
|
3 |
+
size 41365520
|
EfficientSAM-main/weights/efficient_sam_vitt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dff858b19600a46461cbb7de98f796b23a7a888d9f5e34c0b033f7d6eb9e4e6a
|
3 |
+
size 40982470
|
EfficientSAM-main/weights/efficient_sam_vitt_decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a62f8fa5ea080447c0689418d69e58f1e83e0b7adf9c142e2bd9bcc8045c0b11
|
3 |
+
size 16565728
|
EfficientSAM-main/weights/efficient_sam_vitt_encoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84ed466ffcc5c1f8d08409bc34a23bb364ab2c15e402cb12d4335a42be0e0951
|
3 |
+
size 24799761
|