File size: 5,823 Bytes
81f1748
 
 
9b7c3b1
81f1748
ab62477
 
9098a03
 
ab62477
 
 
9098a03
 
 
 
81f1748
9098a03
81f1748
e03a844
81f1748
9098a03
81f1748
 
 
 
 
 
9098a03
 
 
 
 
 
 
81f1748
9098a03
 
 
 
 
 
81f1748
9098a03
 
 
 
81f1748
 
 
 
 
9098a03
 
 
 
 
 
 
 
 
 
81f1748
 
 
 
 
 
 
 
 
 
 
 
9c2a2dc
 
 
81f1748
9098a03
 
81f1748
9098a03
 
81f1748
 
 
 
9098a03
 
81f1748
9c2a2dc
 
 
9098a03
ab62477
9098a03
 
 
 
 
81f1748
 
9098a03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f1748
 
9098a03
 
 
 
 
 
81f1748
 
9098a03
81f1748
 
9098a03
 
 
 
 
81f1748
9098a03
81f1748
 
9098a03
81f1748
9098a03
81f1748
9098a03
 
9c2a2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==4.44.1")
# os.system("pip install gradio_image_prompter")

import gradio as gr
import torch
from PIL import ImageDraw, Image, ImageFont
import numpy as np
import requests
from io import BytesIO

import matplotlib.pyplot as plt
import torch
from transformers import SamModel, SamProcessor

from gradio_image_prompter import ImagePrompter

import os

# define variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_id = "facebook/sam-vit-huge" #60s
model_id = 'Zigeng/SlimSAM-uniform-50' #50s
# model_id = "facebook/sam-vit-base" #50s
# model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
model = SamModel.from_pretrained(model_id).to(device)
processor = SamProcessor.from_pretrained(model_id)

# Description
title = "<center><strong><font size='8'> πŸ” Segment food with clicks 🍜</font></strong></center>"

instruction = """ # Instruction
        This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
        πŸ”₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
        πŸ”₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n
        πŸ”₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
        πŸ”₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
        πŸ”₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID 
        """  
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"

# functions

def read_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    formatted_image = {
        "image": np.array(img),
        "points": [],
    }  # Create the correct format
    return formatted_image

def get_mask_image(raw_image, mask):
    tmp_mask = np.array(mask * 1)
    tmp_mask[tmp_mask == 1] = 255
    tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
    #
    tmp_img_arr = np.array(raw_image)
    tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
    return tmp_img_arr

def format_prompt_points(points):
    prompt_points = []
    point_labels = []
    prompt_boxes = []
    for point in points:
        print(point)
        if point[2] == 2.0 and point[5] == 3.0:
            prompt_boxes.append([point[0], point[1], point[3], point[4]])
        else:
            prompt_points.append([point[0], point[1]])
            label = 1 if point[2] == 1.0 else 0
            point_labels.append(label)
    prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None
    point_labels = [point_labels] if len(point_labels) > 0 else None
    prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None
    return prompt_points, point_labels, prompt_boxes

def segment_with_points(
    prompts
):

    image = np.array(prompts["image"])  # Convert the image to a numpy array
    points = prompts["points"]  # Get the points from prompts
    #
    prompt_points, point_labels, prompt_boxes = format_prompt_points(points)
    print(prompt_points, point_labels, prompt_boxes)
    # segment
    inputs = processor(image, 
                       input_boxes = prompt_boxes,
                       input_points=prompt_points, 
                       input_labels=point_labels, 
                       return_tensors="pt").to(device)
    with torch.no_grad():
      outputs = model(**inputs)
    #
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
    scores = outputs.iou_scores
    #
    mask_images = [get_mask_image(image, m) for m in masks[0][0]]
    mask_img1, mask_img2, mask_img3 = mask_images
    # return fig, None
    return mask_img1, mask_img2, mask_img3

def clear():
    return None, None, None, None

with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(title)
            gr.Markdown("")
            image_url = gr.Textbox(label="Enter Image URL", 
                value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
            run_with_url = gr.Button("Upload Image")
            segment_btn = gr.Button("Segment with prompts", variant='primary')
            clear_btn = gr.Button("Clear points", variant='secondary')
        with gr.Column(scale=1):
            gr.Markdown(instruction)

        # Images
    with gr.Row(variant="panel"):
        with gr.Column(scale=0):
            candidate_pic = ImagePrompter(show_label=False)
            segpic_output1 = gr.Image(format="png")
        with gr.Column(scale=0):
            segpic_output2 = gr.Image(format="png")
            segpic_output3 = gr.Image(format="png")

    # Define interaction relationship
    run_with_url.click(read_image,
                        inputs=[image_url],
                        # outputs=[segm_img_p, cond_img_p])
                        outputs=[candidate_pic])

    segment_btn.click(segment_with_points,
                        inputs=candidate_pic,
                        # outputs=[segm_img_p, cond_img_p])
                        outputs=[segpic_output1, segpic_output2, segpic_output3])
    
    clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])

demo.queue()
demo.launch()