Spaces:
Running
Running
File size: 5,617 Bytes
dc45a68 b7bcd7d a084fc6 dc45a68 ab62477 9098a03 ab62477 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 ab62477 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 181f304 9098a03 ab62477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==4.44.1")
os.system("pip install gradio_image_prompter")
import gradio as gr
import torch
from PIL import ImageDraw, Image, ImageFont
import numpy as np
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import torch
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import os
# define variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_id = "facebook/sam-vit-huge" #60s
model_id = 'Zigeng/SlimSAM-uniform-50' #50s
# model_id = "facebook/sam-vit-base" #50s
# model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
model = SamModel.from_pretrained(model_id).to(device)
processor = SamProcessor.from_pretrained(model_id)
# Description
title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>"
instruction = """ # Instruction
This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n
π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n
π₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n
π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n
π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n
π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID
"""
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
# functions
def read_image(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
formatted_image = {
"image": np.array(img),
"points": [],
} # Create the correct format
return formatted_image
def get_mask_image(raw_image, mask):
tmp_mask = np.array(mask * 1)
tmp_mask[tmp_mask == 1] = 255
tmp_mask2 = np.expand_dims(tmp_mask, axis=2)
#
tmp_img_arr = np.array(raw_image)
tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2)
return tmp_img_arr
def format_prompt_points(points):
prompt_points = []
point_labels = []
prompt_boxes = []
for point in points:
print(point)
if point[2] == 2.0 and point[5] == 3.0:
prompt_boxes.append([point[0], point[1], point[3], point[4]])
else:
prompt_points.append([point[0], point[1]])
label = 1 if point[2] == 1.0 else 0
point_labels.append(label)
return prompt_points, point_labels, prompt_boxes
def segment_with_points(
prompts
):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
#
prompt_points, point_labels, prompt_boxes = format_prompt_points(points)
print(prompt_points, point_labels, prompt_boxes)
# segment
inputs = processor(image,
input_boxes = [prompt_boxes],
input_points=[[prompt_points]],
input_labels=[point_labels],
return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
#
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
#
mask_images = [get_mask_image(image, m) for m in masks[0][0]]
mask_img1, mask_img2, mask_img3 = mask_images
# return fig, None
return mask_img1, mask_img2, mask_img3
def clear():
return None, None, None, None
with gr.Blocks(css=css, title='Segment Food with Prompts') as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(title)
gr.Markdown("")
image_url = gr.Textbox(label="Enter Image URL",
value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg")
run_with_url = gr.Button("Upload Image")
segment_btn = gr.Button("Segment with prompts", variant='primary')
clear_btn = gr.Button("Clear points", variant='secondary')
with gr.Column(scale=1):
gr.Markdown(instruction)
# Images
with gr.Row(variant="panel"):
with gr.Column(scale=0):
candidate_pic = ImagePrompter(show_label=False)
segpic_output1 = gr.Image(format="png")
with gr.Column(scale=0):
segpic_output2 = gr.Image(format="png")
segpic_output3 = gr.Image(format="png")
# Define interaction relationship
run_with_url.click(read_image,
inputs=[image_url],
# outputs=[segm_img_p, cond_img_p])
outputs=[candidate_pic])
segment_btn.click(segment_with_points,
inputs=candidate_pic,
# outputs=[segm_img_p, cond_img_p])
outputs=[segpic_output1, segpic_output2, segpic_output3])
clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3])
demo.queue()
demo.launch() |