zhemai28 commited on
Commit
181f304
Β·
1 Parent(s): a9ddcdb

faster points

Browse files
Files changed (2) hide show
  1. app.py +49 -183
  2. app_orig_0215.py +272 -0
app.py CHANGED
@@ -9,24 +9,18 @@ import matplotlib.pyplot as plt
9
  import torch
10
  from transformers import SamModel, SamProcessor
11
 
12
- import os
13
-
14
 
15
- # Define variables
16
- path = os.getcwd()
17
- font_path = r'{}/arial.ttf'.format(path)
18
- print(font_path)
19
 
20
- # Load the pre-trained model - FastSAM
21
- # fastsam_model = FastSAM('./FastSAM-s.pt')
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
24
- processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
25
-
26
- # Points
27
- global_points = []
28
- global_point_label = []
29
- previous_box_points = 0
30
 
31
  # Description
32
  title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"
@@ -34,85 +28,23 @@ title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</fon
34
  instruction = """ # Instruction
35
  This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
36
  πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
37
- πŸ”₯ Step 2: Add positive (Add mask), negative (Remove Area), and bounding box for the food \n
38
  πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
39
  πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
40
  πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
41
  """
42
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
43
 
 
44
 
45
  def read_image(url):
46
  response = requests.get(url)
47
  img = Image.open(BytesIO(response.content))
48
-
49
- global global_points
50
- global global_point_label
51
-
52
- global_points = []
53
- global_point_label = []
54
- return img
55
-
56
- # def show_mask(mask, ax, random_color=False):
57
- # if random_color:
58
- # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
59
- # else:
60
- # color = np.array([30/255, 144/255, 255/255, 0.6])
61
- # h, w = mask.shape[-2:]
62
- # mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
63
- # ax.imshow(mask_image)
64
-
65
- # def show_points(coords, labels, ax, marker_size=375):
66
- # pos_points = coords[labels==1]
67
- # neg_points = coords[labels==0]
68
- # ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
69
- # ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
70
-
71
- # def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
72
- # masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
73
- # scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
74
- # #
75
- # input_points = np.array(input_points)
76
- # labels = np.array(input_labels)
77
- # #
78
- # mask = mask.cpu().detach()
79
- # plt.imshow(np.array(raw_image))
80
- # ax = plt.gca()
81
- # show_mask(mask, ax)
82
- # show_points(input_points, labels, ax, marker_size=375)
83
- # ax.axis("off")
84
-
85
- # save_path = args.output
86
- # if not os.path.exists(save_path):
87
- # os.makedirs(save_path)
88
- # plt.axis("off")
89
- # fig = plt.gcf()
90
- # plt.draw()
91
-
92
- # try:
93
- # buf = fig.canvas.tostring_rgb()
94
- # except AttributeError:
95
- # fig.canvas.draw()
96
- # buf = fig.canvas.tostring_rgb()
97
-
98
- # cols, rows = fig.canvas.get_width_height()
99
- # img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
100
- # cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
101
-
102
- def format_prompt_points(points, labels):
103
- prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
104
- point_labels = [l for l in labels if l != 9]
105
- #
106
- prompt_boxes = None
107
- if len(point_labels) < len(labels):
108
- prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
109
- return prompt_points, point_labels, prompt_boxes
110
-
111
- # def get_mask_image(raw_image, mask):
112
- # tmp_mask = np.array(mask)
113
- # tmp_img_arr = np.array(raw_image)
114
- # tmp_img_arr[tmp_mask == False] = [255,255,255]
115
- # return tmp_img_arr
116
 
117
  def get_mask_image(raw_image, mask):
118
  tmp_mask = np.array(mask * 1)
@@ -123,29 +55,32 @@ def get_mask_image(raw_image, mask):
123
  tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
124
  return tmp_img_arr
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def segment_with_points(
128
- input,
129
- input_size=1024,
130
- iou_threshold=0.7,
131
- conf_threshold=0.25,
132
- better_quality=False,
133
- withContours=True,
134
- use_retina=True,
135
- mask_random_color=True,
136
  ):
137
- global global_points
138
- global global_point_label
139
-
140
- # read image
141
- raw_image = Image.open(requests.get(input, stream=True).raw).convert("RGB")
142
 
143
- # get prompts
144
- prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
 
 
145
  print(prompt_points, point_labels, prompt_boxes)
146
  # segment
