File size: 10,379 Bytes
181f304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import gradio as gr
import torch
from PIL import ImageDraw, Image, ImageFont
import numpy as np
import requests
from io import BytesIO

import matplotlib.pyplot as plt
import torch
from transformers import SamModel, SamProcessor

import os


# Define variables
path = os.getcwd()
font_path = r'{}/arial.ttf'.format(path)
print(font_path)

# Load the pre-trained model - FastSAM
# fastsam_model = FastSAM('./FastSAM-s.pt')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# Points
global_points = []
global_point_label = []
previous_box_points = 0

# Description
title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"

instruction = """ # Instruction
        This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
        πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
        πŸ”₯ Step 2: Add positive (Add mask), negative (Remove Area), and bounding box for the food \n
        πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
        πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
        πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID 
        """  
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"


def read_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    
    global global_points
    global global_point_label
  
    global_points = []
    global_point_label = []
    return img

# def show_mask(mask, ax, random_color=False):
#     if random_color:
#         color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
#     else:
#         color = np.array([30/255, 144/255, 255/255, 0.6])
#     h, w = mask.shape[-2:]
#     mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#     ax.imshow(mask_image)

# def show_points(coords, labels, ax, marker_size=375):
#     pos_points = coords[labels==1]
#     neg_points = coords[labels==0]
#     ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
#     ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# def show_masks_and_points_on_image(raw_image, mask, input_points, input_labels, args):
#     masks = masks.squeeze() if len(masks.shape) == 4 else masks.unsqueeze(0) if len(masks.shape) == 2 else masks
#     scores = scores.squeeze() if (scores.shape[0] == 1) & (len(scores.shape) == 3) else scores
#     #
#     input_points = np.array(input_points)
#     labels = np.array(input_labels)
#     # 
#     mask = mask.cpu().detach()
#     plt.imshow(np.array(raw_image))
#     ax = plt.gca()
#     show_mask(mask, ax)
#     show_points(input_points, labels, ax, marker_size=375)  
#     ax.axis("off")    

#     save_path = args.output
#     if not os.path.exists(save_path):
#         os.makedirs(save_path)
#     plt.axis("off")
#     fig = plt.gcf()
#     plt.draw()
    
#     try:
#         buf = fig.canvas.tostring_rgb()
#     except AttributeError:
#         fig.canvas.draw()
#         buf = fig.canvas.tostring_rgb()
    
#     cols, rows = fig.canvas.get_width_height()
#     img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
#     cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))

def format_prompt_points(points, labels):
    prompt_points = [xy for xy, l in zip(points, labels) if l != 9]
    point_labels = [l for l in labels if l != 9]
    #
    prompt_boxes = None
    if len(point_labels) < len(labels):
      prompt_boxes = [[np.array([xy for xy, l in zip(points, labels) if l == 9]).reshape(-1, 4).tolist()]]
    return prompt_points, point_labels, prompt_boxes

# def get_mask_image(raw_image, mask):
#     tmp_mask = np.array(mask)
#     tmp_img_arr = np.array(raw_image)
#     tmp_img_arr[tmp_mask == False] = [255,255,255]
#     return tmp_img_arr

def get_mask_image(raw_image, mask):
    tmp_mask = np.array(mask * 1)
    tmp_mask[tmp_mask == 1] = 255
    tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
    #
    tmp_img_arr = np.array(raw_image)
    tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
    return tmp_img_arr


def segment_with_points(
    input,
    input_size=1024, 
    iou_threshold=0.7,
    conf_threshold=0.25,
    better_quality=False,
    withContours=True,
    use_retina=True,
    mask_random_color=True,
):
    global global_points
    global global_point_label

    # read image
    raw_image = Image.open(requests.get(input, stream=True).raw).convert("RGB")

    # get prompts
    prompt_points, point_labels, prompt_boxes = format_prompt_points(global_points, global_point_label)
    print(prompt_points, point_labels, prompt_boxes)
    # segment
    inputs = processor(raw_image, 
                       input_boxes = prompt_boxes,
                       input_points=[[prompt_points]], 
                       input_labels=[point_labels], 
                       return_tensors="pt").to(device)
    with torch.no_grad():
      outputs = model(**inputs)
    #
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
    scores = outputs.iou_scores

    # only show the first mask
    # fig = show_masks_and_points_on_image(raw_image, masks[0][0][0], [global_points], global_point_label)
    mask_images = [get_mask_image(raw_image, m) for m in masks[0][0]]
    mask_img1, mask_img2, mask_img3 = mask_images
    # return fig, None
    return mask_img1, mask_img2, mask_img3

