Spaces:
Running
on
Zero
Running
on
Zero
## Some code was modified from Ovseg and OV-Sam.Thanks to their excellent work. | |
## Ovseg Code:https://github.com/facebookresearch/ov-seg | |
## OV-Sam Code:https://github.com/HarborYuan/ovsam | |
import spaces | |
import multiprocessing as mp | |
import numpy as np | |
from PIL import Image,ImageDraw | |
import torch | |
try: | |
import detectron2 | |
except: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
from detectron2.config import get_cfg | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.data.detection_utils import read_image | |
from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config | |
from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo | |
import gradio as gr | |
import open_clip | |
from sam2.build_sam import build_sam2 | |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter | |
from mask_adapter.data.datasets import openseg_classes | |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() | |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] | |
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng() | |
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_] | |
class_names_coco_ade20k = stuff_classes + ade20k_stuff_classes | |
def setup_cfg(config_file): | |
cfg = get_cfg() | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
add_fcclip_config(cfg) | |
add_mask_adapter_config(cfg) | |
cfg.merge_from_file(config_file) | |
cfg.freeze() | |
return cfg | |
class IMGState: | |
def __init__(self): | |
self.img = None | |
self.selected_points = [] | |
self.selected_points_labels = [] | |
self.selected_bboxes = [] | |
self.available_to_set = True | |
def set_img(self, img): | |
self.img = img | |
self.available_to_set = False | |
def clear(self): | |
self.img = None | |
self.selected_points = [] | |
self.selected_points_labels = [] | |
self.selected_bboxes = [] | |
self.available_to_set = True | |
def clean(self): | |
self.selected_points = [] | |
self.selected_points_labels = [] | |
self.selected_bboxes = [] | |
def available(self): | |
return self.available_to_set | |
def inference_automatic(input_img, class_names): | |
mp.set_start_method("spawn", force=True) | |
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
cfg = setup_cfg(config_file) | |
demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) | |
class_names = class_names.split(',') | |
img = read_image(input_img, format="BGR") | |
if len(class_names) == 1: | |
class_names.append('others') | |
txts = [f'a photo of {cls_name}' for cls_name in class_names] | |
text = open_clip.tokenize(txts) | |
text_features = clip_model.encode_text(text.cuda()) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
_, visualized_output = demo.run_on_image(img, class_names,text_features) | |
return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB') | |
def inference_point(input_img, img_state,class_names_input): | |
mp.set_start_method("spawn", force=True) | |
points = img_state.selected_points | |
print(f"Selected point: {points}") | |
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
cfg = setup_cfg(config_file) | |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) | |
if not class_names_input: | |
class_names_input = class_names_coco_ade20k | |
if class_names_input == class_names_coco_ade20k: | |
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding_new.npy")).cuda() | |
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features) | |
else: | |
class_names_input = class_names_input.split(',') | |
txts = [f'a photo of {cls_name}' for cls_name in class_names_input] | |
text = open_clip.tokenize(txts) | |
text_features = clip_model.encode_text(text.cuda()) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features,class_names_input) | |
return visualized_output | |
sam2_model = None | |
clip_model = None | |
mask_adapter = None | |
def inference_box(input_img, img_state,class_names_input): | |
# if len(img_state.selected_bboxes) != 2: | |
# return None | |
mp.set_start_method("spawn", force=True) | |
box_points = img_state.selected_bboxes | |
bbox = ( | |
min(box_points[0][0], box_points[1][0]), | |
min(box_points[0][1], box_points[1][1]), | |
max(box_points[0][0], box_points[1][0]), | |
max(box_points[0][1], box_points[1][1]), | |
) | |
bbox = np.array(bbox) | |
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
cfg = setup_cfg(config_file) | |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) | |
if not class_names_input: | |
class_names_input = class_names_coco_ade20k | |
if class_names_input == class_names_coco_ade20k: | |
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding_new.npy")).cuda() | |
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features) | |
else: | |
class_names_input = class_names_input.split(',') | |
txts = [f'a photo of {cls_name}' for cls_name in class_names_input] | |
text = open_clip.tokenize(txts) | |
text_features = clip_model.encode_text(text.cuda()) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features,class_names_input) | |
return visualized_output | |
def get_points_with_draw(image, img_state, evt: gr.SelectData): | |
label = 'Add Mask' | |
x, y = evt.index[0], evt.index[1] | |
point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13) | |
img_state.selected_points.append([x, y]) | |
img_state.selected_points_labels.append(1 if label == "Add Mask" else 0) | |
if img_state.img is None: | |
img_state.set_img(np.array(image)) | |
draw = ImageDraw.Draw(image) | |
draw.polygon( | |
[ | |
(x, y - point_radius), | |
(x + point_radius * 0.25, y - point_radius * 0.25), | |
(x + point_radius, y), | |
(x + point_radius * 0.25, y + point_radius * 0.25), | |
(x, y + point_radius), | |
(x - point_radius * 0.25, y + point_radius * 0.25), | |
(x - point_radius, y), | |
(x - point_radius * 0.25, y - point_radius * 0.25) | |
], | |
fill=point_color, | |
) | |
return img_state, image | |
def get_bbox_with_draw(image, img_state, evt: gr.SelectData): | |
x, y = evt.index[0], evt.index[1] | |
point_radius, point_color, box_outline = 5, (237, 34, 13), 2 | |
box_color = (237, 34, 13) | |
if len(img_state.selected_bboxes) in [0, 1]: | |
img_state.selected_bboxes.append([x, y]) | |
elif len(img_state.selected_bboxes) == 2: | |
img_state.selected_bboxes = [[x, y]] | |
image = Image.fromarray(img_state.img) | |
else: | |
raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}") | |
if img_state.img is None: | |
img_state.set_img(np.array(image)) | |
draw = ImageDraw.Draw(image) | |
draw.ellipse( | |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], | |
fill=point_color, | |
) | |
if len(img_state.selected_bboxes) == 2: | |
box_points = img_state.selected_bboxes | |
bbox = (min(box_points[0][0], box_points[1][0]), | |
min(box_points[0][1], box_points[1][1]), | |
max(box_points[0][0], box_points[1][0]), | |
max(box_points[0][1], box_points[1][1]), | |
) | |
draw.rectangle( | |
bbox, | |
outline=box_color, | |
width=box_outline | |
) | |
return img_state, image | |
def check_and_infer_box(input_image, img_state_bbox,class_names_input_box): | |
if len(img_state_bbox.selected_bboxes) == 2: | |
return inference_box(input_image, img_state_bbox, class_names_input_box) | |
return None | |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg): | |
cfg = setup_cfg(cfg) | |
global sam2_model, clip_model, mask_adapter | |
if sam2_model is None: | |
sam2_model = build_sam2(model_cfg, sam_path, device="cpu", apply_postprocessing=False) | |
sam2_model = sam2_model.to("cuda") | |
print("SAM2 model initialized.") | |
if clip_model is None: | |
clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup") | |
clip_model = clip_model.eval() | |
clip_model = clip_model.to("cuda") | |
print("CLIP model initialized.") | |
if mask_adapter is None: | |
mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").to("cuda") | |
mask_adapter = mask_adapter.eval() | |
adapter_state_dict = torch.load(adapter_pth) | |
mask_adapter.load_state_dict(adapter_state_dict) | |
print("Mask Adapter model initialized.") | |
def preprocess_example(input_img, img_state): | |
img_state.clear() | |
return img_state,None | |
def clear_everything(img_state): | |
img_state.clear() | |
return img_state, None, None, gr.Textbox(value='',lines=1, placeholder=class_names_coco_ade20k, label='Class Names') | |
def clean_prompts(img_state): | |
img_state.clean() | |
return img_state, Image.fromarray(img_state.img), None | |
# 初始化配置和模型 | |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
sam_path = './sam2.1_hiera_large.pt' | |
adapter_pth = './model_0279999_with_sem_new.pth' | |
cfg = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
initialize_models(sam_path, adapter_pth, model_cfg, cfg) | |
# Examples for testing | |
examples = [ | |
['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'], | |
['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'], | |
['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'], | |
] | |
examples_point = [ | |
['./demo/images/ADE_val_00000739.jpg'], | |
['./demo/images/000000052462.jpg'], | |
['./demo/images/000000081766.jpg'], | |
['./demo/images/ADE_val_00000001.jpg'], | |
['./demo/images/000000033707.jpg'], | |
['./demo/images/ADE_val_00000572.jpg'] | |
] | |
output_labels = ['segmentation map'] | |
title = '<center><h2>Mask-Adapter + Segment Anything-2</h2></center>' | |
description = """ | |
<b>Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation</b><br> | |
Mask-Adapter effectively extends to SAM or SAM-2 without additional training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.<br> | |
<div style="display: flex; gap: 20px;"> | |
<a href="https://arxiv.org/abs/2406.20076"> | |
<img src="https://img.shields.io/badge/arXiv-Paper-red" alt="arXiv Paper"> | |
</a> | |
<a href="https://github.com/hustvl/MaskAdapter"> | |
<img src="https://img.shields.io/badge/GitHub-Code-blue" alt="GitHub Code"> | |
</a> | |
</div> | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem("Automatic Mode"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='filepath', label="Input Image") | |
class_names = gr.Textbox(lines=1, placeholder=None, label='Class Names') | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label='Segmentation Map') | |
# Buttons below segmentation map (now placed under segmentation map) | |
run_button = gr.Button("Run Automatic Segmentation", elem_id="run_button",variant='primary') | |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image) | |
clear_button = gr.Button("Clear") | |
clear_button.click(lambda: None, inputs=None, outputs=output_image) | |
with gr.Row(): | |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image) | |
with gr.TabItem("Box Mode"): | |
img_state_bbox = gr.State(value=IMGState()) | |
with gr.Row(): # 水平排列 | |
with gr.Column(scale=1): | |
input_image = gr.Image( label="Input Image", type="pil") | |
class_names_input_box = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names') | |
with gr.Column(scale=1): | |
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图 | |
clear_prompt_button_box = gr.Button("Clean Prompt") | |
clear_button_box = gr.Button("Restart") | |
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area") | |
input_image.select( | |
get_bbox_with_draw, | |
[input_image, img_state_bbox], | |
outputs=[img_state_bbox, input_image] | |
).then( | |
check_and_infer_box, | |
inputs=[input_image, img_state_bbox,class_names_input_box], | |
outputs=[output_image_box] | |
) | |
clear_prompt_button_box.click( | |
clean_prompts, | |
inputs=[img_state_bbox], | |
outputs=[img_state_bbox, input_image, output_image_box] | |
) | |
clear_button_box.click( | |
clear_everything, | |
inputs=[img_state_bbox], | |
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box] | |
) | |
input_image.clear( | |
clear_everything, | |
inputs=[img_state_bbox], | |
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box] | |
) | |
output_image_box.clear( | |
clear_everything, | |
inputs=[img_state_bbox], | |
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box] | |
) | |
gr.Examples( | |
examples=examples_point, | |
inputs=[input_image, img_state_bbox], | |
outputs=[img_state_bbox, output_image_box], | |
examples_per_page=6, | |
fn=preprocess_example, | |
run_on_click=True, | |
cache_examples=False, | |
) | |
with gr.TabItem("Point Mode"): | |
img_state_points = gr.State(value=IMGState()) | |
with gr.Row(): # 水平排列 | |
with gr.Column(scale=1): | |
input_image = gr.Image( label="Input Image", type="pil") | |
class_names_input_point = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names') | |
with gr.Column(scale=1): | |
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图 | |
clear_prompt_button_point = gr.Button("Clean Prompt") | |
clear_button_point = gr.Button("Restart") | |
input_image.select( | |
get_points_with_draw, | |
[input_image, img_state_points], | |
outputs=[img_state_points, input_image] | |
).then( | |
inference_point, | |
inputs=[input_image, img_state_points,class_names_input_point], | |
outputs=[output_image_point] | |
) | |
clear_prompt_button_point.click( | |
clean_prompts, | |
inputs=[img_state_points], | |
outputs=[img_state_points, input_image, output_image_point] | |
) | |
clear_button_point.click( | |
clear_everything, | |
inputs=[img_state_points], | |
outputs=[img_state_points, input_image, output_image_point,class_names_input_point] | |
) | |
input_image.clear( | |
clear_everything, | |
inputs=[img_state_points], | |
outputs=[img_state_points, input_image, output_image_point,class_names_input_point] | |
) | |
output_image_point.clear( | |
clear_everything, | |
inputs=[img_state_points], | |
outputs=[img_state_points, input_image, output_image_point,class_names_input_point] | |
) | |
gr.Examples( | |
examples=examples_point, | |
inputs=[input_image, img_state_points], | |
outputs=[img_state_points, output_image_point], | |
examples_per_page=6, | |
fn=preprocess_example, | |
run_on_click=True, | |
cache_examples=False, | |
) | |
# Example images below buttons | |
demo.launch() | |