Spaces:
Build error
Build error
import os | |
import cv2 | |
import sys | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
models = { | |
'vit_b': './checkpoints/sam_vit_b_01ec64.pth', | |
'vit_l': './checkpoints/sam_vit_l_0b3195.pth', | |
'vit_h': './checkpoints/sam_vit_h_4b8939.pth' | |
} | |
def inference(device, model_type, input_img, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area, | |
stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh): | |
sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device) | |
mask_generator = SamAutomaticMaskGenerator( | |
sam, | |
points_per_side=points_per_side, | |
pred_iou_thresh=pred_iou_thresh, | |
stability_score_thresh=stability_score_thresh, | |
stability_score_offset=stability_score_offset, | |
box_nms_thresh=box_nms_thresh, | |
crop_n_layers=crop_n_layers, | |
crop_nms_thresh=crop_nms_thresh, | |
crop_overlap_ratio=512 / 1500, | |
crop_n_points_downscale_factor=1, | |
point_grids=None, | |
min_mask_region_area=min_mask_region_area, | |
output_mode='binary_mask' | |
) | |
masks = mask_generator.generate(input_img) | |
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) | |
mask_all = np.ones((input_img.shape[0], input_img.shape[1], 3)) | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
mask_all[m==True, i] = color_mask[i] | |
result = input_img / 255 * 0.3 + mask_all * 0.7 | |
return result, mask_all | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown( | |
'''# Segment Anything!🚀 | |
分割一切!CV的GPT-3时刻! | |
[**官方网址**](https://segment-anything.com/) | |
''' | |
) | |
with gr.Row(): | |
# 选择模型类型 | |
model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="选择模型") | |
# 选择device | |
device = gr.Dropdown(["cpu", "cuda"], value='cuda', label="选择你的硬件") | |
# 参数 | |
with gr.Accordion(label='参数调整', open=False): | |
with gr.Row(): | |
points_per_side = gr.Number(value=32, label="points_per_side", precision=0, | |
info='''The number of points to be sampled along one side of the image. The total | |
number of points is points_per_side**2.''') | |
pred_iou_thresh = gr.Slider(value=0.88, minimum=0, maximum=1.0, step=0.01, label="pred_iou_thresh", | |
info='''A filtering threshold in [0,1], using the model's predicted mask quality.''') | |
stability_score_thresh = gr.Slider(value=0.95, minimum=0, maximum=1.0, step=0.01, label="stability_score_thresh", | |
info='''A filtering threshold in [0,1], using the stability of the mask under | |
changes to the cutoff used to binarize the model's mask predictions.''') | |
min_mask_region_area = gr.Number(value=0, label="min_mask_region_area", precision=0, | |
info='''If >0, postprocessing will be applied to remove disconnected regions | |
and holes in masks with area smaller than min_mask_region_area.''') | |
with gr.Row(): | |
stability_score_offset = gr.Number(value=1, label="stability_score_offset", | |
info='''The amount to shift the cutoff when calculated the stability score.''') | |
box_nms_thresh = gr.Slider(value=0.7, minimum=0, maximum=1.0, step=0.01, label="box_nms_thresh", | |
info='''The box IoU cutoff used by non-maximal ression to filter duplicate masks.''') | |
crop_n_layers = gr.Number(value=0, label="crop_n_layers", precision=0, | |
info='''If >0, mask prediction will be run again on crops of the image. | |
Sets the number of layers to run, where each layer has 2**i_layer number of image crops.''') | |
crop_nms_thresh = gr.Slider(value=0.7, minimum=0, maximum=1.0, step=0.01, label="crop_nms_thresh", | |
info='''The box IoU cutoff used by non-maximal suppression to filter duplicate | |
masks between different crops.''') | |
# 显示图片 | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
input_image = gr.Image(type="numpy") | |
with gr.Row(): | |
button = gr.Button("Auto!") | |
with gr.Tab(label='原图+mask'): | |
image_output = gr.Image(type='numpy') | |
with gr.Tab(label='Mask'): | |
mask_output = gr.Image(type='numpy') | |
gr.Examples( | |
examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/1.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/2.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/3.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/4.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/5.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/6.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/7.jpg"), | |
os.path.join(os.path.dirname(__file__), "./images/8.jpg"), | |
], | |
inputs=input_image, | |
outputs=image_output, | |
) | |
# 按钮交互 | |
button.click(inference, inputs=[device, model_type, input_image, points_per_side, pred_iou_thresh, | |
stability_score_thresh, min_mask_region_area, stability_score_offset, box_nms_thresh, | |
crop_n_layers, crop_nms_thresh], | |
outputs=[image_output, mask_output]) | |
demo.launch(debug=True) | |