import os # os.system("pip uninstall -y gradio") # os.system("pip install gradio==4.44.1") # os.system("pip install gradio_image_prompter") import gradio as gr import torch from PIL import ImageDraw, Image, ImageFont import numpy as np import requests from io import BytesIO import matplotlib.pyplot as plt import torch from transformers import SamModel, SamProcessor from gradio_image_prompter import ImagePrompter import os # define variables device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model_id = "facebook/sam-vit-huge" #60s model_id = 'Zigeng/SlimSAM-uniform-50' #50s # model_id = "facebook/sam-vit-base" #50s # model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) model = SamModel.from_pretrained(model_id).to(device) processor = SamProcessor.from_pretrained(model_id) # Description title = "
🍔 Segment food with clicks 🍜
" instruction = """ # Instruction This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n 🔥 Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n 🔥 Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n 🔥 Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n 🔥 Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n 🔥 Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID """ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" # functions def read_image(url): response = requests.get(url) img = Image.open(BytesIO(response.content)) formatted_image = { "image": np.array(img), "points": [], } # Create the correct format return formatted_image def get_mask_image(raw_image, mask): tmp_mask = np.array(mask * 1) tmp_mask[tmp_mask == 1] = 255 tmp_mask2 = np.expand_dims(tmp_mask, axis=2) # tmp_img_arr = np.array(raw_image) tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2) return tmp_img_arr def format_prompt_points(points): prompt_points = [] point_labels = [] prompt_boxes = [] for point in points: print(point) if point[2] == 2.0 and point[5] == 3.0: prompt_boxes.append([point[0], point[1], point[3], point[4]]) else: prompt_points.append([point[0], point[1]]) label = 1 if point[2] == 1.0 else 0 point_labels.append(label) prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None point_labels = [point_labels] if len(point_labels) > 0 else None prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None return prompt_points, point_labels, prompt_boxes def segment_with_points( prompts ): image = np.array(prompts["image"]) # Convert the image to a numpy array points = prompts["points"] # Get the points from prompts # prompt_points, point_labels, prompt_boxes = format_prompt_points(points) print(prompt_points, point_labels, prompt_boxes) # segment inputs = processor(image, input_boxes = prompt_boxes, input_points=prompt_points, input_labels=point_labels, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) scores = outputs.iou_scores # mask_images = [get_mask_image(image, m) for m in masks[0][0]] mask_img1, mask_img2, mask_img3 = mask_images # return fig, None return mask_img1, mask_img2, mask_img3 def clear(): return None, None, None, None with gr.Blocks(css=css, title='Segment Food with Prompts') as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown(title) gr.Markdown("") image_url = gr.Textbox(label="Enter Image URL", value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg") run_with_url = gr.Button("Upload Image") segment_btn = gr.Button("Segment with prompts", variant='primary') clear_btn = gr.Button("Clear points", variant='secondary') with gr.Column(scale=1): gr.Markdown(instruction) # Images with gr.Row(variant="panel"): with gr.Column(scale=0): candidate_pic = ImagePrompter(show_label=False) segpic_output1 = gr.Image(format="png") with gr.Column(scale=0): segpic_output2 = gr.Image(format="png") segpic_output3 = gr.Image(format="png") # Define interaction relationship run_with_url.click(read_image, inputs=[image_url], # outputs=[segm_img_p, cond_img_p]) outputs=[candidate_pic]) segment_btn.click(segment_with_points, inputs=candidate_pic, # outputs=[segm_img_p, cond_img_p]) outputs=[segpic_output1, segpic_output2, segpic_output3]) clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3]) demo.queue() demo.launch()