147
- inputs = processor(raw_image,
148
- input_boxes = prompt_boxes,
149
  input_points=[[prompt_points]],
150
  input_labels=[point_labels],
151
  return_tensors="pt").to(device)
@@ -155,74 +90,15 @@ def segment_with_points(
155
  masks = processor.image_processor.post_process_masks(
156
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
157
  scores = outputs.iou_scores
158
-
159
- # only show the first mask
160
- # fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
161
- mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
162
  mask_img1, mask_img2, mask_img3 = mask_images
163
  # return fig, None
164
  return mask_img1, mask_img2, mask_img3
165
 
166
- def find_font_size(text, font_path, image, target_width_ratio):
167
- tested_font_size = 100
168
- tested_font = ImageFont.truetype(font_path, tested_font_size)
169
- observed_width = get_text_size(text, image, tested_font)
170
- estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
171
- return round(estimated_font_size)
172
-
173
- def get_text_size(text, image, font):
174
- im = Image.new('RGB', (image.width, image.height))
175
- draw = ImageDraw.Draw(im)
176
- return draw.textlength(text, font)
177
-
178
-
179
- def get_points_with_draw(image, label, evt: gr.SelectData):
180
- global global_points
181
- global global_point_label
182
- global previous_box_points
183
-
184
- x, y = evt.index[0], evt.index[1]
185
- point_radius = 15
186
- point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
187
- global_points.append([x, y])
188
- global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
189
-
190
- print(x, y, label)
191
- print(previous_box_points)
192
-
193
- draw = ImageDraw.Draw(image)
194
- if label != 'Bounding Box':
195
- draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
196
- else:
197
- if (previous_box_points == 0) | (previous_box_points%2 == 0):
198
- target_width_ratio = 0.9
199
- text = "Please Click Another Point For Bounding Box"
200
- font_size = find_font_size(text, font_path, image, target_width_ratio)
201
- font = ImageFont.truetype(font_path, font_size)
202
- draw.text((x, y), text, fill = (0,0,0), font = font)
203
- else:
204
- [previous_x, previous_y] = global_points[-2]
205
- print((previous_x, previous_y), (x, y))
206
- draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
207
- previous_box_points += 1
208
- return image
209
-
210
  def clear():
211
- global global_points
212
- global global_point_label
213
-
214
- global_points = []
215
- global_point_label = []
216
- previous_box_points = 0
217
  return None, None, None, None
218
 
219
-
220
- # Configure layout
221
- cond_img_p = gr.Image(label="Input with points", type='pil')
222
- segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
223
- segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
224
- segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")
225
-
226
  with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
227
  with gr.Row():
228
  with gr.Column(scale=1):
@@ -231,42 +107,32 @@ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
231
  image_url = gr.Textbox(label="Enter Image URL",
232
  value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
233
  run_with_url = gr.Button("Upload Image")
 
 
234
  with gr.Column(scale=1):
235
  gr.Markdown(instruction)
236
 
237
  # Images
238
  with gr.Row(variant="panel"):
239
  with gr.Column(scale=0):
240
- cond_img_p.render()
241
- segm_img_p2.render()
242
  with gr.Column(scale=0):
243
- segm_img_p1.render()
244
- segm_img_p3.render()
245
-
246
- # Submit & Clear
247
- with gr.Row():
248
- with gr.Column():
249
- add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"],
250
- value="Add Mask",
251
- label="Point label")
252
- with gr.Column():
253
- segment_btn_p = gr.Button("Segment with prompts", variant='primary')
254
- clear_btn_p = gr.Button("Clear points", variant='secondary')
255
 
256
  # Define interaction relationship
257
  run_with_url.click(read_image,
258
  inputs=[image_url],
259
  # outputs=[segm_img_p, cond_img_p])
260
- outputs=[cond_img_p])
261
 
262
- cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
263
-
264
- segment_btn_p.click(segment_with_points,
265
- inputs=[image_url],
266
  # outputs=[segm_img_p, cond_img_p])
267
- outputs=[segm_img_p1, segm_img_p2, segm_img_p3])
268
 
269
- clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p1, segm_img_p2, segm_img_p3])
270
 
271
  demo.queue()
272
  demo.launch()
 
9
  import torch
10
  from transformers import SamModel, SamProcessor
11
 
12
+ from gradio_image_prompter import ImagePrompter
 
13
 
14
+ import os
 
 
 
