import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==4.44.1")
# os.system("pip install gradio_image_prompter")
import gradio as gr
import torch
from PIL import ImageDraw, Image, ImageFont
import numpy as np
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import torch
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import os
# define variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_id = "facebook/sam-vit-huge" #60s
model_id = 'Zigeng/SlimSAM-uniform-50' #50s
# model_id = "facebook/sam-vit-base" #50s
# model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
model = SamModel.from_pretrained(model_id).to(device)
processor = SamProcessor.from_pretrained(model_id)
# Description
title = "
🍔 Segment food with clicks 🍜"
instruction = """ # Instruction
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
🔥 Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
🔥 Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n
🔥 Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
🔥 Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
🔥 Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
"""
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
# functions
def read_image(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
formatted_image = {
"image": np.array(img),
"points": [],
} # Create the correct format
return formatted_image
def get_mask_image(raw_image, mask):
tmp_mask = np.array(mask * 1)
tmp_mask[tmp_mask == 1] = 255
tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
#
tmp_img_arr = np.array(raw_image)
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
return tmp_img_arr
def format_prompt_points(points):
prompt_points = []
point_labels = []
prompt_boxes = []
for point in points:
print(point)
if point[2] == 2.0 and point[5] == 3.0:
prompt_boxes.append([point[0], point[1], point[3], point[4]])
else:
prompt_points.append([point[0], point[1]])
label = 1 if point[2] == 1.0 else 0
point_labels.append(label)
prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None
point_labels = [point_labels] if len(point_labels) > 0 else None
prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None
return prompt_points, point_labels, prompt_boxes
def segment_with_points(
prompts
):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
#
prompt_points, point_labels, prompt_boxes = format_prompt_points(points)
print(prompt_points, point_labels, prompt_boxes)
# segment
inputs = processor(image,
input_boxes = prompt_boxes,
input_points=prompt_points,
input_labels=point_labels,
return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
#
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
#
mask_images = [get_mask_image(image, m) for m in masks[0][0]]
mask_img1, mask_img2, mask_img3 = mask_images
# return fig, None
return mask_img1, mask_img2, mask_img3
def clear():
return None, None, None, None
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(title)
gr.Markdown("")
image_url = gr.Textbox(label="Enter Image URL",
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
run_with_url = gr.Button("Upload Image")
segment_btn = gr.Button("Segment with prompts", variant='primary')
clear_btn = gr.Button("Clear points", variant='secondary')
with gr.Column(scale=1):
gr.Markdown(instruction)
# Images
with gr.Row(variant="panel"):
with gr.Column(scale=0):
candidate_pic = ImagePrompter(show_label=False)
segpic_output1 = gr.Image(format="png")
with gr.Column(scale=0):
segpic_output2 = gr.Image(format="png")
segpic_output3 = gr.Image(format="png")
# Define interaction relationship
run_with_url.click(read_image,
inputs=[image_url],
# outputs=[segm_img_p, cond_img_p])
outputs=[candidate_pic])
segment_btn.click(segment_with_points,
inputs=candidate_pic,
# outputs=[segm_img_p, cond_img_p])
outputs=[segpic_output1, segpic_output2, segpic_output3])
clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
demo.queue()
demo.launch()