NandiniLokeshReddy commited on
Commit
d8dd7fb
·
verified ·
1 Parent(s): e0b4b52

Upload 38 files

Browse files
Files changed (39) hide show
  1. .gitattributes +4 -0
  2. EfficientSAM-main/.DS_Store +0 -0
  3. EfficientSAM-main/.gitignore +5 -0
  4. EfficientSAM-main/EfficientSAM_example.py +54 -0
  5. EfficientSAM-main/EfficientSAM_onnx_example.py +77 -0
  6. EfficientSAM-main/LICENSE +201 -0
  7. EfficientSAM-main/README.md +63 -0
  8. EfficientSAM-main/efficient_sam/__init__.py +7 -0
  9. EfficientSAM-main/efficient_sam/build_efficient_sam.py +22 -0
  10. EfficientSAM-main/efficient_sam/efficient_sam.py +305 -0
  11. EfficientSAM-main/efficient_sam/efficient_sam_decoder.py +315 -0
  12. EfficientSAM-main/efficient_sam/efficient_sam_encoder.py +257 -0
  13. EfficientSAM-main/efficient_sam/mlp.py +29 -0
  14. EfficientSAM-main/efficient_sam/two_way_transformer.py +266 -0
  15. EfficientSAM-main/export_to_onnx.py +139 -0
  16. EfficientSAM-main/export_to_torchscript.py +17 -0
  17. EfficientSAM-main/figs/.DS_Store +0 -0
  18. EfficientSAM-main/figs/examples/demo_box.png +3 -0
  19. EfficientSAM-main/figs/examples/demo_everything.png +3 -0
  20. EfficientSAM-main/figs/examples/demo_point.png +3 -0
  21. EfficientSAM-main/figs/examples/demo_saliency.png +3 -0
  22. EfficientSAM-main/figs/examples/dogs.jpg +0 -0
  23. EfficientSAM-main/figs/examples/dogs_efficient_sam_vits_mask.png +0 -0
  24. EfficientSAM-main/figs/examples/dogs_efficient_sam_vitt_mask.png +0 -0
  25. EfficientSAM-main/figs/examples/dogs_efficientsam_s_mask.png +0 -0
  26. EfficientSAM-main/figs/examples/dogs_efficientsam_ti_mask.png +0 -0
  27. EfficientSAM-main/figs/examples/dogs_squeeze_sam_mask.png +0 -0
  28. EfficientSAM-main/linter.sh +32 -0
  29. EfficientSAM-main/notebooks/EfficientSAM_example.ipynb +0 -0
  30. EfficientSAM-main/notebooks/EfficientSAM_segment_everything_example.ipynb +0 -0
  31. EfficientSAM-main/onnx_models.py +166 -0
  32. EfficientSAM-main/setup.cfg +11 -0
  33. EfficientSAM-main/setup.py +18 -0
  34. EfficientSAM-main/torchscripted_model/efficient_sam_vitt_torchscript.pt +3 -0
  35. EfficientSAM-main/weights/efficient_sam_vits.pt.zip +3 -0
  36. EfficientSAM-main/weights/efficient_sam_vitt.onnx +3 -0
  37. EfficientSAM-main/weights/efficient_sam_vitt.pt +3 -0
  38. EfficientSAM-main/weights/efficient_sam_vitt_decoder.onnx +3 -0
  39. EfficientSAM-main/weights/efficient_sam_vitt_encoder.onnx +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EfficientSAM-main/figs/examples/demo_box.png filter=lfs diff=lfs merge=lfs -text
37
+ EfficientSAM-main/figs/examples/demo_everything.png filter=lfs diff=lfs merge=lfs -text
38
+ EfficientSAM-main/figs/examples/demo_point.png filter=lfs diff=lfs merge=lfs -text
39
+ EfficientSAM-main/figs/examples/demo_saliency.png filter=lfs diff=lfs merge=lfs -text
EfficientSAM-main/.DS_Store ADDED
Binary file (10.2 kB). View file
 