def find_font_size(text, font_path, image, target_width_ratio):
    tested_font_size = 100
    tested_font = ImageFont.truetype(font_path, tested_font_size)
    observed_width = get_text_size(text, image, tested_font)
    estimated_font_size = tested_font_size / (observed_width / image.width) * target_width_ratio
    return round(estimated_font_size)

def get_text_size(text, image, font):
    im = Image.new('RGB', (image.width, image.height))
    draw = ImageDraw.Draw(im)
    return draw.textlength(text, font)


def get_points_with_draw(image, label, evt: gr.SelectData):
    global global_points
    global global_point_label
    global previous_box_points

    x, y = evt.index[0], evt.index[1]
    point_radius = 15
    point_color = (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
    global_points.append([x, y])
    global_point_label.append(1 if label == 'Add Mask' else 0 if label == 'Remove Area' else 9)
    
    print(x, y, label)
    print(previous_box_points)
    
    draw = ImageDraw.Draw(image)
    if label != 'Bounding Box':
        draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
    else:
        if (previous_box_points == 0) | (previous_box_points%2 == 0):
          target_width_ratio = 0.9
          text = "Please Click Another Point For Bounding Box"
          font_size = find_font_size(text, font_path, image, target_width_ratio)
          font = ImageFont.truetype(font_path, font_size)
          draw.text((x, y), text, fill = (0,0,0), font = font)
        else:
          [previous_x, previous_y] = global_points[-2]
          print((previous_x, previous_y), (x, y))
          draw.rectangle([(previous_x, previous_y), (x, y)], outline=(0, 0, 255), width=10)
        previous_box_points += 1
    return image

def clear():
    global global_points
    global global_point_label

    global_points = []
    global_point_label = []
    previous_box_points = 0
    return None, None, None, None


# Configure layout
cond_img_p = gr.Image(label="Input with points", type='pil')
segm_img_p1 = gr.Image(label="Segmented Image Option 1", interactive=False, type='pil', format="png")
segm_img_p2 = gr.Image(label="Segmented Image Option 2", interactive=False, type='pil', format="png")
segm_img_p3 = gr.Image(label="Segmented Image Option 3", interactive=False, type='pil', format="png")

with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(title)
            gr.Markdown("")
            image_url = gr.Textbox(label="Enter Image URL", 
                value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
            run_with_url = gr.Button("Upload Image")
        with gr.Column(scale=1):
            gr.Markdown(instruction)

        # Images
    with gr.Row(variant="panel"):
        with gr.Column(scale=0):
            cond_img_p.render()
            segm_img_p2.render()
        with gr.Column(scale=0):
            segm_img_p1.render()
            segm_img_p3.render()
            
    # Submit & Clear
    with gr.Row():
        with gr.Column():
            add_or_remove = gr.Radio(["Add Mask", "Remove Area", "Bounding Box"], 
                    value="Add Mask", 
                    label="Point label")
        with gr.Column():
            segment_btn_p = gr.Button("Segment with prompts", variant='primary')
            clear_btn_p = gr.Button("Clear points", variant='secondary')

    # Define interaction relationship
    run_with_url.click(read_image,
                        inputs=[image_url],
                        # outputs=[segm_img_p, cond_img_p])
                        outputs=[cond_img_p])

    cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)

    segment_btn_p.click(segment_with_points,
                        inputs=[image_url],
                        # outputs=[segm_img_p, cond_img_p])
                        outputs=[segm_img_p1, segm_img_p2, segm_img_p3])
    
    clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p1, segm_img_p2, segm_img_p3])

demo.queue()
demo.launch()