Spaces:
Running
Running
import os | |
# os.system("pip uninstall -y gradio") | |
# os.system("pip install gradio==4.44.1") | |
# os.system("pip install gradio_image_prompter") | |
import gradio as gr | |
import torch | |
from PIL import ImageDraw, Image, ImageFont | |
import numpy as np | |
import requests | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
import torch | |
from transformers import SamModel, SamProcessor | |
from gradio_image_prompter import ImagePrompter | |
import os | |
# define variables | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# model_id = "facebook/sam-vit-huge" #60s | |
model_id = 'Zigeng/SlimSAM-uniform-50' #50s | |
# model_id = "facebook/sam-vit-base" #50s | |
# model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
model = SamModel.from_pretrained(model_id).to(device) | |
processor = SamProcessor.from_pretrained(model_id) | |
# Description | |
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>" | |
instruction = """ # Instruction | |
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n | |
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n | |
π₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n | |
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n | |
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n | |
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID | |
""" | |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
# functions | |
def read_image(url): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
formatted_image = { | |
"image": np.array(img), | |
"points": [], | |
} # Create the correct format | |
return formatted_image | |
def get_mask_image(raw_image, mask): | |
tmp_mask = np.array(mask * 1) | |
tmp_mask[tmp_mask == 1] = 255 | |
tmp_mask2 = np.expand_dims(tmp_mask, axis=2) | |
# | |
tmp_img_arr = np.array(raw_image) | |
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2) | |
return tmp_img_arr | |
def format_prompt_points(points): | |
prompt_points = [] | |
point_labels = [] | |
prompt_boxes = [] | |
for point in points: | |
print(point) | |
if point[2] == 2.0 and point[5] == 3.0: | |
prompt_boxes.append([point[0], point[1], point[3], point[4]]) | |
else: | |
prompt_points.append([point[0], point[1]]) | |
label = 1 if point[2] == 1.0 else 0 | |
point_labels.append(label) | |
prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None | |
point_labels = [point_labels] if len(point_labels) > 0 else None | |
prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None | |
return prompt_points, point_labels, prompt_boxes | |
def segment_with_points( | |
prompts | |
): | |
image = np.array(prompts["image"]) # Convert the image to a numpy array | |
points = prompts["points"] # Get the points from prompts | |
# | |
prompt_points, point_labels, prompt_boxes = format_prompt_points(points) | |
print(prompt_points, point_labels, prompt_boxes) | |
# segment | |
inputs = processor(image, | |
input_boxes = prompt_boxes, | |
input_points=prompt_points, | |
input_labels=point_labels, | |
return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
scores = outputs.iou_scores | |
# | |
mask_images = [get_mask_image(image, m) for m in masks[0][0]] | |
mask_img1, mask_img2, mask_img3 = mask_images | |
# return fig, None | |
return mask_img1, mask_img2, mask_img3 | |
def clear(): | |
return None, None, None, None | |
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(title) | |
gr.Markdown("") | |
image_url = gr.Textbox(label="Enter Image URL", | |
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg") | |
run_with_url = gr.Button("Upload Image") | |
segment_btn = gr.Button("Segment with prompts", variant='primary') | |
clear_btn = gr.Button("Clear points", variant='secondary') | |
with gr.Column(scale=1): | |
gr.Markdown(instruction) | |
# Images | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=0): | |
candidate_pic = ImagePrompter(show_label=False) | |
segpic_output1 = gr.Image(format="png") | |
with gr.Column(scale=0): | |
segpic_output2 = gr.Image(format="png") | |
segpic_output3 = gr.Image(format="png") | |
# Define interaction relationship | |
run_with_url.click(read_image, | |
inputs=[image_url], | |
# outputs=[segm_img_p, cond_img_p]) | |
outputs=[candidate_pic]) | |
segment_btn.click(segment_with_points, | |
inputs=candidate_pic, | |
# outputs=[segm_img_p, cond_img_p]) | |
outputs=[segpic_output1, segpic_output2, segpic_output3]) | |
clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3]) | |
demo.queue() | |
demo.launch() |