15
 
16
+ # define variables
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ # model_id = "facebook/sam-vit-huge" #60s
19
+ model_id = 'Zigeng/SlimSAM-uniform-50' #50s
20
+ # model_id = "facebook/sam-vit-base" #50s
21
+ # model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
22
+ model = SamModel.from_pretrained(model_id).to(device)
23
+ processor = SamProcessor.from_pretrained(model_id)
 
24
 
25
  # Description
26
  title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"
 
28
  instruction = """ # Instruction
29
  This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
30
  πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
31
+ πŸ”₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n
32
  πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
33
  πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
34
  πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
35
  """
36
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
37
 
38
+ # functions
39
 
40
  def read_image(url):
41
  response = requests.get(url)
42
  img = Image.open(BytesIO(response.content))
43
+ formatted_image = {
44
+ "image": np.array(img),
45
+ "points": [],
46
+ } # Create the correct format
47
+ return formatted_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def get_mask_image(raw_image, mask):
50
  tmp_mask = np.array(mask * 1)
 
55
  tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
56
  return tmp_img_arr
57
 
58
+ def format_prompt_points(points):
59
+ prompt_points = []
60
+ point_labels = []
61
+ prompt_boxes = []
62
+ for point in points:
63
+ print(point)
64
+ if point[2] == 2.0 and point[5] == 3.0:
65
+ prompt_boxes.append([point[0], point[1], point[3], point[4]])
66
+ else:
67
+ prompt_points.append([point[0], point[1]])
68
+ label = 1 if point[2] == 1.0 else 0
69
+ point_labels.append(label)
70
+ return prompt_points, point_labels, prompt_boxes
71
 
72
  def segment_with_points(
73
+ prompts
 
 
 
 
 
 
 
74
  ):
 
 
 
 
 
75
 
76
+ image = np.array(prompts["image"]) # Convert the image to a numpy array
77
+ points = prompts["points"] # Get the points from prompts
78
+ #
79
+ prompt_points, point_labels, prompt_boxes = format_prompt_points(points)
80
  print(prompt_points, point_labels, prompt_boxes)
81
  # segment
82
+ inputs = processor(image,
83
+ input_boxes = [prompt_boxes],
84
  input_points=[[prompt_points]],
85
  input_labels=[point_labels],
86
  return_tensors="pt").to(device)
 
