zhemai28 commited on
Commit
9098a03
Β·
1 Parent(s): ab62477

segmentation points

Browse files
Files changed (4) hide show
  1. README.md +6 -5
  2. app.py +263 -48
  3. arial.ttf +0 -0
  4. requirements.txt +5 -8
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Segtesting
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.18.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Segment Anything
3
+ emoji: πŸ“š
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.47.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,58 +1,273 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoProcessor
3
  import torch
 
 
4
  import requests
5
- from PIL import Image
6
  from io import BytesIO
7
 
8
- fashion_items = ['top', 'trousers', 'jumper']
9
-
10
- # Load model and processor
11
- model_name = 'Marqo/marqo-fashionSigLIP'
12
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
13
- processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
14
-
15
- # Preprocess and normalize text data
16
- with torch.no_grad():
17
- # Ensure truncation and padding are activated
18
- processed_texts = processor(
19
- text=fashion_items,
20
- return_tensors="pt",
21
- truncation=True, # Ensure text is truncated to fit model input size
22
- padding=True # Pad shorter sequences so that all are the same length
23
- )['input_ids']
24
-
25
- text_features = model.get_text_features(processed_texts)
26
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
27
-
28
- # Prediction function
29
- def predict_from_url(url):
30
- # Check if the URL is empty
31
- if not url:
32
- return {"Error": "Please input a URL"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- try:
35
- image = Image.open(BytesIO(requests.get(url).content))
36
- except Exception as e:
37
- return {"Error": f"Failed to load image: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- processed_image = processor(images=image, return_tensors="pt")['pixel_values']
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
- image_features = model.get_image_features(processed_image)
43
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
44
- text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
45
-
46
- return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
47
-
48
- # Gradio interface
49
- demo = gr.Interface(
50
- fn=predict_from_url,
51
- inputs=gr.Textbox(label="Enter Image URL"),
52
- outputs=gr.Label(label="Classification Results"),
53
- title="Fashion Item Classifier",
54
- allow_flagging="never"
55
- )
56
-
57
- # Launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from PIL import ImageDraw, Image, ImageFont
4
+ import numpy as np
5
  import requests
 
6
  from io import BytesIO
7
 
8
+ import matplotlib.pyplot as plt
9
+ import torch
10
+ from transformers import SamModel, SamProcessor
11
+
12
+ import os
13
+
14
+
15
+ # Define variables
16
+ path = os.getcwd()
17
+ font_path = r'{}/arial.ttf'.format(path)
18
+
19
+ # Load the pre-trained model - FastSAM
20
+ # fastsam_model = FastSAM('./FastSAM-s.pt')
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
23
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
24
+
25
+ # Points
26
+ global_points = []
27
+ global_point_label = []
28
+ previous_box_points = 0
29
+
30
+ # Description
31
+ title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"
32
+
33
+ instruction = """ # Instruction
34
+ This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
35
+ πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
36
+ πŸ”₯ Step 2: Add positive (Add mask), negative (Remove Area), and bounding box for the food \n
37
+ πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
38
+ πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
39
+ πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
40
+ """
41
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
42
+
43
+ # Others
44
+ font_path = '/Users/zhe.mai/Documents/segmentation_apps/segtesting/arial.ttf'
45
+
46
+ def read_image(url):
47
+ response = requests.get(url)
48
+ img = Image.open(BytesIO(response.content))
49
 
50
+ global global_points
51
+ global global_point_label
52
+
53
+ global_points = []
54
+ global_point_label = []
55
+ return img
56
+
57
+ # def show_mask(mask, ax, random_color=False):
58
+ # if random_color:
59
+ # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
60
+ # else:
61
+ # color = np.array([30/255, 144/255, 255/255, 0.6])
62
+ # h, w = mask.shape[-2:]
63
+ # mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
64
+ # ax.imshow(mask_image)
65
+
66
+ # def show_points(coords, labels, ax, marker_size=375):
67
+ # pos_points = coords[labels==1]
68
+ # neg_points = coords[labels==0]
69
+ # ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
70
+ # ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
71
+
72
+ # def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
73
+ # masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
74
+ # scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
75
+ # #
76
+ # input_points = np.array(input_points)
77
+ # labels = np.array(input_labels)
78
+ # #
79
+ # mask = mask.cpu().detach()
80
+ # plt.imshow(np.array(raw_image))
81
+ # ax = plt.gca()
82
+ # show_mask(mask, ax)
83
+ # show_points(input_points, labels, ax, marker_size=375)
84
+ # ax.axis("off")
85
+
86
+ # save_path = args.output
87
+ # if not os.path.exists(save_path):
88
+ # os.makedirs(save_path)
89
+ # plt.axis("off")
90
+ # fig = plt.gcf()
91
+ # plt.draw()
92
 
93
+ # try:
94
+ # buf = fig.canvas.tostring_rgb()
95
+ # except AttributeError:
96
+ # fig.canvas.draw()
97
+ # buf = fig.canvas.tostring_rgb()
98
 
99
+ # cols, rows = fig.canvas.get_width_height()
100
+ # img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
101
+ # cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
102
+
103
+ def format_prompt_points(points, labels):
104
+ prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
105
+ point_labels = [l for l in labels if l != 9]
106
+ #
107
+ prompt_boxes = None
108
+ if len(point_labels) < len(labels):
109
+ prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
110
+ return prompt_points, point_labels, prompt_boxes
111
+
112
+ # def get_mask_image(raw_image, mask):
113
+ # tmp_mask = np.array(mask)
114
+ # tmp_img_arr = np.array(raw_image)
115
+ # tmp_img_arr[tmp_mask == False] = [255,255,255]
116
+ # return tmp_img_arr
117
+
118
+ def get_mask_image(raw_image, mask):
119
+ tmp_mask = np.array(mask * 1)
120
+ tmp_mask[tmp_mask == 1] = 255
121
+ tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
122
+ #
123
+ tmp_img_arr = np.array(raw_image)
124
+ tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
125
+ return tmp_img_arr
126
+
127
+
128
+ def segment_with_points(
129
+ input,
130
+ input_size=1024,
131
+ iou_threshold=0.7,
132
+ conf_threshold=0.25,
133
+ better_quality=False,
134
+ withContours=True,
135
+ use_retina=True,
136
+ mask_random_color=True,
137
+ ):
138
+ global global_points
139
+ global global_point_label
140
+
141
+ # read image
142
+ raw_image = Image.open(requests.get(input, stream=True).raw).convert("RGB")
143
+
144
+ # get prompts
145
+ prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
146
+ print(prompt_points, point_labels, prompt_boxes)
147
+ # segment
148
+ inputs = processor(raw_image,
149
+ input_boxes = prompt_boxes,
150
+ input_points=[[prompt_points]],
151
+ input_labels=[point_labels],
152
+ return_tensors="pt").to(device)
153
  with torch.no_grad():
154
+ outputs = model(**inputs)
155
+ #
156
+ masks = processor.image_processor.post_process_masks(
157
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
158
+ scores = outputs.iou_scores
159
+
160
+ # only show the first mask
161
+ # fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
162
+ mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
163
+ mask_img1, mask_img2, mask_img3 = mask_images
164
+ # return fig, None
165
+ return mask_img1, mask_img2, mask_img3
166
+
167
+ def find_font_size(text, font_path, image, target_width_ratio):
168
+ tested_font_size = 100
169
+ tested_font = ImageFont.truetype(font_path, tested_font_size)
170
+ observed_width = get_text_size(text, image, tested_font)
171
+ estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
172
+ return round(estimated_font_size)
173
+
174
+ def get_text_size(text, image, font):
175
+ im = Image.new('RGB', (image.width, image.height))
176
+ draw = ImageDraw.Draw(im)
177
+ return draw.textlength(text, font)
178
+
179
+
180
+ def get_points_with_draw(image, label, evt: gr.SelectData):
181
+ global global_points
182
+ global global_point_label
183
+ global previous_box_points
184
+
185
+ x, y = evt.index[0], evt.index[1]
186
+ point_radius = 15
187
+ point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
188
+ global_points.append([x, y])
189
+ global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
190
+
191
+ print(x, y, label)
192
+ print(previous_box_points)
193
+
194
+ draw = ImageDraw.Draw(image)
195
+ if label != 'Bounding Box':
196
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
197
+ else:
198
+ if (previous_box_points == 0) | (previous_box_points%2 == 0):
199
+ target_width_ratio = 0.9
200
+ text = "Please Click Another Point For Bounding Box"
201
+ font_size = find_font_size(text, font_path, image, target_width_ratio)
202
+ font = ImageFont.truetype(font_path, font_size)
203
+ draw.text((x, y), text, fill = (0,0,0), font = font)
204
+ else:
205
+ [previous_x, previous_y] = global_points[-2]
206
+ print((previous_x, previous_y), (x, y))
207
+ draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
208
+ previous_box_points += 1
209
+ return image
210
+
211
+ def clear():
212
+ global global_points
213
+ global global_point_label
214
+
215
+ global_points = []
216
+ global_point_label = []
217
+ previous_box_points = 0
218
+ return None, None, None, None
219
+
220
+
221
+ # Configure layout
222
+ cond_img_p = gr.Image(label="Input with points", type='pil')
223
+ segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
224
+ segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
225
+ segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")
226
+
227
+ with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
228
+ with gr.Row():
229
+ with gr.Column(scale=1):
230
+ gr.Markdown(title)
231
+ gr.Markdown("")
232
+ image_url = gr.Textbox(label="Enter Image URL",
233
+ value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
234
+ run_with_url = gr.Button("Upload Image")
235
+ with gr.Column(scale=1):
236
+ gr.Markdown(instruction)
237
+
238
+ # Images
239
+ with gr.Row(variant="panel"):
240
+ with gr.Column(scale=0):
241
+ cond_img_p.render()
242
+ segm_img_p2.render()
243
+ with gr.Column(scale=0):
244
+ segm_img_p1.render()
245
+ segm_img_p3.render()
246
+
247
+ # Submit & Clear
248
+ with gr.Row():
249
+ with gr.Column():
250
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"],
251
+ value="Add Mask",
252
+ label="Point label")
253
+ with gr.Column():
254
+ segment_btn_p = gr.Button("Segment with prompts", variant='primary')
255
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
256
+
257
+ # Define interaction relationship
258
+ run_with_url.click(read_image,
259
+ inputs=[image_url],
260
+ # outputs=[segm_img_p, cond_img_p])
261
+ outputs=[cond_img_p])
262
+
263
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
264
+
265
+ segment_btn_p.click(segment_with_points,
266
+ inputs=[image_url],
267
+ # outputs=[segm_img_p, cond_img_p])
268
+ outputs=[segm_img_p1, segm_img_p2, segm_img_p3])
269
+
270
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p1, segm_img_p2, segm_img_p3])
271
+
272
+ demo.queue()
273
  demo.launch()
arial.ttf ADDED
Binary file (312 kB). View file
 
requirements.txt CHANGED
@@ -1,9 +1,6 @@
1
- transformers
2
- torch
3
- requests
4
- Pillow
5
- open_clip_torch
6
- ftfy
7
 
8
- # This is only needed for local deployment
9
- gradio
 
1
+ matplotlib==3.2.2
2
+ numpy
3
+ opencv-python
4
+ transformers==4.49.0
5
+ pillow==11.1.0
 
6