Ankan Ghosh commited on
Commit
620fb76
·
1 Parent(s): 95dd819

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+
10
+ from PIL import Image
11
+
12
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
+
14
+ # suppress server-side GUI windows
15
+ matplotlib.pyplot.switch_backend('Agg')
16
+
17
+ # setup models
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ sam = sam_model_registry["vit_b"](checkpoint="./sam_vit_b_01ec64.pth")
20
+ sam.to(device=device)
21
+ mask_generator = SamAutomaticMaskGenerator(sam)
22
+ predictor = SamPredictor(sam)
23
+
24
+
25
+ # copied from: https://github.com/facebookresearch/segment-anything
26
+ def show_anns(anns):
27
+ if len(anns) == 0:
28
+ return
29
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
30
+ ax = plt.gca()
31
+ ax.set_autoscale_on(False)
32
+ polygons = []
33
+ color = []
34
+ for ann in sorted_anns:
35
+ m = ann['segmentation']
36
+ img = np.ones((m.shape[0], m.shape[1], 3))
37
+ color_mask = np.random.random((1, 3)).tolist()[0]
38
+ for i in range(3):
39
+ img[:,:,i] = color_mask[i]
40
+ ax.imshow(np.dstack((img, m*0.35)))
41
+
42
+
43
+ # demo function
44
+ def segment_image(input_image):
45
+
46
+ if input_image is not None:
47
+
48
+ # generate masks
49
+ masks = mask_generator.generate(input_image)
50
+
51
+ # add masks to image
52
+ plt.clf()
53
+ ppi = 100
54
+ height, width, _ = input_image.shape
55
+ plt.figure(figsize=(width / ppi, height / ppi)) # convert pixel to inches
56
+ plt.imshow(input_image)
57
+ show_anns(masks)
58
+ plt.axis('off')
59
+
60
+ # save and get figure
61
+ plt.savefig('output_figure.png', bbox_inches='tight')
62
+ output_image = cv2.imread('output_figure.png')
63
+ return Image.fromarray(output_image)
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+
68
+ with gr.Row():
69
+ gr.Markdown("## Segment Anything (by Meta AI Research)")
70
+ with gr.Row():
71
+ gr.Markdown("The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.")
72
+
73
+ with gr.Row():
74
+
75
+ with gr.Column():
76
+ image_input = gr.Image()
77
+ segment_image_button = gr.Button('Generate Mask')
78
+
79
+ with gr.Column():
80
+ image_output = gr.Image()
81
+
82
+ segment_image_button.click(segment_image, inputs=[image_input], outputs=image_output)
83
+
84
+ gr.Examples(
85
+ examples=[
86
+ ['./examples/dog.jpg'],
87
+ ['./examples/groceries.jpg'],
88
+ ['./examples/truck.jpg']
89
+
90
+ ],
91
+ inputs=[image_input],
92
+ outputs=[image_output],
93
+ fn=segment_image,
94
+ #cache_examples=True
95
+ )
96
+
97
+ demo.launch()
examples/cat.jpg ADDED
examples/dog.jpg ADDED
examples/groceries.jpg ADDED
examples/truck.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ matplotlib
4
+ numpy
5
+ torch
6
+ torchvision
7
+ git+https://github.com/facebookresearch/segment-anything.git
sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383