EfficientSAM-main/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pyo
3
+ *.pyd
4
+ __py
5
+ **/__pycache__/
EfficientSAM-main/EfficientSAM_example.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
2
+ # from squeeze_sam.build_squeeze_sam import build_squeeze_sam
3
+
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import torch
7
+ import numpy as np
8
+ import zipfile
9
+
10
+
11
+
12
+ models = {}
13
+
14
+ # Build the EfficientSAM-Ti model.
15
+ models['efficientsam_ti'] = build_efficient_sam_vitt()
16
+
17
+ # Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
18
+ with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
19
+ zip_ref.extractall("weights")
20
+ # Build the EfficientSAM-S model.
21
+ models['efficientsam_s'] = build_efficient_sam_vits()
22
+
23
+ # Build the SqueezeSAM model.
24
+ # models['squeeze_sam'] = build_squeeze_sam()
25
+
26
+ # load an image
27
+ sample_image_np = np.array(Image.open("figs/examples/dogs.jpg"))
28
+ sample_image_tensor = transforms.ToTensor()(sample_image_np)
29
+ # Feed a few (x,y) points in the mask as input.
30
+
31
+ input_points = torch.tensor([[[[580, 350], [650, 350]]]])
32
+ input_labels = torch.tensor([[[1, 1]]])
33
+
34
+ # Run inference for both EfficientSAM-Ti and EfficientSAM-S models.
35
+ for model_name, model in models.items():
36
+ print('Running inference using ', model_name)
37
+ predicted_logits, predicted_iou = model(
38
+ sample_image_tensor[None, ...],
39
+ input_points,
40
+ input_labels,
41
+ )
42
+ sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
43
+ predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
44
+ predicted_logits = torch.take_along_dim(
45
+ predicted_logits, sorted_ids[..., None, None], dim=2
46
+ )
47
+ # The masks are already sorted by their predicted IOUs.
48
+ # The first dimension is the batch size (we have a single image. so it is 1).
49
+ # The second dimension is the number of masks we want to generate (in this case, it is only 1)
50
+ # The third dimension is the number of candidate masks output by the model.
51
+ # For this demo we use the first mask.
52
+ mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()
53
+ masked_image_np = sample_image_np.copy().astype(np.uint8) * mask[:,:,None]
54
+ Image.fromarray(masked_image_np).save(f"figs/examples/dogs_{model_name}_mask.png")
EfficientSAM-main/EfficientSAM_onnx_example.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ import imgviz
7
+ import onnxruntime
8
+ import time
9
+ from PIL import Image
10
+
11
+
12
+ def predict_onnx(input_image, input_points, input_labels):
13
+ if 0:
14
+ inference_session = onnxruntime.InferenceSession(
15
+ "weights/efficient_sam_vitt.onnx"
16
+ )
17
+ (
18
+ predicted_logits,
19
+ predicted_iou,
20
+ predicted_lowres_logits,
21
+ ) = inference_session.run(
22
+ output_names=None,
23
+ input_feed={
24
+ "batched_images": input_image,
25
+ "batched_point_coords": input_points,
26
+ "batched_point_labels": input_labels,
27
+ },
28
+ )
29
+ else:
30
+ inference_session = onnxruntime.InferenceSession(
31
+ "weights/efficient_sam_vitt_encoder.onnx"
32
+ )
33
+ t_start = time.time()
34
+ image_embeddings, = inference_session.run(
35
+ output_names=None,
36
+ input_feed={
37
+ "batched_images": input_image,
38
+ },
39
+ )
40
+ print("encoder time", time.time() - t_start)
41
+
42
+ inference_session = onnxruntime.InferenceSession(
43
+ "weights/efficient_sam_vitt_decoder.onnx"
44
+ )
45
+ t_start = time.time()
46
+ (
47
+ predicted_logits,
48
+ predicted_iou,
49
+ predicted_lowres_logits,
50
+ ) = inference_session.run(
51
+ output_names=None,
52
+ input_feed={
53
+ "image_embeddings": image_embeddings,
54
+ "batched_point_coords": input_points,
55
+ "batched_point_labels": input_labels,
56
+ "orig_im_size": np.array(input_image.shape[2:], dtype=np.int64),
57
+ },
58
+ )
59
+ print("decoder time", time.time() - t_start)
60
+ mask = predicted_logits[0, 0, 0, :, :] >= 0
61
+ imgviz.io.imsave(f"figs/examples/dogs_onnx_mask.png", mask)
62
+
63
+
64
+ def main():
65
+ image = np.array(Image.open("figs/examples/dogs.jpg"))
66
+
67
+ input_image = image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0
68
+ # batch_size, num_queries, num_points, 2
69
+ input_points = np.array([[[[580, 350], [650, 350]]]], dtype=np.float32)
70
+ # batch_size, num_queries, num_points
71
+ input_labels = np.array([[[1, 1]]], dtype=np.float32)
72
+
73
+ predict_onnx(input_image, input_points, input_labels)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
EfficientSAM-main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
EfficientSAM-main/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientSAM
2
+ EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
3
+
4
+ ## News
5
+ [Jan.12 2024] ONNX version of EfficientSAM including separate encoder and decoder is available on the [Hugging Face Space](https://huggingface.co/spaces/yunyangx/EfficientSAM/tree/main) (thanks to @wkentaro Kentaro Wada for implementing onnx export)
6
+
7
+ [Dec.31 2023] EfficientSAM is integrated into the annotation tool, [Labelme](https://github.com/labelmeai/labelme) (huge thanks to lableme team and @wkentaro Kentaro Wada)
8
+
9
+ [Dec.11 2023] The EfficientSAM model code with checkpoints is fully available in this repository. The [example](https://github.com/yformer/EfficientSAM/blob/main/EfficientSAM_example.py) script shows how to instantiate the model with checkpoint and query points on an image.
10
+
11
+ [Dec.10 2023] Grounded EfficientSAM demo is available on [Grounded-Efficient-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM) (huge thanks to IDEA-Research team and @rentainhe for supporting [grounded-efficient-sam demo](https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/EfficientSAM/grounded_efficient_sam.py) under [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)).
12
+
13
+ [Dec.6 2023] EfficientSAM demo is available on the [Hugging Face Space](https://huggingface.co/spaces/yunyangx/EfficientSAM) (huge thanks to all the HF team for their support).
14
+
15
+ [Dec.5 2023] We release the torchscript version of EfficientSAM and share a colab.
16
+
17
+ ## Online Demo & Examples
18
+ Online demo and examples can be found in the [project page](https://yformer.github.io/efficient-sam/).
19
+
20
+ ## EfficientSAM Instance Segmentation Examples
21
+ | | |
22
+ :-------------------------:|:-------------------------:
23
+ Point-prompt | ![point-prompt](figs/examples/demo_point.png)
24
+ Box-prompt | ![box-prompt](figs/examples/demo_box.png)
25
+ Segment everything |![segment everything](figs/examples/demo_everything.png)
26
+ Saliency | ![Saliency](figs/examples/demo_saliency.png)
27
+
28
+ ## Model
29
+ EfficientSAM checkpoints are available under the weights folder of this github repository. Example instantiations and run of the models can be found in [EfficientSAM_example.py](https://github.com/yformer/EfficientSAM/blob/main/EfficientSAM_example.py).
30
+
31
+ | EfficientSAM-S | EfficientSAM-Ti |
32
+ |------------------------------|------------------------------|
33
+ | [Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vits.pt.zip) |[Download](https://github.com/yformer/EfficientSAM/blob/main/weights/efficient_sam_vitt.pt)|
34
+
35
+ You can directly use EfficientSAM with checkpoints,
36
+ ```
37
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
38
+ efficientsam = build_efficient_sam_vitt()
39
+ ```
40
+
41
+ ## Jupyter Notebook Example
42
+ The notebook is shared [here](https://github.com/yformer/EfficientSAM/blob/main/notebooks)
43
+
44
+
45
+ ## Acknowledgement
46
+
47
+ + [SAM](https://github.com/facebookresearch/segment-anything)
48
+ + [MobileSAM](https://github.com/ChaoningZhang/MobileSAM)
49
+ + [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
50
+ + [U-2-Net](https://github.com/xuebinqin/U-2-Net)
51
+
52
+
53
+ If you're using EfficientSAM in your research or applications, please cite using this BibTeX:
54
+ ```bibtex
55
+
56
+
57
+ @article{xiong2023efficientsam,
58
+ title={EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything},
59
+ author={Yunyang Xiong, Bala Varadarajan, Lemeng Wu, Xiaoyu Xiang, Fanyi Xiao, Chenchen Zhu, Xiaoliang Dai, Dilin Wang, Fei Sun, Forrest Iandola, Raghuraman Krishnamoorthi, Vikas Chandra},
60
+ journal={arXiv:2312.00863},
61
+ year={2023}
62
+ }
63
+ ```
EfficientSAM-main/efficient_sam/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ from .build_efficient_sam import (
5
+ build_efficient_sam_vitt,
6
+ build_efficient_sam_vits,
7
+ )
EfficientSAM-main/efficient_sam/build_efficient_sam.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .efficient_sam import build_efficient_sam
8
+
9
+ def build_efficient_sam_vitt():
10
+ return build_efficient_sam(
11
+ encoder_patch_embed_dim=192,
12
+ encoder_num_heads=3,
13
+ checkpoint="weights/efficient_sam_vitt.pt",
14
+ ).eval()
15
+
16
+
17
+ def build_efficient_sam_vits():
18
+ return build_efficient_sam(
19
+ encoder_patch_embed_dim=384,
20
+ encoder_num_heads=6,
21
+ checkpoint="weights/efficient_sam_vits.pt",
22
+ ).eval()
EfficientSAM-main/efficient_sam/efficient_sam.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, List, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from torch import nn, Tensor
14
+
15
+ from .efficient_sam_decoder import MaskDecoder, PromptEncoder
16
+ from .efficient_sam_encoder import ImageEncoderViT
17
+ from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
18
+
19
+ class EfficientSam(nn.Module):
20
+ mask_threshold: float = 0.0
21
+ image_format: str = "RGB"
22
+
23
+ def __init__(
24
+ self,
25
+ image_encoder: ImageEncoderViT,
26
+ prompt_encoder: PromptEncoder,
27
+ decoder_max_num_input_points: int,
28
+ mask_decoder: MaskDecoder,
29
+ pixel_mean: List[float] = [0.485, 0.456, 0.406],
30
+ pixel_std: List[float] = [0.229, 0.224, 0.225],
31
+ ) -> None:
32
+ """
33
+ SAM predicts object masks from an image and input prompts.
34
+
35
+ Arguments:
36
+ image_encoder (ImageEncoderViT): The backbone used to encode the
37
+ image into image embeddings that allow for efficient mask prediction.
38
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
39
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
40
+ and encoded prompts.
41
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
42
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
43
+ """
44
+ super().__init__()
45
+ self.image_encoder = image_encoder
46
+ self.prompt_encoder = prompt_encoder
47
+ self.decoder_max_num_input_points = decoder_max_num_input_points
48
+ self.mask_decoder = mask_decoder
49
+ self.register_buffer(
50
+ "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
51
+ )
52
+ self.register_buffer(
53
+ "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
54
+ )
55
+
56
+ @torch.jit.export
57
+ def predict_masks(
58
+ self,
59
+ image_embeddings: torch.Tensor,
60
+ batched_points: torch.Tensor,
61
+ batched_point_labels: torch.Tensor,
62
+ multimask_output: bool,
63
+ input_h: int,
64
+ input_w: int,
65
+ output_h: int = -1,
66
+ output_w: int = -1,
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ """
69
+ Predicts masks given image embeddings and prompts. This only runs the decoder.
70
+
71
+ Arguments:
72
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
73
+ batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
74
+ batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
75
+ Returns:
76
+ A tuple of two tensors:
77
+ low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
78
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
79
+ """
80
+
81
+ batch_size, max_num_queries, num_pts, _ = batched_points.shape
82
+ num_pts = batched_points.shape[2]
83
+ rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
84
+
85
+ if num_pts > self.decoder_max_num_input_points:
86
+ rescaled_batched_points = rescaled_batched_points[
87
+ :, :, : self.decoder_max_num_input_points, :
88
+ ]
89
+ batched_point_labels = batched_point_labels[
90
+ :, :, : self.decoder_max_num_input_points
91
+ ]
92
+ elif num_pts < self.decoder_max_num_input_points:
93
+ rescaled_batched_points = F.pad(
94
+ rescaled_batched_points,
95
+ (0, 0, 0, self.decoder_max_num_input_points - num_pts),
96
+ value=-1.0,
97
+ )
98
+ batched_point_labels = F.pad(
99
+ batched_point_labels,
100
+ (0, self.decoder_max_num_input_points - num_pts),
101
+ value=-1.0,
102
+ )
103
+
104
+ sparse_embeddings = self.prompt_encoder(
105
+ rescaled_batched_points.reshape(
106
+ batch_size * max_num_queries, self.decoder_max_num_input_points, 2
107
+ ),
108
+ batched_point_labels.reshape(
109
+ batch_size * max_num_queries, self.decoder_max_num_input_points
110
+ ),
111
+ )
112
+ sparse_embeddings = sparse_embeddings.view(
113
+ batch_size,
114
+ max_num_queries,
115
+ sparse_embeddings.shape[1],
116
+ sparse_embeddings.shape[2],
117
+ )
118
+ low_res_masks, iou_predictions = self.mask_decoder(
119
+ image_embeddings,
120
+ self.prompt_encoder.get_dense_pe(),
121
+ sparse_prompt_embeddings=sparse_embeddings,
122
+ multimask_output=multimask_output,
123
+ )
124
+ _, num_predictions, low_res_size, _ = low_res_masks.shape
125
+
126
+ if output_w > 0 and output_h > 0:
127
+ output_masks = F.interpolate(
128
+ low_res_masks, (output_h, output_w), mode="bicubic"
129
+ )
130
+ output_masks = torch.reshape(
131
+ output_masks,
132
+ (batch_size, max_num_queries, num_predictions, output_h, output_w),
133
+ )
134
+ else:
135
+ output_masks = torch.reshape(
136
+ low_res_masks,
137
+ (
138
+ batch_size,
139
+ max_num_queries,
140
+ num_predictions,
141
+ low_res_size,
142
+ low_res_size,
143
+ ),
144
+ )
145
+ iou_predictions = torch.reshape(
146
+ iou_predictions, (batch_size, max_num_queries, num_predictions)
147
+ )
148
+ return output_masks, iou_predictions
149
+
150
+ def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
151
+ return torch.stack(
152
+ [
153
+ torch.where(
154
+ batched_points[..., 0] >= 0,
155
+ batched_points[..., 0] * self.image_encoder.img_size / input_w,
156
+ -1.0,
157
+ ),
158
+ torch.where(
159
+ batched_points[..., 1] >= 0,
160
+ batched_points[..., 1] * self.image_encoder.img_size / input_h,
161
+ -1.0,
162
+ ),
163
+ ],
164
+ dim=-1,
165
+ )
166
+
167
+ @torch.jit.export
168
+ def get_image_embeddings(self, batched_images) -> torch.Tensor:
169
+ """
170
+ Predicts masks end-to-end from provided images and prompts.
171
+ If prompts are not known in advance, using SamPredictor is
172
+ recommended over calling the model directly.
173
+
174
+ Arguments:
175
+ batched_images: A tensor of shape [B, 3, H, W]
176
+ Returns:
177
+ List of image embeddings each of of shape [B, C(i), H(i), W(i)].
178
+ The last embedding corresponds to the final layer.
179
+ """
180
+ batched_images = self.preprocess(batched_images)
181
+ return self.image_encoder(batched_images)
182
+
183
+ def forward(
184
+ self,
185
+ batched_images: torch.Tensor,
186
+ batched_points: torch.Tensor,
187
+ batched_point_labels: torch.Tensor,
188
+ scale_to_original_image_size: bool = True,
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Predicts masks end-to-end from provided images and prompts.
192
+ If prompts are not known in advance, using SamPredictor is
193
+ recommended over calling the model directly.
194
+
195
+ Arguments:
196
+ batched_images: A tensor of shape [B, 3, H, W]
197
+ batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
198
+ batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
199
+
200
+ Returns:
201
+ A list tuples of two tensors where the ith element is by considering the first i+1 points.
202
+ low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
203
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
204
+ """
205
+ batch_size, _, input_h, input_w = batched_images.shape
206
+ image_embeddings = self.get_image_embeddings(batched_images)
207
+ return self.predict_masks(
208
+ image_embeddings,
209
+ batched_points,
210
+ batched_point_labels,
211
+ multimask_output=True,
212
+ input_h=input_h,
213
+ input_w=input_w,
214
+ output_h=input_h if scale_to_original_image_size else -1,
215
+ output_w=input_w if scale_to_original_image_size else -1,
216
+ )
217
+
218
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
219
+ """Normalize pixel values and pad to a square input."""
220
+ if (
221
+ x.shape[2] != self.image_encoder.img_size
222
+ or x.shape[3] != self.image_encoder.img_size
223
+ ):
224
+ x = F.interpolate(
225
+ x,
226
+ (self.image_encoder.img_size, self.image_encoder.img_size),
227
+ mode="bilinear",
228
+ )
229
+ return (x - self.pixel_mean) / self.pixel_std
230
+
231
+
232
+ def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
233
+ img_size = 1024
234
+ encoder_patch_size = 16
235
+ encoder_depth = 12
236
+ encoder_mlp_ratio = 4.0
237
+ encoder_neck_dims = [256, 256]
238
+ decoder_max_num_input_points = 6
239
+ decoder_transformer_depth = 2
240
+ decoder_transformer_mlp_dim = 2048
241
+ decoder_num_heads = 8
242
+ decoder_upscaling_layer_dims = [64, 32]
243
+ num_multimask_outputs = 3
244
+ iou_head_depth = 3
245
+ iou_head_hidden_dim = 256
246
+ activation = "gelu"
247
+ normalization_type = "layer_norm"
248
+ normalize_before_activation = False
249
+
250
+ assert activation == "relu" or activation == "gelu"
251
+ if activation == "relu":
252
+ activation_fn = nn.ReLU
253
+ else:
254
+ activation_fn = nn.GELU
255
+
256
+ image_encoder = ImageEncoderViT(
257
+ img_size=img_size,
258
+ patch_size=encoder_patch_size,
259
+ in_chans=3,
260
+ patch_embed_dim=encoder_patch_embed_dim,
261
+ normalization_type=normalization_type,
262
+ depth=encoder_depth,
263
+ num_heads=encoder_num_heads,
264
+ mlp_ratio=encoder_mlp_ratio,
265
+ neck_dims=encoder_neck_dims,
266
+ act_layer=activation_fn,
267
+ )
268
+
269
+ image_embedding_size = image_encoder.image_embedding_size
270
+ encoder_transformer_output_dim = image_encoder.transformer_output_dim
271
+
272
+ sam = EfficientSam(
273
+ image_encoder=image_encoder,
274
+ prompt_encoder=PromptEncoder(
275
+ embed_dim=encoder_transformer_output_dim,
276
+ image_embedding_size=(image_embedding_size, image_embedding_size),
277
+ input_image_size=(img_size, img_size),
278
+ ),
279
+ decoder_max_num_input_points=decoder_max_num_input_points,
280
+ mask_decoder=MaskDecoder(
281
+ transformer_dim=encoder_transformer_output_dim,
282
+ transformer=TwoWayTransformer(
283
+ depth=decoder_transformer_depth,
284
+ embedding_dim=encoder_transformer_output_dim,
285
+ num_heads=decoder_num_heads,
286
+ mlp_dim=decoder_transformer_mlp_dim,
287
+ activation=activation_fn,
288
+ normalize_before_activation=normalize_before_activation,
289
+ ),
290
+ num_multimask_outputs=num_multimask_outputs,
291
+ activation=activation_fn,
292
+ normalization_type=normalization_type,
293
+ normalize_before_activation=normalize_before_activation,
294
+ iou_head_depth=iou_head_depth - 1,
295
+ iou_head_hidden_dim=iou_head_hidden_dim,
296
+ upscaling_layer_dims=decoder_upscaling_layer_dims,
297
+ ),
298
+ pixel_mean=[0.485, 0.456, 0.406],
299
+ pixel_std=[0.229, 0.224, 0.225],
300
+ )
301
+ if checkpoint is not None:
302
+ with open(checkpoint, "rb") as f:
303
+ state_dict = torch.load(f, map_location="cpu")
304
+ sam.load_state_dict(state_dict["model"])
305
+ return sam
EfficientSAM-main/efficient_sam/efficient_sam_decoder.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .mlp import MLPBlock
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ ) -> None:
24
+ """
25
+ Encodes prompts for input to SAM's mask decoder.
26
+
27
+ Arguments:
28
+ embed_dim (int): The prompts' embedding dimension
29
+ image_embedding_size (tuple(int, int)): The spatial size of the
30
+ image embedding, as (H, W).
31
+ input_image_size (int): The padded size of the image as input
32
+ to the image encoder, as (H, W).
33
+ """
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.input_image_size = input_image_size
37
+ self.image_embedding_size = image_embedding_size
38
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
39
+ self.invalid_points = nn.Embedding(1, embed_dim)
40
+ self.point_embeddings = nn.Embedding(1, embed_dim)
41
+ self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
42
+ self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
43
+
44
+ def get_dense_pe(self) -> torch.Tensor:
45
+ """
46
+ Returns the positional encoding used to encode point prompts,
47
+ applied to a dense set of points the shape of the image encoding.
48
+
49
+ Returns:
50
+ torch.Tensor: Positional encoding with shape
51
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
52
+ """
53
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
54
+
55
+ def _embed_points(
56
+ self,
57
+ points: torch.Tensor,
58
+ labels: torch.Tensor,
59
+ ) -> torch.Tensor:
60
+ """Embeds point prompts."""
61
+ points = points + 0.5 # Shift to center of pixel
62
+ point_embedding = self.pe_layer.forward_with_coords(
63
+ points, self.input_image_size
64
+ )
65
+ invalid_label_ids = torch.eq(labels, -1)[:,:,None]
66
+ point_label_ids = torch.eq(labels, 1)[:,:,None]
67
+ topleft_label_ids = torch.eq(labels, 2)[:,:,None]
68
+ bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
69
+ point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
70
+ point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
71
+ point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
72
+ point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
73
+ return point_embedding
74
+
75
+ def forward(
76
+ self,
77
+ coords,
78
+ labels,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Embeds different types of prompts, returning both sparse and dense
82
+ embeddings.
83
+
84
+ Arguments:
85
+ points: A tensor of shape [B, 2]
86
+ labels: An integer tensor of shape [B] where each element is 1,2 or 3.
87
+
88
+ Returns:
89
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
90
+ BxNx(embed_dim), where N is determined by the number of input points
91
+ and boxes.
92
+ """
93
+ return self._embed_points(coords, labels)
94
+
95
+
96
+ class PositionEmbeddingRandom(nn.Module):
97
+ """
98
+ Positional encoding using random spatial frequencies.
99
+ """
100
+
101
+ def __init__(self, num_pos_feats: int) -> None:
102
+ super().__init__()
103
+ self.register_buffer(
104
+ "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
105
+ )
106
+
107
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
108
+ """Positionally encode points that are normalized to [0,1]."""
109
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
110
+ coords = 2 * coords - 1
111
+ coords = coords @ self.positional_encoding_gaussian_matrix
112
+ coords = 2 * np.pi * coords
113
+ # outputs d_1 x ... x d_n x C shape
114
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
115
+
116
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
117
+ """Generate positional encoding for a grid of the specified size."""
118
+ h, w = size
119
+ device = self.positional_encoding_gaussian_matrix.device
120
+ grid = torch.ones([h, w], device=device, dtype=torch.float32)
121
+ y_embed = grid.cumsum(dim=0) - 0.5
122
+ x_embed = grid.cumsum(dim=1) - 0.5
123
+ y_embed = y_embed / h
124
+ x_embed = x_embed / w
125
+
126
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
127
+ return pe.permute(2, 0, 1) # C x H x W
128
+
129
+ def forward_with_coords(
130
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
131
+ ) -> torch.Tensor:
132
+ """Positionally encode points that are not normalized to [0,1]."""
133
+ coords = coords_input.clone()
134
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
135
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
136
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
137
+
138
+
139
+ class MaskDecoder(nn.Module):
140
+ def __init__(
141
+ self,
142
+ *,
143
+ transformer_dim: int,
144
+ transformer: nn.Module,
145
+ num_multimask_outputs: int,
146
+ activation: Type[nn.Module],
147
+ normalization_type: str,
148
+ normalize_before_activation: bool,
149
+ iou_head_depth: int,
150
+ iou_head_hidden_dim: int,
151
+ upscaling_layer_dims: List[int],
152
+ ) -> None:
153
+ """
154
+ Predicts masks given an image and prompt embeddings, using a
155
+ transformer architecture.
156
+
157
+ Arguments:
158
+ transformer_dim (int): the channel dimension of the transformer
159
+ transformer (nn.Module): the transformer used to predict masks
160
+ num_multimask_outputs (int): the number of masks to predict
161
+ when disambiguating masks
162
+ activation (nn.Module): the type of activation to use when
163
+ upscaling masks
164
+ iou_head_depth (int): the depth of the MLP used to predict
165
+ mask quality
166
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
167
+ used to predict mask quality
168
+ """
169
+ super().__init__()
170
+ self.transformer_dim = transformer_dim
171
+ self.transformer = transformer
172
+
173
+ self.num_multimask_outputs = num_multimask_outputs
174
+
175
+ self.iou_token = nn.Embedding(1, transformer_dim)
176
+ if num_multimask_outputs > 1:
177
+ self.num_mask_tokens = num_multimask_outputs + 1
178
+ else:
179
+ self.num_mask_tokens = 1
180
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
181
+ output_dim_after_upscaling = transformer_dim
182
+
183
+ self.final_output_upscaling_layers = nn.ModuleList([])
184
+ for idx, layer_dims in enumerate(upscaling_layer_dims):
185
+ self.final_output_upscaling_layers.append(
186
+ nn.Sequential(
187
+ nn.ConvTranspose2d(
188
+ output_dim_after_upscaling,
189
+ layer_dims,
190
+ kernel_size=2,
191
+ stride=2,
192
+ ),
193
+ nn.GroupNorm(1, layer_dims)
194
+ if idx < len(upscaling_layer_dims) - 1
195
+ else nn.Identity(),
196
+ activation(),
197
+ )
198
+ )
199
+ output_dim_after_upscaling = layer_dims
200
+
201
+ self.output_hypernetworks_mlps = nn.ModuleList(
202
+ [
203
+ MLPBlock(
204
+ input_dim=transformer_dim,
205
+ hidden_dim=transformer_dim,
206
+ output_dim=output_dim_after_upscaling,
207
+ num_layers=2,
208
+ act=activation,
209
+ )
210
+ for i in range(self.num_mask_tokens)
211
+ ]
212
+ )
213
+
214
+ self.iou_prediction_head = MLPBlock(
215
+ input_dim=transformer_dim,
216
+ hidden_dim=iou_head_hidden_dim,
217
+ output_dim=self.num_mask_tokens,
218
+ num_layers=iou_head_depth,
219
+ act=activation,
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ image_embeddings: torch.Tensor,
225
+ image_pe: torch.Tensor,
226
+ sparse_prompt_embeddings: torch.Tensor,
227
+ multimask_output: bool,
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ """
230
+ Predict masks given image and prompt embeddings.
231
+
232
+ Arguments:
233
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
234
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
235
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
236
+ multimask_output (bool): Whether to return multiple masks or a single
237
+ mask.
238
+
239
+ Returns:
240
+ torch.Tensor: batched predicted masks
241
+ torch.Tensor: batched predictions of mask quality
242
+ """
243
+
244
+ (
245
+ batch_size,
246
+ max_num_queries,
247
+ sparse_embed_dim_1,
248
+ sparse_embed_dim_2,
249
+ ) = sparse_prompt_embeddings.shape
250
+
251
+ (
252
+ _,
253
+ image_embed_dim_c,
254
+ image_embed_dim_h,
255
+ image_embed_dim_w,
256
+ ) = image_embeddings.shape
257
+
258
+ # Tile the image embedding for all queries.
259
+ image_embeddings_tiled = torch.tile(
260
+ image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
261
+ ).view(
262
+ batch_size * max_num_queries,
263
+ image_embed_dim_c,
264
+ image_embed_dim_h,
265
+ image_embed_dim_w,
266
+ )
267
+ sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
268
+ batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
269
+ )
270
+ masks, iou_pred = self.predict_masks(
271
+ image_embeddings=image_embeddings_tiled,
272
+ image_pe=image_pe,
273
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
274
+ )
275
+ if multimask_output and self.num_multimask_outputs > 1:
276
+ return masks[:, 1:, :], iou_pred[:, 1:]
277
+ else:
278
+ return masks[:, :1, :], iou_pred[:, :1]
279
+
280
+ def predict_masks(
281
+ self,
282
+ image_embeddings: torch.Tensor,
283
+ image_pe: torch.Tensor,
284
+ sparse_prompt_embeddings: torch.Tensor,
285
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
286
+ """Predicts masks. See 'forward' for more details."""
287
+ # Concatenate output tokens
288
+ output_tokens = torch.cat(
289
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
290
+ )
291
+ output_tokens = output_tokens.unsqueeze(0).expand(
292
+ sparse_prompt_embeddings.size(0), -1, -1
293
+ )
294
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
295
+ # Expand per-image data in batch direction to be per-mask
296
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
297
+ b, c, h, w = image_embeddings.shape
298
+ hs, src = self.transformer(image_embeddings, pos_src, tokens)
299
+ iou_token_out = hs[:, 0, :]
300
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
301
+
302
+ # Upscale mask embeddings and predict masks using the mask tokens
303
+ upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
304
+
305
+ for upscaling_layer in self.final_output_upscaling_layers:
306
+ upscaled_embedding = upscaling_layer(upscaled_embedding)
307
+ hyper_in_list: List[torch.Tensor] = []
308
+ for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
309
+ hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
310
+ hyper_in = torch.stack(hyper_in_list, dim=1)
311
+ b, c, h, w = upscaled_embedding.shape
312
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
313
+ # Generate mask quality predictions
314
+ iou_pred = self.iou_prediction_head(iou_token_out)
315
+ return masks, iou_pred
EfficientSAM-main/efficient_sam/efficient_sam_encoder.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import List, Optional, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class LayerNorm2d(nn.Module):
16
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
17
+ super().__init__()
18
+ self.weight = nn.Parameter(torch.ones(num_channels))
19
+ self.bias = nn.Parameter(torch.zeros(num_channels))
20
+ self.eps = eps
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ u = x.mean(1, keepdim=True)
24
+ s = (x - u).pow(2).mean(1, keepdim=True)
25
+ x = (x - u) / torch.sqrt(s + self.eps)
26
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
27
+ return x
28
+
29
+
30
+ class PatchEmbed(nn.Module):
31
+ """2D Image to Patch Embedding"""
32
+
33
+ def __init__(
34
+ self,
35
+ img_size,
36
+ patch_size,
37
+ in_chans,
38
+ embed_dim,
39
+ ):
40
+ super().__init__()
41
+ self.proj = nn.Conv2d(
42
+ in_chans,
43
+ embed_dim,
44
+ kernel_size=(patch_size, patch_size),
45
+ stride=(patch_size, patch_size),
46
+ bias=True,
47
+ )
48
+
49
+ def forward(self, x):
50
+ B, C, H, W = x.shape
51
+ x = self.proj(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ num_heads,
60
+ qkv_bias,
61
+ qk_scale=None,
62
+ ):
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ head_dim = dim // num_heads
66
+ self.scale = qk_scale or head_dim**-0.5
67
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68
+ self.proj = nn.Linear(dim, dim)
69
+
70
+ def forward(self, x):
71
+ B, N, C = x.shape
72
+ qkv = (
73
+ self.qkv(x)
74
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
75
+ .permute(2, 0, 3, 1, 4)
76
+ )
77
+ q, k, v = (
78
+ qkv[0],
79
+ qkv[1],
80
+ qkv[2],
81
+ )
82
+ attn = (q @ k.transpose(-2, -1)) * self.scale
83
+ attn = attn.softmax(dim=-1)
84
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
85
+ x = self.proj(x)
86
+ return x
87
+
88
+
89
+ class Mlp(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_features,
93
+ hidden_features=None,
94
+ out_features=None,
95
+ act_layer=nn.GELU,
96
+ ):
97
+ super().__init__()
98
+ out_features = out_features or in_features
99
+ hidden_features = hidden_features or in_features
100
+ self.fc1 = nn.Linear(in_features, hidden_features)
101
+ self.act = act_layer()
102
+ self.fc2 = nn.Linear(hidden_features, out_features)
103
+
104
+ def forward(self, x):
105
+ x = self.fc1(x)
106
+ x = self.act(x)
107
+ x = self.fc2(x)
108
+ return x
109
+
110
+
111
+ class Block(nn.Module):
112
+ def __init__(
113
+ self,
114
+ dim,
115
+ num_heads,
116
+ mlp_ratio=4.0,
117
+ qkv_bias=False,
118
+ qk_scale=None,
119
+ act_layer=nn.GELU,
120
+ ):
121
+ super().__init__()
122
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
123
+ self.attn = Attention(
124
+ dim,
125
+ num_heads=num_heads,
126
+ qkv_bias=qkv_bias,
127
+ qk_scale=qk_scale,
128
+ )
129
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(
132
+ in_features=dim,
133
+ hidden_features=mlp_hidden_dim,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = x + self.attn(self.norm1(x))
139
+ x = x + self.mlp(self.norm2(x))
140
+ return x
141
+
142
+
143
+ @torch.jit.export
144
+ def get_abs_pos(
145
+ abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
146
+ ) -> torch.Tensor:
147
+ """
148
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
149
+ dimension for the original embeddings.
150
+ Args:
151
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
152
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
153
+ hw (Tuple): size of input image tokens.
154
+
155
+ Returns:
156
+ Absolute positional embeddings after processing with shape (1, H, W, C)
157
+ """
158
+ h = hw[0]
159
+ w = hw[1]
160
+ if has_cls_token:
161
+ abs_pos = abs_pos[:, 1:]
162
+ xy_num = abs_pos.shape[1]
163
+ size = int(math.sqrt(xy_num))
164
+ assert size * size == xy_num
165
+
166
+ if size != h or size != w:
167
+ new_abs_pos = F.interpolate(
168
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
169
+ size=(h, w),
170
+ mode="bicubic",
171
+ align_corners=False,
172
+ )
173
+ return new_abs_pos.permute(0, 2, 3, 1)
174
+ else:
175
+ return abs_pos.reshape(1, h, w, -1)
176
+
177
+
178
+ # Image encoder for efficient SAM.
179
+ class ImageEncoderViT(nn.Module):
180
+ def __init__(
181
+ self,
182
+ img_size: int,
183
+ patch_size: int,
184
+ in_chans: int,
185
+ patch_embed_dim: int,
186
+ normalization_type: str,
187
+ depth: int,
188
+ num_heads: int,
189
+ mlp_ratio: float,
190
+ neck_dims: List[int],
191
+ act_layer: Type[nn.Module],
192
+ ) -> None:
193
+ """
194
+ Args:
195
+ img_size (int): Input image size.
196
+ patch_size (int): Patch size.
197
+ in_chans (int): Number of input image channels.
198
+ patch_embed_dim (int): Patch embedding dimension.
199
+ depth (int): Depth of ViT.
200
+ num_heads (int): Number of attention heads in each ViT block.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ act_layer (nn.Module): Activation layer.
203
+ """
204
+ super().__init__()
205
+
206
+ self.img_size = img_size
207
+ self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
208
+ self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
209
+ self.pretrain_use_cls_token = True
210
+ pretrain_img_size = 224
211
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
212
+ # Initialize absolute positional embedding with pretrain image size.
213
+ num_patches = (pretrain_img_size // patch_size) * (
214
+ pretrain_img_size // patch_size
215
+ )
216
+ num_positions = num_patches + 1
217
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
218
+ self.blocks = nn.ModuleList()
219
+ for i in range(depth):
220
+ vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
221
+ self.blocks.append(vit_block)
222
+ self.neck = nn.Sequential(
223
+ nn.Conv2d(
224
+ patch_embed_dim,
225
+ neck_dims[0],
226
+ kernel_size=1,
227
+ bias=False,
228
+ ),
229
+ LayerNorm2d(neck_dims[0]),
230
+ nn.Conv2d(
231
+ neck_dims[0],
232
+ neck_dims[0],
233
+ kernel_size=3,
234
+ padding=1,
235
+ bias=False,
236
+ ),
237
+ LayerNorm2d(neck_dims[0]),
238
+ )
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ assert (
242
+ x.shape[2] == self.img_size and x.shape[3] == self.img_size
243
+ ), "input image size must match self.img_size"
244
+ x = self.patch_embed(x)
245
+ # B C H W -> B H W C
246
+ x = x.permute(0, 2, 3, 1)
247
+ x = x + get_abs_pos(
248
+ self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
249
+ )
250
+ num_patches = x.shape[1]
251
+ assert x.shape[2] == num_patches
252
+ x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
253
+ for blk in self.blocks:
254
+ x = blk(x)
255
+ x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
256
+ x = self.neck(x.permute(0, 3, 1, 2))
257
+ return x
EfficientSAM-main/efficient_sam/mlp.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from torch import nn
4
+
5
+
6
+ # Lightly adapted from
7
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
8
+ class MLPBlock(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim: int,
12
+ hidden_dim: int,
13
+ output_dim: int,
14
+ num_layers: int,
15
+ act: Type[nn.Module],
16
+ ) -> None:
17
+ super().__init__()
18
+ self.num_layers = num_layers
19
+ h = [hidden_dim] * (num_layers - 1)
20
+ self.layers = nn.ModuleList(
21
+ nn.Sequential(nn.Linear(n, k), act())
22
+ for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
23
+ )
24
+ self.fc = nn.Linear(hidden_dim, output_dim)
25
+
26
+ def forward(self, x):
27
+ for layer in self.layers:
28
+ x = layer(x)
29
+ return self.fc(x)
EfficientSAM-main/efficient_sam/two_way_transformer.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Type
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from .mlp import MLPBlock
6
+
7
+
8
+
9
+
10
+ class TwoWayTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ depth: int,
14
+ embedding_dim: int,
15
+ num_heads: int,
16
+ mlp_dim: int,
17
+ activation: Type[nn.Module],
18
+ normalize_before_activation: bool,
19
+ attention_downsample_rate: int = 2,
20
+ ) -> None:
21
+ """
22
+ A transformer decoder that attends to an input image using
23
+ queries whose positional embedding is supplied.
24
+
25
+ Args:
26
+ depth (int): number of layers in the transformer
27
+ embedding_dim (int): the channel dimension for the input embeddings
28
+ num_heads (int): the number of heads for multihead attention. Must
29
+ divide embedding_dim
30
+ mlp_dim (int): the channel dimension internal to the MLP block
31
+ activation (nn.Module): the activation to use in the MLP block
32
+ """
33
+ super().__init__()
34
+ self.depth = depth
35
+ self.embedding_dim = embedding_dim
36
+ self.num_heads = num_heads
37
+ self.mlp_dim = mlp_dim
38
+ self.layers = nn.ModuleList()
39
+
40
+ for i in range(depth):
41
+ curr_layer = TwoWayAttentionBlock(
42
+ embedding_dim=embedding_dim,
43
+ num_heads=num_heads,
44
+ mlp_dim=mlp_dim,
45
+ activation=activation,
46
+ normalize_before_activation=normalize_before_activation,
47
+ attention_downsample_rate=attention_downsample_rate,
48
+ skip_first_layer_pe=(i == 0),
49
+ )
50
+ self.layers.append(curr_layer)
51
+
52
+ self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
53
+ embedding_dim,
54
+ num_heads,
55
+ downsample_rate=attention_downsample_rate,
56
+ )
57
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
58
+
59
+ def forward(
60
+ self,
61
+ image_embedding: Tensor,
62
+ image_pe: Tensor,
63
+ point_embedding: Tensor,
64
+ ) -> Tuple[Tensor, Tensor]:
65
+ """
66
+ Args:
67
+ image_embedding (torch.Tensor): image to attend to. Should be shape
68
+ B x embedding_dim x h x w for any h and w.
69
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
70
+ have the same shape as image_embedding.
71
+ point_embedding (torch.Tensor): the embedding to add to the query points.
72
+ Must have shape B x N_points x embedding_dim for any N_points.
73
+
74
+ Returns:
75
+ torch.Tensor: the processed point_embedding
76
+ torch.Tensor: the processed image_embedding
77
+ """
78
+
79
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
80
+ bs, c, h, w = image_embedding.shape
81
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
82
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
83
+
84
+ # Prepare queries
85
+ queries = point_embedding
86
+ keys = image_embedding
87
+
88
+ # Apply transformer blocks and final layernorm
89
+ for idx, layer in enumerate(self.layers):
90
+ queries, keys = layer(
91
+ queries=queries,
92
+ keys=keys,
93
+ query_pe=point_embedding,
94
+ key_pe=image_pe,
95
+ )
96
+
97
+ # Apply the final attention layer from the points to the image
98
+ q = queries + point_embedding
99
+ k = keys + image_pe
100
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
101
+ queries = queries + attn_out
102
+ queries = self.norm_final_attn(queries)
103
+ return queries, keys
104
+
105
+
106
+ class TwoWayAttentionBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ embedding_dim: int,
110
+ num_heads: int,
111
+ mlp_dim: int,
112
+ activation: Type[nn.Module],
113
+ normalize_before_activation: bool,
114
+ attention_downsample_rate: int = 2,
115
+ skip_first_layer_pe: bool = False,
116
+ ) -> None:
117
+ """
118
+ A transformer block with four layers: (1) self-attention of sparse
119
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
120
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
121
+ inputs.
122
+
123
+ Arguments:
124
+ embedding_dim (int): the channel dimension of the embeddings
125
+ num_heads (int): the number of heads in the attention layers
126
+ mlp_dim (int): the hidden dimension of the mlp block
127
+ activation (nn.Module): the activation of the mlp block
128
+ skip_first_layer_pe (bool): skip the PE on the first layer
129
+ """
130
+ super().__init__()
131
+ self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
132
+ self.norm1 = nn.LayerNorm(embedding_dim)
133
+
134
+ self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
135
+ embedding_dim,
136
+ num_heads,
137
+ downsample_rate=attention_downsample_rate,
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(
142
+ embedding_dim,
143
+ mlp_dim,
144
+ embedding_dim,
145
+ 1,
146
+ activation,
147
+ )
148
+
149
+ self.norm3 = nn.LayerNorm(embedding_dim)
150
+
151
+ self.norm4 = nn.LayerNorm(embedding_dim)
152
+ self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
153
+ embedding_dim,
154
+ num_heads,
155
+ downsample_rate=attention_downsample_rate,
156
+ )
157
+
158
+ self.skip_first_layer_pe = skip_first_layer_pe
159
+
160
+ def forward(
161
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
162
+ ) -> Tuple[Tensor, Tensor]:
163
+ # Self attention block
164
+ if not self.skip_first_layer_pe:
165
+ queries = queries + query_pe
166
+ attn_out = self.self_attn(q=queries, k=queries, v=queries)
167
+ queries = queries + attn_out
168
+ queries = self.norm1(queries)
169
+
170
+ # Cross attention block, tokens attending to image embedding
171
+ q = queries + query_pe
172
+ k = keys + key_pe
173
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
174
+ queries = queries + attn_out
175
+ queries = self.norm2(queries)
176
+
177
+ # MLP block
178
+ mlp_out = self.mlp(queries)
179
+ queries = queries + mlp_out
180
+ queries = self.norm3(queries)
181
+
182
+ # Cross attention block, image embedding attending to tokens
183
+ q = queries + query_pe
184
+ k = keys + key_pe
185
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
186
+ keys = keys + attn_out
187
+ keys = self.norm4(keys)
188
+
189
+ return queries, keys
190
+
191
+
192
+ class AttentionForTwoWayAttentionBlock(nn.Module):
193
+ """
194
+ An attention layer that allows for downscaling the size of the embedding
195
+ after projection to queries, keys, and values.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ embedding_dim: int,
201
+ num_heads: int,
202
+ downsample_rate: int = 1,
203
+ ) -> None:
204
+ super().__init__()
205
+ self.embedding_dim = embedding_dim
206
+ self.internal_dim = embedding_dim // downsample_rate
207
+ self.num_heads = num_heads
208
+ assert (
209
+ self.internal_dim % num_heads == 0
210
+ ), "num_heads must divide embedding_dim."
211
+ self.c_per_head = self.internal_dim / num_heads
212
+ self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head)
213
+
214
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
215
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
216
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
217
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
218
+ self._reset_parameters()
219
+
220
+ def _reset_parameters(self) -> None:
221
+ # The fan_out is incorrect, but matches pytorch's initialization
222
+ # for which qkv is a single 3*embedding_dim x embedding_dim matrix
223
+ fan_in = self.embedding_dim
224
+ fan_out = 3 * self.internal_dim
225
+ # Xavier uniform with our custom fan_out
226
+ bnd = math.sqrt(6 / (fan_in + fan_out))
227
+ nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
228
+ nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
229
+ nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
230
+ # out_proj.weight is left with default initialization, like pytorch attention
231
+ nn.init.zeros_(self.q_proj.bias)
232
+ nn.init.zeros_(self.k_proj.bias)
233
+ nn.init.zeros_(self.v_proj.bias)
234
+ nn.init.zeros_(self.out_proj.bias)
235
+
236
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
237
+ b, n, c = x.shape
238
+ x = x.reshape(b, n, num_heads, c // num_heads)
239
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
240
+
241
+ def _recombine_heads(self, x: Tensor) -> Tensor:
242
+ b, n_heads, n_tokens, c_per_head = x.shape
243
+ x = x.transpose(1, 2)
244
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
245
+
246
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
247
+ # Input projections
248
+ q = self.q_proj(q)
249
+ k = self.k_proj(k)
250
+ v = self.v_proj(v)
251
+
252
+ # Separate into heads
253
+ q = self._separate_heads(q, self.num_heads)
254
+ k = self._separate_heads(k, self.num_heads)
255
+ v = self._separate_heads(v, self.num_heads)
256
+
257
+ # Attention
258
+ _, _, _, c_per_head = q.shape
259
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
260
+ attn = attn * self.inv_sqrt_c_per_head
261
+ attn = torch.softmax(attn, dim=-1)
262
+ # Get output
263
+ out = attn @ v
264
+ out = self._recombine_heads(out)
265
+ out = self.out_proj(out)
266
+ return out
EfficientSAM-main/export_to_onnx.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ONNX export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
2
+
3
+ import onnxruntime
4
+ import torch
5
+
6
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
7
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vitt
8
+
9
+ import onnx_models
10
+
11
+
12
+ def export_onnx(onnx_model, output, dynamic_axes, dummy_inputs, output_names):
13
+ with open(output, "wb") as f:
14
+ print(f"Exporting onnx model to {output}...")
15
+ torch.onnx.export(
16
+ onnx_model,
17
+ tuple(dummy_inputs.values()),
18
+ f,
19
+ export_params=True,
20
+ verbose=False,
21
+ opset_version=17,
22
+ do_constant_folding=True,
23
+ input_names=list(dummy_inputs.keys()),
24
+ output_names=output_names,
25
+ dynamic_axes=dynamic_axes,
26
+ )
27
+
28
+ inference_session = onnxruntime.InferenceSession(output)
29
+ output = inference_session.run(
30
+ output_names=output_names,
31
+ input_feed={k: v.numpy() for k, v in dummy_inputs.items()},
32
+ )
33
+ print(output_names)
34
+ print([output_i.shape for output_i in output])
35
+
36
+
37
+ def export_onnx_esam(model, output):
38
+ onnx_model = onnx_models.OnnxEfficientSam(model=model)
39
+ dynamic_axes = {
40
+ "batched_images": {0: "batch", 2: "height", 3: "width"},
41
+ "batched_point_coords": {2: "num_points"},
42
+ "batched_point_labels": {2: "num_points"},
43
+ }
44
+ dummy_inputs = {
45
+ "batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
46
+ "batched_point_coords": torch.randint(
47
+ low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
48
+ ),
49
+ "batched_point_labels": torch.randint(
50
+ low=0, high=4, size=(1, 1, 5), dtype=torch.float
51
+ ),
52
+ }
53
+ output_names = ["output_masks", "iou_predictions"]
54
+ export_onnx(
55
+ onnx_model=onnx_model,
56
+ output=output,
57
+ dynamic_axes=dynamic_axes,
58
+ dummy_inputs=dummy_inputs,
59
+ output_names=output_names,
60
+ )
61
+
62
+
63
+ def export_onnx_esam_encoder(model, output):
64
+ onnx_model = onnx_models.OnnxEfficientSamEncoder(model=model)
65
+ dynamic_axes = {
66
+ "batched_images": {0: "batch", 2: "height", 3: "width"},
67
+ }
68
+ dummy_inputs = {
69
+ "batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
70
+ }
71
+ output_names = ["image_embeddings"]
72
+ export_onnx(
73
+ onnx_model=onnx_model,
74
+ output=output,
75
+ dynamic_axes=dynamic_axes,
76
+ dummy_inputs=dummy_inputs,
77
+ output_names=output_names,
78
+ )
79
+
80
+
81
+ def export_onnx_esam_decoder(model, output):
82
+ onnx_model = onnx_models.OnnxEfficientSamDecoder(model=model)
83
+ dynamic_axes = {
84
+ "image_embeddings": {0: "batch"},
85
+ "batched_point_coords": {2: "num_points"},
86
+ "batched_point_labels": {2: "num_points"},
87
+ }
88
+ dummy_inputs = {
89
+ "image_embeddings": torch.randn(1, 256, 64, 64, dtype=torch.float),
90
+ "batched_point_coords": torch.randint(
91
+ low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float
92
+ ),
93
+ "batched_point_labels": torch.randint(
94
+ low=0, high=4, size=(1, 1, 5), dtype=torch.float
95
+ ),
96
+ "orig_im_size": torch.tensor([1080, 1920], dtype=torch.long),
97
+ }
98
+ output_names = ["output_masks", "iou_predictions"]
99
+ export_onnx(
100
+ onnx_model=onnx_model,
101
+ output=output,
102
+ dynamic_axes=dynamic_axes,
103
+ dummy_inputs=dummy_inputs,
104
+ output_names=output_names,
105
+ )
106
+
107
+
108
+ def main():
109
+ # faster
110
+ export_onnx_esam(
111
+ model=build_efficient_sam_vitt(),
112
+ output="weights/efficient_sam_vitt.onnx",
113
+ )
114
+ export_onnx_esam_encoder(
115
+ model=build_efficient_sam_vitt(),
116
+ output="weights/efficient_sam_vitt_encoder.onnx",
117
+ )
118
+ export_onnx_esam_decoder(
119
+ model=build_efficient_sam_vitt(),
120
+ output="weights/efficient_sam_vitt_decoder.onnx",
121
+ )
122
+
123
+ # more accurate
124
+ export_onnx_esam(
125
+ model=build_efficient_sam_vits(),
126
+ output="weights/efficient_sam_vits.onnx",
127
+ )
128
+ export_onnx_esam_encoder(
129
+ model=build_efficient_sam_vits(),
130
+ output="weights/efficient_sam_vits_encoder.onnx",
131
+ )
132
+ export_onnx_esam_decoder(
133
+ model=build_efficient_sam_vits(),
134
+ output="weights/efficient_sam_vits_decoder.onnx",
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
EfficientSAM-main/export_to_torchscript.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
3
+ # from squeeze_sam.build_squeeze_sam import build_squeeze_sam
4
+ import zipfile
5
+ import os
6
+
7
+ # Efficient SAM (VIT-tiny)
8
+ torch.jit.save(torch.jit.script(build_efficient_sam_vitt()), "torchscripted_model/efficient_sam_vitt_torchscript.pt")
9
+
10
+ # Efficient SAM (VIT-small)
11
+ # Since VIT-small is >100MB, we store the zip file.
12
+ with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
13
+ zip_ref.extractall("weights")
14
+ torch.jit.save(torch.jit.script(build_efficient_sam_vits()), "torchscripted_model/efficient_sam_vits_torchscript.pt")
15
+
16
+ # Squeeze SAM (UNET)
17
+ # torch.jit.save(torch.jit.script(build_squeeze_sam()), "torchscripted_model/squeeze_sam_torchscript.pt")
EfficientSAM-main/figs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
EfficientSAM-main/figs/examples/demo_box.png ADDED

Git LFS Details

  • SHA256: 7f8a5afbf5e785fcc14ec2190410ee8d72c983b305eb40d8368a072edb6b4689
  • Pointer size: 132 Bytes
  • Size of remote file: 3.58 MB
EfficientSAM-main/figs/examples/demo_everything.png ADDED

Git LFS Details

  • SHA256: 09a0d74c93b81e1968249302212156c110ebf5d5b68f9483b05c5f678b5981fc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
EfficientSAM-main/figs/examples/demo_point.png ADDED

Git LFS Details

  • SHA256: 66b1f7745da8f62eebf89b939acc9fe5c8c21e9a546c05d592a57cf0eeb24f77
  • Pointer size: 132 Bytes
  • Size of remote file: 5.04 MB
EfficientSAM-main/figs/examples/demo_saliency.png ADDED

Git LFS Details

  • SHA256: 9aade441926db1702be187bef4bc270d1dd33f016702e509f6b2613a7546c6a2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.55 MB
EfficientSAM-main/figs/examples/dogs.jpg ADDED
EfficientSAM-main/figs/examples/dogs_efficient_sam_vits_mask.png ADDED
EfficientSAM-main/figs/examples/dogs_efficient_sam_vitt_mask.png ADDED
EfficientSAM-main/figs/examples/dogs_efficientsam_s_mask.png ADDED
EfficientSAM-main/figs/examples/dogs_efficientsam_ti_mask.png ADDED
EfficientSAM-main/figs/examples/dogs_squeeze_sam_mask.png ADDED
EfficientSAM-main/linter.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -e
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ {
5
+ black --version | grep -E "23\." > /dev/null
6
+ } || {
7
+ echo "Linter requires 'black==23.*' !"
8
+ exit 1
9
+ }
10
+
11
+ ISORT_VERSION=$(isort --version-number)
12
+ if [[ "$ISORT_VERSION" != 5.12* ]]; then
13
+ echo "Linter requires isort==5.12.0 !"
14
+ exit 1
15
+ fi
16
+
17
+ echo "Running isort ..."
18
+ isort . --atomic
19
+
20
+ echo "Running black ..."
21
+ black -l 100 .
22
+
23
+ echo "Running flake8 ..."
24
+ if [ -x "$(command -v flake8)" ]; then
25
+ flake8 .
26
+ else
27
+ python3 -m flake8 .
28
+ fi
29
+
30
+ echo "Running mypy..."
31
+
32
+ mypy --exclude 'setup.py|notebooks' .
EfficientSAM-main/notebooks/EfficientSAM_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
EfficientSAM-main/notebooks/EfficientSAM_segment_everything_example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
EfficientSAM-main/onnx_models.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Onnx export code is from [labelme annotation tool](https://github.com/labelmeai/efficient-sam). Huge thanks to Kentaro Wada.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class OnnxEfficientSam(torch.nn.Module):
8
+ def __init__(self, model):
9
+ super().__init__()
10
+ self.model = model
11
+
12
+ @property
13
+ def decoder_max_num_input_points(self):
14
+ return self.model.decoder_max_num_input_points
15
+
16
+ @property
17
+ def image_encoder(self):
18
+ return self.model.image_encoder
19
+
20
+ @property
21
+ def get_image_embeddings(self):
22
+ return self.model.get_image_embeddings
23
+
24
+ @property
25
+ def prompt_encoder(self):
26
+ return self.model.prompt_encoder
27
+
28
+ @property
29
+ def mask_decoder(self):
30
+ return self.model.mask_decoder
31
+
32
+ def forward(
33
+ self,
34
+ batched_images: torch.Tensor,
35
+ batched_points: torch.Tensor,
36
+ batched_point_labels: torch.Tensor,
37
+ ):
38
+ batch_size, _, input_h, input_w = batched_images.shape
39
+ image_embeddings = self.get_image_embeddings(batched_images)
40
+ return self.predict_masks(
41
+ image_embeddings,
42
+ batched_points,
43
+ batched_point_labels,
44
+ multimask_output=True,
45
+ input_h=input_h,
46
+ input_w=input_w,
47
+ output_h=input_h,
48
+ output_w=input_w,
49
+ )
50
+
51
+ def get_rescaled_pts(
52
+ self, batched_points: torch.Tensor, input_h: int, input_w: int
53
+ ):
54
+ return torch.stack(
55
+ [
56
+ batched_points[..., 0] * self.image_encoder.img_size / input_w,
57
+ batched_points[..., 1] * self.image_encoder.img_size / input_h,
58
+ ],
59
+ dim=-1,
60
+ )
61
+
62
+ def predict_masks(
63
+ self,
64
+ image_embeddings: torch.Tensor,
65
+ batched_points: torch.Tensor,
66
+ batched_point_labels: torch.Tensor,
67
+ multimask_output: bool,
68
+ input_h: int,
69
+ input_w: int,
70
+ output_h: int = -1,
71
+ output_w: int = -1,
72
+ ):
73
+ batch_size, max_num_queries, num_pts, _ = batched_points.shape
74
+ num_pts = batched_points.shape[2]
75
+ rescaled_batched_points = self.get_rescaled_pts(
76
+ batched_points, input_h, input_w
77
+ )
78
+
79
+ if num_pts > self.decoder_max_num_input_points:
80
+ rescaled_batched_points = rescaled_batched_points[
81
+ :, :, : self.decoder_max_num_input_points, :
82
+ ]
83
+ batched_point_labels = batched_point_labels[
84
+ :, :, : self.decoder_max_num_input_points
85
+ ]
86
+ elif num_pts < self.decoder_max_num_input_points:
87
+ rescaled_batched_points = F.pad(
88
+ rescaled_batched_points,
89
+ (0, 0, 0, self.decoder_max_num_input_points - num_pts),
90
+ value=-1.0,
91
+ )
92
+ batched_point_labels = F.pad(
93
+ batched_point_labels,
94
+ (0, self.decoder_max_num_input_points - num_pts),
95
+ value=-1.0,
96
+ )
97
+
98
+ sparse_embeddings = self.prompt_encoder(
99
+ rescaled_batched_points.reshape(
100
+ batch_size * max_num_queries, self.decoder_max_num_input_points, 2
101
+ ),
102
+ batched_point_labels.reshape(
103
+ batch_size * max_num_queries, self.decoder_max_num_input_points
104
+ ),
105
+ )
106
+ sparse_embeddings = sparse_embeddings.view(
107
+ batch_size,
108
+ max_num_queries,
109
+ sparse_embeddings.shape[1],
110
+ sparse_embeddings.shape[2],
111
+ )
112
+ low_res_masks, iou_predictions = self.mask_decoder(
113
+ image_embeddings,
114
+ self.prompt_encoder.get_dense_pe(),
115
+ sparse_prompt_embeddings=sparse_embeddings,
116
+ multimask_output=multimask_output,
117
+ )
118
+ _, num_predictions, low_res_size, _ = low_res_masks.shape
119
+
120
+ if output_w > 0 and output_h > 0:
121
+ output_masks = F.interpolate(
122
+ low_res_masks,
123
+ (output_h, output_w),
124
+ # NOTE: "bicubic" is inefficient on onnx
125
+ mode="bilinear",
126
+ )
127
+ output_masks = torch.reshape(
128
+ output_masks,
129
+ (batch_size, max_num_queries, num_predictions, output_h, output_w),
130
+ )
131
+ else:
132
+ output_masks = torch.reshape(
133
+ low_res_masks,
134
+ (
135
+ batch_size,
136
+ max_num_queries,
137
+ num_predictions,
138
+ low_res_size,
139
+ low_res_size,
140
+ ),
141
+ )
142
+ iou_predictions = torch.reshape(
143
+ iou_predictions, (batch_size, max_num_queries, num_predictions)
144
+ )
145
+ return output_masks, iou_predictions, low_res_masks
146
+
147
+
148
+ class OnnxEfficientSamEncoder(OnnxEfficientSam):
149
+ def forward(self, batched_images: torch.Tensor):
150
+ return self.model.get_image_embeddings(batched_images)
151
+
152
+
153
+ class OnnxEfficientSamDecoder(OnnxEfficientSam):
154
+ def forward(
155
+ self, image_embeddings, batched_points, batched_point_labels, orig_im_size
156
+ ):
157
+ return self.predict_masks(
158
+ image_embeddings=image_embeddings,
159
+ batched_points=batched_points,
160
+ batched_point_labels=batched_point_labels,
161
+ multimask_output=True,
162
+ input_h=orig_im_size[0],
163
+ input_w=orig_im_size[1],
164
+ output_h=orig_im_size[0],
165
+ output_w=orig_im_size[1],
166
+ )
EfficientSAM-main/setup.cfg ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ line_length=100
3
+ multi_line_output=3
4
+ include_trailing_comma=True
5
+ known_standard_library=numpy,setuptools
6
+ skip_glob=*/__init__.py
7
+ known_myself=efficient_sam
8
+ known_third_party=matplotlib,torch,torchvision,black,isort
9
+ no_lines_before=STDLIB,THIRDPARTY
10
+ sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
11
+ default_section=FIRSTPARTY
EfficientSAM-main/setup.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import find_packages, setup
8
+
9
+ setup(
10
+ name="efficient_sam",
11
+ version="1.0",
12
+ install_requires=[],
13
+ packages=find_packages(exclude="notebooks"),
14
+ extras_require={
15
+ "all": ["matplotlib", "onnx", "onnxruntime"],
16
+ "dev": ["flake8", "isort", "black", "mypy"],
17
+ },
18
+ )
EfficientSAM-main/torchscripted_model/efficient_sam_vitt_torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d300fa20a94f40bf307616a3979f4d4b6a4a347edd08ed37d69535c2185188e
3
+ size 41074768
EfficientSAM-main/weights/efficient_sam_vits.pt.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1801de05adeea87a6b779b0bedf3ab6751e03c21facb82d2c660867c02813fc
3
+ size 98304933
EfficientSAM-main/weights/efficient_sam_vitt.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143c3198a7b2a15f23c21cdb723432fb3fbcdbabbdad3483cf3babd8b95c1397
3
+ size 41365520
EfficientSAM-main/weights/efficient_sam_vitt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dff858b19600a46461cbb7de98f796b23a7a888d9f5e34c0b033f7d6eb9e4e6a
3
+ size 40982470
EfficientSAM-main/weights/efficient_sam_vitt_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a62f8fa5ea080447c0689418d69e58f1e83e0b7adf9c142e2bd9bcc8045c0b11
3
+ size 16565728
EfficientSAM-main/weights/efficient_sam_vitt_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84ed466ffcc5c1f8d08409bc34a23bb364ab2c15e402cb12d4335a42be0e0951
3
+ size 24799761