CuriousDolphin An-619 commited on
Commit
2ad5f92
·
0 Parent(s):

Duplicate from An-619/FastSAM

Browse files

Co-authored-by: Yongqi An <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pyo
3
+ *.pyd
4
+ .DS_Store
5
+ .idea
6
+ gradio_cached_examples
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FastSAM
3
+ emoji: 🐠
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app_gradio.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: An-619/FastSAM
12
+ ---
13
+
14
+ # Fast Segment Anything
15
+
16
+ Official PyTorch Implementation of the <a href="https://github.com/CASIA-IVA-Lab/FastSAM">.
17
+
18
+ The **Fast Segment Anything Model(FastSAM)** is a CNN Segment Anything Model trained by only 2% of the SA-1B dataset published by SAM authors. The FastSAM achieve a comparable performance with
19
+ the SAM method at **50× higher run-time speed**.
20
+
21
+
22
+ ## License
23
+
24
+ The model is licensed under the [Apache 2.0 license](LICENSE).
25
+
26
+
27
+ ## Acknowledgement
28
+
29
+ - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes.
30
+ - [YOLOv8](https://github.com/ultralytics/ultralytics) provides codes and pre-trained models.
31
+ - [YOLACT](https://arxiv.org/abs/2112.10003) provides powerful instance segmentation method.
32
+ - [Grounded-Segment-Anything](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) provides a useful web demo template.
33
+
34
+ ## Citing FastSAM
35
+
36
+ If you find this project useful for your research, please consider citing the following BibTeX entry.
37
+
38
+ ```
39
+ @misc{zhao2023fast,
40
+ title={Fast Segment Anything},
41
+ author={Xu Zhao and Wenchao Ding and Yongqi An and Yinglong Du and Tao Yu and Min Li and Ming Tang and Jinqiao Wang},
42
+ year={2023},
43
+ eprint={2306.12156},
44
+ archivePrefix={arXiv},
45
+ primaryClass={cs.CV}
46
+ }
47
+ ```
app_gradio.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
+ from PIL import ImageDraw
7
+ import numpy as np
8
+
9
+ # Load the pre-trained model
10
+ model = YOLO('./weights/FastSAM.pt')
11
+
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ # Description
21
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
+
23
+ news = """ # 📖 News
24
+ 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
25
+
26
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
27
+
28
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
29
+ """
30
+
31
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
32
+
33
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
34
+
35
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
36
+
37
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
38
+
39
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
40
+
41
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
42
+
43
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
44
+
45
+ """
46
+
47
+ description_p = """ # 🎯 Instructions for points mode
48
+ This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
49
+
50
+ 1. Upload an image or choose an example.
51
+
52
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
53
+
54
+ 3. Add points one by one on the image.
55
+
56
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
57
+
58
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
59
+
60
+ """
61
+
62
+ examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
63
+ ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
64
+
65
+ default_example = examples[0]
66
+
67
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
68
+
69
+
70
+ def segment_everything(
71
+ input,
72
+ input_size=1024,
73
+ iou_threshold=0.7,
74
+ conf_threshold=0.25,
75
+ better_quality=False,
76
+ withContours=True,
77
+ use_retina=True,
78
+ text="",
79
+ mask_random_color=True,
80
+ ):
81
+ input_size = int(input_size) # 确保 imgsz 是整数
82
+ # Thanks for the suggestion by hysts in HuggingFace.
83
+ w, h = input.size
84
+ scale = input_size / max(w, h)
85
+ new_w = int(w * scale)
86
+ new_h = int(h * scale)
87
+ input = input.resize((new_w, new_h))
88
+
89
+ results = model(input,
90
+ device=device,
91
+ retina_masks=True,
92
+ iou=iou_threshold,
93
+ conf=conf_threshold,
94
+ imgsz=input_size,)
95
+
96
+ if len(text) > 0:
97
+ results = format_results(results[0], 0)
98
+ annotations, _ = text_prompt(results, text, input, device=device)
99
+ annotations = np.array([annotations])
100
+ else:
101
+ annotations = results[0].masks.data
102
+
103
+ fig = fast_process(annotations=annotations,
104
+ image=input,
105
+ device=device,
106
+ scale=(1024 // input_size),
107
+ better_quality=better_quality,
108
+ mask_random_color=mask_random_color,
109
+ bbox=None,
110
+ use_retina=use_retina,
111
+ withContours=withContours,)
112
+ return fig
113
+
114
+
115
+ def segment_with_points(
116
+ input,
117
+ input_size=1024,
118
+ iou_threshold=0.7,
119
+ conf_threshold=0.25,
120
+ better_quality=False,
121
+ withContours=True,
122
+ use_retina=True,
123
+ mask_random_color=True,
124
+ ):
125
+ global global_points
126
+ global global_point_label
127
+
128
+ input_size = int(input_size) # 确保 imgsz 是整数
129
+ # Thanks for the suggestion by hysts in HuggingFace.
130
+ w, h = input.size
131
+ scale = input_size / max(w, h)
132
+ new_w = int(w * scale)
133
+ new_h = int(h * scale)
134
+ input = input.resize((new_w, new_h))
135
+
136
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
137
+
138
+ results = model(input,
139
+ device=device,
140
+ retina_masks=True,
141
+ iou=iou_threshold,
142
+ conf=conf_threshold,
143
+ imgsz=input_size,)
144
+
145
+ results = format_results(results[0], 0)
146
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
147
+ annotations = np.array([annotations])
148
+
149
+ fig = fast_process(annotations=annotations,
150
+ image=input,
151
+ device=device,
152
+ scale=(1024 // input_size),
153
+ better_quality=better_quality,
154
+ mask_random_color=mask_random_color,
155
+ bbox=None,
156
+ use_retina=use_retina,
157
+ withContours=withContours,)
158
+
159
+ global_points = []
160
+ global_point_label = []
161
+ return fig, None
162
+
163
+
164
+ def get_points_with_draw(image, label, evt: gr.SelectData):
165
+ global global_points
166
+ global global_point_label
167
+
168
+ x, y = evt.index[0], evt.index[1]
169
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
170
+ global_points.append([x, y])
171
+ global_point_label.append(1 if label == 'Add Mask' else 0)
172
+
173
+ print(x, y, label == 'Add Mask')
174
+
175
+ # 创建一个可以在图像上绘图的对象
176
+ draw = ImageDraw.Draw(image)
177
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
178
+ return image
179
+
180
+
181
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
182
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
183
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
184
+
185
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
186
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
187
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
188
+
189
+ global_points = []
190
+ global_point_label = []
191
+
192
+ input_size_slider_e = gr.components.Slider(minimum=512,
193
+ maximum=1024,
194
+ value=1024,
195
+ step=64,
196
+ label='Input_size',
197
+ info='Our model was trained on a size of 1024')
198
+
199
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
200
+ with gr.Row():
201
+ with gr.Column(scale=1):
202
+ # Title
203
+ gr.Markdown(title)
204
+
205
+ with gr.Column(scale=1):
206
+ # News
207
+ gr.Markdown(news)
208
+
209
+ with gr.Tab("Everything mode"):
210
+ # Images
211
+ with gr.Row(variant="panel"):
212
+ with gr.Column(scale=1):
213
+ cond_img_e.render()
214
+
215
+ with gr.Column(scale=1):
216
+ segm_img_e.render()
217
+
218
+ # Submit & Clear
219
+ with gr.Row():
220
+ with gr.Column():
221
+ input_size_slider_e.render()
222
+
223
+ with gr.Row():
224
+ contour_check_e = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
225
+
226
+ with gr.Column():
227
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
228
+ clear_btn_e = gr.Button("Clear", variant="secondary")
229
+
230
+ gr.Markdown("Try some of the examples below ⬇️")
231
+ gr.Examples(examples=examples,
232
+ inputs=[cond_img_e],
233
+ outputs=segm_img_e,
234
+ fn=segment_everything,
235
+ cache_examples=True,
236
+ examples_per_page=4)
237
+
238
+ with gr.Column():
239
+ with gr.Accordion("Advanced options", open=False):
240
+ # text_box = gr.Textbox(label="text prompt")
241
+ iou_threshold_e = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
242
+ conf_threshold_e = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
243
+ with gr.Row():
244
+ mor_check_e = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
245
+ with gr.Column():
246
+ retina_check_e = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
247
+ # Description
248
+ gr.Markdown(description_e)
249
+
250
+ with gr.Tab("Points mode"):
251
+ # Images
252
+ with gr.Row(variant="panel"):
253
+ with gr.Column(scale=1):
254
+ cond_img_p.render()
255
+
256
+ with gr.Column(scale=1):
257
+ segm_img_p.render()
258
+
259
+ # Submit & Clear
260
+ with gr.Row():
261
+ with gr.Column():
262
+ with gr.Row():
263
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
264
+
265
+ with gr.Column():
266
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
267
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
268
+
269
+ gr.Markdown("Try some of the examples below ⬇️")
270
+ gr.Examples(examples=examples,
271
+ inputs=[cond_img_p],
272
+ # outputs=segm_img_p,
273
+ # fn=segment_with_points,
274
+ # cache_examples=True,
275
+ examples_per_page=4)
276
+
277
+ with gr.Column():
278
+ # Description
279
+ gr.Markdown(description_p)
280
+
281
+ with gr.Tab("Text mode"):
282
+ # Images
283
+ with gr.Row(variant="panel"):
284
+ with gr.Column(scale=1):
285
+ cond_img_t.render()
286
+
287
+ with gr.Column(scale=1):
288
+ segm_img_t.render()
289
+
290
+ # Submit & Clear
291
+ with gr.Row():
292
+ with gr.Column():
293
+ input_size_slider_t = gr.components.Slider(minimum=512,
294
+ maximum=1024,
295
+ value=1024,
296
+ step=64,
297
+ label='Input_size',
298
+ info='Our model was trained on a size of 1024')
299
+ with gr.Row():
300
+ with gr.Column():
301
+ contour_check_t = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
302
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
303
+
304
+ with gr.Column():
305
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
306
+ clear_btn_t = gr.Button("Clear", variant="secondary")
307
+
308
+ gr.Markdown("Try some of the examples below ⬇️")
309
+ gr.Examples(examples=["examples/dogs.jpg"],
310
+ inputs=[cond_img_e],
311
+ # outputs=segm_img_e,
312
+ # fn=segment_everything,
313
+ # cache_examples=True,
314
+ examples_per_page=4)
315
+
316
+ with gr.Column():
317
+ with gr.Accordion("Advanced options", open=False):
318
+ iou_threshold_t = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
319
+ conf_threshold_t = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
320
+ with gr.Row():
321
+ mor_check_t = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
322
+ with gr.Column():
323
+ retina_check_t = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
324
+
325
+ # Description
326
+ gr.Markdown(description_e)
327
+
328
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
329
+
330
+ segment_btn_e.click(segment_everything,
331
+ inputs=[
332
+ cond_img_e,
333
+ input_size_slider_e,
334
+ iou_threshold_e,
335
+ conf_threshold_e,
336
+ mor_check_e,
337
+ contour_check_e,
338
+ retina_check_e,
339
+ ],
340
+ outputs=segm_img_e)
341
+
342
+ segment_btn_p.click(segment_with_points,
343
+ inputs=[cond_img_p],
344
+ outputs=[segm_img_p, cond_img_p])
345
+
346
+ segment_btn_t.click(segment_everything,
347
+ inputs=[
348
+ cond_img_t,
349
+ input_size_slider_t,
350
+ iou_threshold_t,
351
+ conf_threshold_t,
352
+ mor_check_t,
353
+ contour_check_t,
354
+ retina_check_t,
355
+ text_box,
356
+ ],
357
+ outputs=segm_img_t)
358
+
359
+ def clear():
360
+ return None, None
361
+
362
+ def clear_text():
363
+ return None, None, None
364
+
365
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
366
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
367
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
368
+
369
+ demo.queue()
370
+ demo.launch()
examples/dogs.jpg ADDED

Git LFS Details

  • SHA256: 49b29517d3a6457bf8bd0b83a80cbeb24c2466bf3e5804bd503ebe60e430d784
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
examples/sa_10039.jpg ADDED

Git LFS Details

  • SHA256: 4a9735583a997fa08e5eb36b3ba8bf17a31771bb2aea71e6d51ab9824c1d141e
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
examples/sa_11025.jpg ADDED

Git LFS Details

  • SHA256: b7edd63aa5121414bc29a760770606d09387561ff990c89f9b82c35803bd20aa
  • Pointer size: 131 Bytes
  • Size of remote file: 988 kB
examples/sa_1309.jpg ADDED

Git LFS Details

  • SHA256: b1012cbfd3ffe4ee0da940dc45961fbd1ce7546bea566f650514ec56d72b0460
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
examples/sa_192.jpg ADDED

Git LFS Details

  • SHA256: dcec4fce91382cbfeb2711fff3caeae183c23cb6d8a6c9e2ca0cd2e8eac39512
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
examples/sa_414.jpg ADDED

Git LFS Details

  • SHA256: 69dbead40b43e54d3bb80fb372c2e241b0f3ff2159d32525433a75153e067c65
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
examples/sa_561.jpg ADDED

Git LFS Details

  • SHA256: 837d725885e427534623dcc7d82ea846fffea046877c94e2e9c5b027d593796b
  • Pointer size: 131 Bytes
  • Size of remote file: 822 kB
examples/sa_862.jpg ADDED

Git LFS Details

  • SHA256: 06efc970f0d95faa6e8c69ee73f2032627569dde1c28bc783faebdaefa5eb2a8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
examples/sa_8776.jpg ADDED

Git LFS Details

  • SHA256: 7d71aea32d9f14122378a0707a4243de968d87b292a20a905351b5eacd924212
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base-----------------------------------
2
+ matplotlib==3.2.2
3
+ numpy
4
+ opencv-python
5
+
6
+ git+https://github.com/openai/CLIP.git
7
+ # Pillow>=7.1.2
8
+ # PyYAML>=5.3.1
9
+ # requests>=2.23.0
10
+ # scipy>=1.4.1
11
+ # torch
12
+ # torchvision
13
+ # tqdm>=4.64.0
14
+
15
+ # pandas>=1.1.4
16
+ # seaborn>=0.11.0
17
+
18
+ # Ultralytics-----------------------------------
19
+ ultralytics==8.0.121
20
+
utils/__init__.py ADDED
File without changes
utils/tools.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ import os
7
+ import sys
8
+ import clip
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ x1 = box[0]
13
+ y1 = box[1]
14
+ x2 = box[0] + box[2]
15
+ y2 = box[1] + box[3]
16
+ return [x1, y1, x2, y2]
17
+
18
+
19
+ def segment_image(image, bbox):
20
+ image_array = np.array(image)
21
+ segmented_image_array = np.zeros_like(image_array)
22
+ x1, y1, x2, y2 = bbox
23
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
26
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
27
+ transparency_mask = np.zeros(
28
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
29
+ )
30
+ transparency_mask[y1:y2, x1:x2] = 255
31
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
32
+ black_image.paste(segmented_image, mask=transparency_mask_image)
33
+ return black_image
34
+
35
+
36
+ def format_results(result, filter=0):
37
+ annotations = []
38
+ n = len(result.masks.data)
39
+ for i in range(n):
40
+ annotation = {}
41
+ mask = result.masks.data[i] == 1.0
42
+
43
+ if torch.sum(mask) < filter:
44
+ continue
45
+ annotation["id"] = i
46
+ annotation["segmentation"] = mask.cpu().numpy()
47
+ annotation["bbox"] = result.boxes.data[i]
48
+ annotation["score"] = result.boxes.conf[i]
49
+ annotation["area"] = annotation["segmentation"].sum()
50
+ annotations.append(annotation)
51
+ return annotations
52
+
53
+
54
+ def filter_masks(annotations): # filter the overlap mask
55
+ annotations.sort(key=lambda x: x["area"], reverse=True)
56
+ to_remove = set()
57
+ for i in range(0, len(annotations)):
58
+ a = annotations[i]
59
+ for j in range(i + 1, len(annotations)):
60
+ b = annotations[j]
61
+ if i != j and j not in to_remove:
62
+ # check if
63
+ if b["area"] < a["area"]:
64
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
65
+ "segmentation"
66
+ ].sum() > 0.8:
67
+ to_remove.add(j)
68
+
69
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
70
+
71
+
72
+ def get_bbox_from_mask(mask):
73
+ mask = mask.astype(np.uint8)
74
+ contours, hierarchy = cv2.findContours(
75
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
76
+ )
77
+ x1, y1, w, h = cv2.boundingRect(contours[0])
78
+ x2, y2 = x1 + w, y1 + h
79
+ if len(contours) > 1:
80
+ for b in contours:
81
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
82
+ # 将多个bbox合并成一个
83
+ x1 = min(x1, x_t)
84
+ y1 = min(y1, y_t)
85
+ x2 = max(x2, x_t + w_t)
86
+ y2 = max(y2, y_t + h_t)
87
+ h = y2 - y1
88
+ w = x2 - x1
89
+ return [x1, y1, x2, y2]
90
+
91
+
92
+ def fast_process(
93
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
94
+ ):
95
+ if isinstance(annotations[0], dict):
96
+ annotations = [annotation["segmentation"] for annotation in annotations]
97
+ result_name = os.path.basename(args.img_path)
98
+ image = cv2.imread(args.img_path)
99
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
+ original_h = image.shape[0]
101
+ original_w = image.shape[1]
102
+ if sys.platform == "darwin":
103
+ plt.switch_backend("TkAgg")
104
+ plt.figure(figsize=(original_w/100, original_h/100))
105
+ # Add subplot with no margin.
106
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
107
+ plt.margins(0, 0)
108
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
109
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
110
+ plt.imshow(image)
111
+ if args.better_quality == True:
112
+ if isinstance(annotations[0], torch.Tensor):
113
+ annotations = np.array(annotations.cpu())
114
+ for i, mask in enumerate(annotations):
115
+ mask = cv2.morphologyEx(
116
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
117
+ )
118
+ annotations[i] = cv2.morphologyEx(
119
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
120
+ )
121
+ if args.device == "cpu":
122
+ annotations = np.array(annotations)
123
+ fast_show_mask(
124
+ annotations,
125
+ plt.gca(),
126
+ random_color=mask_random_color,
127
+ bbox=bbox,
128
+ points=points,
129
+ point_label=args.point_label,
130
+ retinamask=args.retina,
131
+ target_height=original_h,
132
+ target_width=original_w,
133
+ )
134
+ else:
135
+ if isinstance(annotations[0], np.ndarray):
136
+ annotations = torch.from_numpy(annotations)
137
+ fast_show_mask_gpu(
138
+ annotations,
139
+ plt.gca(),
140
+ random_color=args.randomcolor,
141
+ bbox=bbox,
142
+ points=points,
143
+ point_label=args.point_label,
144
+ retinamask=args.retina,
145
+ target_height=original_h,
146
+ target_width=original_w,
147
+ )
148
+ if isinstance(annotations, torch.Tensor):
149
+ annotations = annotations.cpu().numpy()
150
+ if args.withContours == True:
151
+ contour_all = []
152
+ temp = np.zeros((original_h, original_w, 1))
153
+ for i, mask in enumerate(annotations):
154
+ if type(mask) == dict:
155
+ mask = mask["segmentation"]
156
+ annotation = mask.astype(np.uint8)
157
+ if args.retina == False:
158
+ annotation = cv2.resize(
159
+ annotation,
160
+ (original_w, original_h),
161
+ interpolation=cv2.INTER_NEAREST,
162
+ )
163
+ contours, hierarchy = cv2.findContours(
164
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
165
+ )
166
+ for contour in contours:
167
+ contour_all.append(contour)
168
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
169
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
170
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
171
+ plt.imshow(contour_mask)
172
+
173
+ save_path = args.output
174
+ if not os.path.exists(save_path):
175
+ os.makedirs(save_path)
176
+ plt.axis("off")
177
+ fig = plt.gcf()
178
+ plt.draw()
179
+
180
+ try:
181
+ buf = fig.canvas.tostring_rgb()
182
+ except AttributeError:
183
+ fig.canvas.draw()
184
+ buf = fig.canvas.tostring_rgb()
185
+
186
+ cols, rows = fig.canvas.get_width_height()
187
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
188
+ cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
189
+
190
+
191
+ # CPU post process
192
+ def fast_show_mask(
193
+ annotation,
194
+ ax,
195
+ random_color=False,
196
+ bbox=None,
197
+ points=None,
198
+ point_label=None,
199
+ retinamask=True,
200
+ target_height=960,
201
+ target_width=960,
202
+ ):
203
+ msak_sum = annotation.shape[0]
204
+ height = annotation.shape[1]
205
+ weight = annotation.shape[2]
206
+ # 将annotation 按照面积 排序
207
+ areas = np.sum(annotation, axis=(1, 2))
208
+ sorted_indices = np.argsort(areas)
209
+ annotation = annotation[sorted_indices]
210
+
211
+ index = (annotation != 0).argmax(axis=0)
212
+ if random_color == True:
213
+ color = np.random.random((msak_sum, 1, 1, 3))
214
+ else:
215
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
216
+ [30 / 255, 144 / 255, 255 / 255]
217
+ )
218
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
219
+ visual = np.concatenate([color, transparency], axis=-1)
220
+ mask_image = np.expand_dims(annotation, -1) * visual
221
+
222
+ show = np.zeros((height, weight, 4))
223
+ h_indices, w_indices = np.meshgrid(
224
+ np.arange(height), np.arange(weight), indexing="ij"
225
+ )
226
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
227
+ # 使用向量化索引更新show的值
228
+ show[h_indices, w_indices, :] = mask_image[indices]
229
+ if bbox is not None:
230
+ x1, y1, x2, y2 = bbox
231
+ ax.add_patch(
232
+ plt.Rectangle(
233
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
234
+ )
235
+ )
236
+ # draw point
237
+ if points is not None:
238
+ plt.scatter(
239
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
240
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
241
+ s=20,
242
+ c="y",
243
+ )
244
+ plt.scatter(
245
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
246
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
247
+ s=20,
248
+ c="m",
249
+ )
250
+
251
+ if retinamask == False:
252
+ show = cv2.resize(
253
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
254
+ )
255
+ ax.imshow(show)
256
+
257
+
258
+ def fast_show_mask_gpu(
259
+ annotation,
260
+ ax,
261
+ random_color=False,
262
+ bbox=None,
263
+ points=None,
264
+ point_label=None,
265
+ retinamask=True,
266
+ target_height=960,
267
+ target_width=960,
268
+ ):
269
+ msak_sum = annotation.shape[0]
270
+ height = annotation.shape[1]
271
+ weight = annotation.shape[2]
272
+ areas = torch.sum(annotation, dim=(1, 2))
273
+ sorted_indices = torch.argsort(areas, descending=False)
274
+ annotation = annotation[sorted_indices]
275
+ # 找每个位置第一个非零值下标
276
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
277
+ if random_color == True:
278
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
279
+ else:
280
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
281
+ [30 / 255, 144 / 255, 255 / 255]
282
+ ).to(annotation.device)
283
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
284
+ visual = torch.cat([color, transparency], dim=-1)
285
+ mask_image = torch.unsqueeze(annotation, -1) * visual
286
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
287
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
288
+ h_indices, w_indices = torch.meshgrid(
289
+ torch.arange(height), torch.arange(weight), indexing="ij"
290
+ )
291
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
292
+ # 使用向量化索引更新show的值
293
+ show[h_indices, w_indices, :] = mask_image[indices]
294
+ show_cpu = show.cpu().numpy()
295
+ if bbox is not None:
296
+ x1, y1, x2, y2 = bbox
297
+ ax.add_patch(
298
+ plt.Rectangle(
299
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
300
+ )
301
+ )
302
+ # draw point
303
+ if points is not None:
304
+ plt.scatter(
305
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
306
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
307
+ s=20,
308
+ c="y",
309
+ )
310
+ plt.scatter(
311
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
312
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
313
+ s=20,
314
+ c="m",
315
+ )
316
+ if retinamask == False:
317
+ show_cpu = cv2.resize(
318
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
319
+ )
320
+ ax.imshow(show_cpu)
321
+
322
+
323
+ # clip
324
+ @torch.no_grad()
325
+ def retriev(
326
+ model, preprocess, elements, search_text: str, device
327
+ ) -> int:
328
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
329
+ tokenized_text = clip.tokenize([search_text]).to(device)
330
+ stacked_images = torch.stack(preprocessed_images)
331
+ image_features = model.encode_image(stacked_images)
332
+ text_features = model.encode_text(tokenized_text)
333
+ image_features /= image_features.norm(dim=-1, keepdim=True)
334
+ text_features /= text_features.norm(dim=-1, keepdim=True)
335
+ probs = 100.0 * image_features @ text_features.T
336
+ return probs[:, 0].softmax(dim=0)
337
+
338
+
339
+ def crop_image(annotations, image_like):
340
+ if isinstance(image_like, str):
341
+ image = Image.open(image_like)
342
+ else:
343
+ image = image_like
344
+ ori_w, ori_h = image.size
345
+ mask_h, mask_w = annotations[0]["segmentation"].shape
346
+ if ori_w != mask_w or ori_h != mask_h:
347
+ image = image.resize((mask_w, mask_h))
348
+ cropped_boxes = []
349
+ cropped_images = []
350
+ not_crop = []
351
+ filter_id = []
352
+ # annotations, _ = filter_masks(annotations)
353
+ # filter_id = list(_)
354
+ for _, mask in enumerate(annotations):
355
+ if np.sum(mask["segmentation"]) <= 100:
356
+ filter_id.append(_)
357
+ continue
358
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
359
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
360
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
361
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
362
+
363
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
364
+
365
+
366
+ def box_prompt(masks, bbox, target_height, target_width):
367
+ h = masks.shape[1]
368
+ w = masks.shape[2]
369
+ if h != target_height or w != target_width:
370
+ bbox = [
371
+ int(bbox[0] * w / target_width),
372
+ int(bbox[1] * h / target_height),
373
+ int(bbox[2] * w / target_width),
374
+ int(bbox[3] * h / target_height),
375
+ ]
376
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
377
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
378
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
379
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
380
+
381
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
382
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
383
+
384
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
385
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
386
+
387
+ union = bbox_area + orig_masks_area - masks_area
388
+ IoUs = masks_area / union
389
+ max_iou_index = torch.argmax(IoUs)
390
+
391
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
392
+
393
+
394
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
395
+ h = masks[0]["segmentation"].shape[0]
396
+ w = masks[0]["segmentation"].shape[1]
397
+ if h != target_height or w != target_width:
398
+ points = [
399
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
400
+ for point in points
401
+ ]
402
+ onemask = np.zeros((h, w))
403
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
404
+ for i, annotation in enumerate(masks):
405
+ if type(annotation) == dict:
406
+ mask = annotation['segmentation']
407
+ else:
408
+ mask = annotation
409
+ for i, point in enumerate(points):
410
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
411
+ onemask[mask] = 1
412
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
413
+ onemask[mask] = 0
414
+ onemask = onemask >= 1
415
+ return onemask, 0
416
+
417
+
418
+ def text_prompt(annotations, text, img_path, device):
419
+ cropped_boxes, cropped_images, not_crop, filter_id, annotations_ = crop_image(
420
+ annotations, img_path
421
+ )
422
+ clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
423
+ scores = retriev(
424
+ clip_model, preprocess, cropped_boxes, text, device=device
425
+ )
426
+ max_idx = scores.argsort()
427
+ max_idx = max_idx[-1]
428
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
429
+ return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation['segmentation'] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
+ if device == 'cpu':
31
+ annotations = np.array(annotations)
32
+ inner_mask = fast_show_mask(
33
+ annotations,
34
+ plt.gca(),
35
+ random_color=mask_random_color,
36
+ bbox=bbox,
37
+ retinamask=use_retina,
38
+ target_height=original_h,
39
+ target_width=original_w,
40
+ )
41
+ else:
42
+ if isinstance(annotations[0], np.ndarray):
43
+ annotations = torch.from_numpy(annotations)
44
+ inner_mask = fast_show_mask_gpu(
45
+ annotations,
46
+ plt.gca(),
47
+ random_color=mask_random_color,
48
+ bbox=bbox,
49
+ retinamask=use_retina,
50
+ target_height=original_h,
51
+ target_width=original_w,
52
+ )
53
+ if isinstance(annotations, torch.Tensor):
54
+ annotations = annotations.cpu().numpy()
55
+
56
+ if withContours:
57
+ contour_all = []
58
+ temp = np.zeros((original_h, original_w, 1))
59
+ for i, mask in enumerate(annotations):
60
+ if type(mask) == dict:
61
+ mask = mask['segmentation']
62
+ annotation = mask.astype(np.uint8)
63
+ if use_retina == False:
64
+ annotation = cv2.resize(
65
+ annotation,
66
+ (original_w, original_h),
67
+ interpolation=cv2.INTER_NEAREST,
68
+ )
69
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70
+ for contour in contours:
71
+ contour_all.append(contour)
72
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
75
+
76
+ image = image.convert('RGBA')
77
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78
+ image.paste(overlay_inner, (0, 0), overlay_inner)
79
+
80
+ if withContours:
81
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82
+ image.paste(overlay_contour, (0, 0), overlay_contour)
83
+
84
+ return image
85
+
86
+
87
+ # CPU post process
88
+ def fast_show_mask(
89
+ annotation,
90
+ ax,
91
+ random_color=False,
92
+ bbox=None,
93
+ retinamask=True,
94
+ target_height=960,
95
+ target_width=960,
96
+ ):
97
+ mask_sum = annotation.shape[0]
98
+ height = annotation.shape[1]
99
+ weight = annotation.shape[2]
100
+ # 将annotation 按照面积 排序
101
+ areas = np.sum(annotation, axis=(1, 2))
102
+ sorted_indices = np.argsort(areas)[::1]
103
+ annotation = annotation[sorted_indices]
104
+
105
+ index = (annotation != 0).argmax(axis=0)
106
+ if random_color == True:
107
+ color = np.random.random((mask_sum, 1, 1, 3))
108
+ else:
109
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111
+ visual = np.concatenate([color, transparency], axis=-1)
112
+ mask_image = np.expand_dims(annotation, -1) * visual
113
+
114
+ mask = np.zeros((height, weight, 4))
115
+
116
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118
+
119
+ mask[h_indices, w_indices, :] = mask_image[indices]
120
+ if bbox is not None:
121
+ x1, y1, x2, y2 = bbox
122
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
+
124
+ if retinamask == False:
125
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
+
127
+ return mask
128
+
129
+
130
+ def fast_show_mask_gpu(
131
+ annotation,
132
+ ax,
133
+ random_color=False,
134
+ bbox=None,
135
+ retinamask=True,
136
+ target_height=960,
137
+ target_width=960,
138
+ ):
139
+ device = annotation.device
140
+ mask_sum = annotation.shape[0]
141
+ height = annotation.shape[1]
142
+ weight = annotation.shape[2]
143
+ areas = torch.sum(annotation, dim=(1, 2))
144
+ sorted_indices = torch.argsort(areas, descending=False)
145
+ annotation = annotation[sorted_indices]
146
+ # 找每个位置第一个非���值下标
147
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
148
+ if random_color == True:
149
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
+ else:
151
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152
+ [30 / 255, 144 / 255, 255 / 255]
153
+ ).to(device)
154
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155
+ visual = torch.cat([color, transparency], dim=-1)
156
+ mask_image = torch.unsqueeze(annotation, -1) * visual
157
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158
+ mask = torch.zeros((height, weight, 4)).to(device)
159
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161
+ # 使用向量化索引更新show的值
162
+ mask[h_indices, w_indices, :] = mask_image[indices]
163
+ mask_cpu = mask.cpu().numpy()
164
+ if bbox is not None:
165
+ x1, y1, x2, y2 = bbox
166
+ ax.add_patch(
167
+ plt.Rectangle(
168
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
+ )
170
+ )
171
+ if retinamask == False:
172
+ mask_cpu = cv2.resize(
173
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
+ )
175
+ return mask_cpu
weights/CLIP_ViT_B_32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
weights/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063