90
  masks = processor.image_processor.post_process_masks(
91
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
92
  scores = outputs.iou_scores
93
+ #
94
+ mask_images = [get_mask_image(image, m) for m in masks[0][0]]
 
 
95
  mask_img1, mask_img2, mask_img3 = mask_images
96
  # return fig, None
97
  return mask_img1, mask_img2, mask_img3
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def clear():
 
 
 
 
 
 
100
  return None, None, None, None
101
 
 
 
 
 
 
 
 
102
  with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
103
  with gr.Row():
104
  with gr.Column(scale=1):
 
107
  image_url = gr.Textbox(label="Enter Image URL",
108
  value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
109
  run_with_url = gr.Button("Upload Image")
110
+ segment_btn = gr.Button("Segment with prompts", variant='primary')
111
+ clear_btn = gr.Button("Clear points", variant='secondary')
112
  with gr.Column(scale=1):
113
  gr.Markdown(instruction)
114
 
115
  # Images
116
  with gr.Row(variant="panel"):
117
  with gr.Column(scale=0):
118
+ candidate_pic = ImagePrompter(show_label=False)
119
+ segpic_output1 = gr.Image(format="png")
120
  with gr.Column(scale=0):
121
+ segpic_output2 = gr.Image(format="png")
122
+ segpic_output3 = gr.Image(format="png")
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Define interaction relationship
125
  run_with_url.click(read_image,
126
  inputs=[image_url],
127
  # outputs=[segm_img_p, cond_img_p])
128
+ outputs=[candidate_pic])
129
 
130
+ segment_btn.click(segment_with_points,
131
+ inputs=candidate_pic,
 
 
132
  # outputs=[segm_img_p, cond_img_p])
133
+ outputs=[segpic_output1, segpic_output2, segpic_output3])
134
 
135
+ clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
136
 
137
  demo.queue()
138
  demo.launch()
app_orig_0215.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import ImageDraw, Image, ImageFont
4
+ import numpy as np
5
+ import requests
6
+ from io import BytesIO
7
+
8
+ import matplotlib.pyplot as plt
9
+ import torch
10
+ from transformers import SamModel, SamProcessor
11
+
12
+ import os
13
+
14
+
15
+ # Define variables
16
+ path = os.getcwd()
17
+ font_path = r'{}/arial.ttf'.format(path)
18
+ print(font_path)
19
+
20
+ # Load the pre-trained model - FastSAM
21
+ # fastsam_model = FastSAM('./FastSAM-s.pt')
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
24
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
25
+
26
+ # Points
27
+ global_points = []
28
+ global_point_label = []
29
+ previous_box_points = 0
30
+
31
+ # Description
32
+ title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"
33
+
34
+ instruction = """ # Instruction
35
+ This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
36
+ πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
37
+ πŸ”₯ Step 2: Add positive (Add mask), negative (Remove Area), and bounding box for the food \n
38
+ πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
39
+ πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
40
+ πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
41
+ """
42
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
43
+
44
+
45
+ def read_image(url):
46
+ response = requests.get(url)
47
+ img = Image.open(BytesIO(response.content))
48
+
49
+ global global_points
50
+ global global_point_label
51
+
52
+ global_points = []
53
+ global_point_label = []
54
+ return img
55
+
56
+ # def show_mask(mask, ax, random_color=False):
57
+ # if random_color:
58
+ # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
59
+ # else:
60
+ # color = np.array([30/255, 144/255, 255/255, 0.6])
61
+ # h, w = mask.shape[-2:]
62
+ # mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
63
+ # ax.imshow(mask_image)
64
+
65
+ # def show_points(coords, labels, ax, marker_size=375):
66
+ # pos_points = coords[labels==1]
67
+ # neg_points = coords[labels==0]
68
+ # ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
69
+ # ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
70
+
71
+ # def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
72
+ # masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
73
+ # scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
74
+ # #
75
+ # input_points = np.array(input_points)
76
+ # labels = np.array(input_labels)
77
+ # #
78
+ # mask = mask.cpu().detach()
79
+ # plt.imshow(np.array(raw_image))
80
+ # ax = plt.gca()
81
+ # show_mask(mask, ax)
82
+ # show_points(input_points, labels, ax, marker_size=375)
83
+ # ax.axis("off")
84
+
85
+ # save_path = args.output
86
+ # if not os.path.exists(save_path):
87
+ # os.makedirs(save_path)
88
+ # plt.axis("off")
89
+ # fig = plt.gcf()
90
+ # plt.draw()
91
+
92
+ # try:
93
+ # buf = fig.canvas.tostring_rgb()
94
+ # except AttributeError:
95
+ # fig.canvas.draw()
96
+ # buf = fig.canvas.tostring_rgb()
97
+
98
+ # cols, rows = fig.canvas.get_width_height()
99
+ # img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
100
+ # cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
101
+
102
+ def format_prompt_points(points, labels):
103
+ prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
104
+ point_labels = [l for l in labels if l != 9]
105
+ #
106
+ prompt_boxes = None
107
+ if len(point_labels) < len(labels):
108
+ prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
109
+ return prompt_points, point_labels, prompt_boxes
110
+
111
+ # def get_mask_image(raw_image, mask):
112
+ # tmp_mask = np.array(mask)
113
+ # tmp_img_arr = np.array(raw_image)
114
+ # tmp_img_arr[tmp_mask == False] = [255,255,255]
115
+ # return tmp_img_arr
116
+
117
+ def get_mask_image(raw_image, mask):
118
+ tmp_mask = np.array(mask * 1)
119
+ tmp_mask[tmp_mask == 1] = 255
120
+ tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
121
+ #
122
+ tmp_img_arr = np.array(raw_image)
123
+ tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
124
+ return tmp_img_arr
125
+
126
+
127
+ def segment_with_points(
128
+ input,
129
+ input_size=1024,
130
+ iou_threshold=0.7,
131
+ conf_threshold=0.25,
132
+ better_quality=False,
133
+ withContours=True,
134
+ use_retina=True,
135
+ mask_random_color=True,
136
+ ):
137
+ global global_points
138
+ global global_point_label
139
+
140
+ # read image
141
+ raw_image = Image.open(requests.get(input, stream=True).raw).convert("RGB")
142
+
143
+ # get prompts
144
+ prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
145
+ print(prompt_points, point_labels, prompt_boxes)
146
+ # segment
147
+ inputs = processor(raw_image,
148
+ input_boxes = prompt_boxes,
149
+ input_points=[[prompt_points]],
150
+ input_labels=[point_labels],
151
+ return_tensors="pt").to(device)
152
+ with torch.no_grad():
153
+ outputs = model(**inputs)
154
+ #
155
+ masks = processor.image_processor.post_process_masks(
156
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
157
+ scores = outputs.iou_scores
158
+
159
+ # only show the first mask
160
+ # fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
161
+ mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
162
+ mask_img1, mask_img2, mask_img3 = mask_images
163
+ # return fig, None
164
+ return mask_img1, mask_img2, mask_img3
165
+
166
+ def find_font_size(text, font_path, image, target_width_ratio):
167
+ tested_font_size = 100
168
+ tested_font = ImageFont.truetype(font_path, tested_font_size)
169
+ observed_width = get_text_size(text, image, tested_font)
170
+ estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
171
+ return round(estimated_font_size)
172
+
173
+ def get_text_size(text, image, font):
174
+ im = Image.new('RGB', (image.width, image.height))
175
+ draw = ImageDraw.Draw(im)
176
+ return draw.textlength(text, font)
177
+
178
+
179
+ def get_points_with_draw(image, label, evt: gr.SelectData):
180
+ global global_points
181
+ global global_point_label
182
+ global previous_box_points
183
+
184
+ x, y = evt.index[0], evt.index[1]
185
+ point_radius = 15
186
+ point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
187
+ global_points.append([x, y])
188
+ global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
189
+
190
+ print(x, y, label)
191
+ print(previous_box_points)
192
+
193
+ draw = ImageDraw.Draw(image)
194
+ if label != 'Bounding Box':
195
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
196
+ else:
197
+ if (previous_box_points == 0) | (previous_box_points%2 == 0):
198
+ target_width_ratio = 0.9
199
+ text = "Please Click Another Point For Bounding Box"
200
+ font_size = find_font_size(text, font_path, image, target_width_ratio)
201
+ font = ImageFont.truetype(font_path, font_size)
202
+ draw.text((x, y), text, fill = (0,0,0), font = font)
203
+ else:
204
+ [previous_x, previous_y] = global_points[-2]
205
+ print((previous_x, previous_y), (x, y))
206
+ draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
207
+ previous_box_points += 1
208
+ return image
209
+
210
+ def clear():
211
+ global global_points
212
+ global global_point_label
213
+
214
+ global_points = []
215
+ global_point_label = []
216
+ previous_box_points = 0
217
+ return None, None, None, None
218
+
219
+
220
+ # Configure layout
221
+ cond_img_p = gr.Image(label="Input with points", type='pil')
222
+ segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
223
+ segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
224
+ segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")
225
+
226
+ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
227
+ with gr.Row():
228
+ with gr.Column(scale=1):
229
+ gr.Markdown(title)
230
+ gr.Markdown("")
231
+ image_url = gr.Textbox(label="Enter Image URL",
232
+ value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
233
+ run_with_url = gr.Button("Upload Image")
234
+ with gr.Column(scale=1):
235
+ gr.Markdown(instruction)
236
+
237
+ # Images
238
+ with gr.Row(variant="panel"):
239
+ with gr.Column(scale=0):
240
+ cond_img_p.render()
241
+ segm_img_p2.render()
242
+ with gr.Column(scale=0):
243
+ segm_img_p1.render()
244
+ segm_img_p3.render()
245
+
246
+ # Submit & Clear
247
+ with gr.Row():
248
+ with gr.Column():
249
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"],
250
+ value="Add Mask",
251
+ label="Point label")
252
+ with gr.Column():
253
+ segment_btn_p = gr.Button("Segment with prompts", variant='primary')
254
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
255
+
256
+ # Define interaction relationship
257
+ run_with_url.click(read_image,
258
+ inputs=[image_url],
259
+ # outputs=[segm_img_p, cond_img_p])
260
+ outputs=[cond_img_p])
261
+
262
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
263
+
264
+ segment_btn_p.click(segment_with_points,
265
+ inputs=[image_url],
266
+ # outputs=[segm_img_p, cond_img_p])
267
+ outputs=[segm_img_p1, segm_img_p2, segm_img_p3])
268
+
269
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p1, segm_img_p2, segm_img_p3])
270
+
271
+ demo.queue()
272
+ demo.launch()