Mask-Adapter / app.py
wondervictor's picture
Update app.py
e53c1e6 verified
## Some code was modified from Ovseg and OV-Sam.Thanks to their excellent work.
## Ovseg Code:https://github.com/facebookresearch/ov-seg
## OV-Sam Code:https://github.com/HarborYuan/ovsam
import spaces
import multiprocessing as mp
import numpy as np
from PIL import Image,ImageDraw
import torch
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data.detection_utils import read_image
from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config
from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo
import gradio as gr
import open_clip
from sam2.build_sam import build_sam2
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
from mask_adapter.data.datasets import openseg_classes
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
class_names_coco_ade20k = stuff_classes + ade20k_stuff_classes
def setup_cfg(config_file):
cfg = get_cfg()
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
add_fcclip_config(cfg)
add_mask_adapter_config(cfg)
cfg.merge_from_file(config_file)
cfg.freeze()
return cfg
class IMGState:
def __init__(self):
self.img = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def set_img(self, img):
self.img = img
self.available_to_set = False
def clear(self):
self.img = None
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
self.available_to_set = True
def clean(self):
self.selected_points = []
self.selected_points_labels = []
self.selected_bboxes = []
@property
def available(self):
return self.available_to_set
@spaces.GPU
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def inference_automatic(input_img, class_names):
mp.set_start_method("spawn", force=True)
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
cfg = setup_cfg(config_file)
demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
class_names = class_names.split(',')
img = read_image(input_img, format="BGR")
if len(class_names) == 1:
class_names.append('others')
txts = [f'a photo of {cls_name}' for cls_name in class_names]
text = open_clip.tokenize(txts)
text_features = clip_model.encode_text(text.cuda())
text_features /= text_features.norm(dim=-1, keepdim=True)
_, visualized_output = demo.run_on_image(img, class_names,text_features)
return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
@spaces.GPU
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def inference_point(input_img, img_state,class_names_input):
mp.set_start_method("spawn", force=True)
points = img_state.selected_points
print(f"Selected point: {points}")
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
cfg = setup_cfg(config_file)
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
if not class_names_input:
class_names_input = class_names_coco_ade20k
if class_names_input == class_names_coco_ade20k:
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding_new.npy")).cuda()
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features)
else:
class_names_input = class_names_input.split(',')
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
text = open_clip.tokenize(txts)
text_features = clip_model.encode_text(text.cuda())
text_features /= text_features.norm(dim=-1, keepdim=True)
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features,class_names_input)
return visualized_output
sam2_model = None
clip_model = None
mask_adapter = None
@spaces.GPU
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def inference_box(input_img, img_state,class_names_input):
# if len(img_state.selected_bboxes) != 2:
# return None
mp.set_start_method("spawn", force=True)
box_points = img_state.selected_bboxes
bbox = (
min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
bbox = np.array(bbox)
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
cfg = setup_cfg(config_file)
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
if not class_names_input:
class_names_input = class_names_coco_ade20k
if class_names_input == class_names_coco_ade20k:
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding_new.npy")).cuda()
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
else:
class_names_input = class_names_input.split(',')
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
text = open_clip.tokenize(txts)
text_features = clip_model.encode_text(text.cuda())
text_features /= text_features.norm(dim=-1, keepdim=True)
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features,class_names_input)
return visualized_output
def get_points_with_draw(image, img_state, evt: gr.SelectData):
label = 'Add Mask'
x, y = evt.index[0], evt.index[1]
point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13)
img_state.selected_points.append([x, y])
img_state.selected_points_labels.append(1 if label == "Add Mask" else 0)
if img_state.img is None:
img_state.set_img(np.array(image))
draw = ImageDraw.Draw(image)
draw.polygon(
[
(x, y - point_radius),
(x + point_radius * 0.25, y - point_radius * 0.25),
(x + point_radius, y),
(x + point_radius * 0.25, y + point_radius * 0.25),
(x, y + point_radius),
(x - point_radius * 0.25, y + point_radius * 0.25),
(x - point_radius, y),
(x - point_radius * 0.25, y - point_radius * 0.25)
],
fill=point_color,
)
return img_state, image
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
x, y = evt.index[0], evt.index[1]
point_radius, point_color, box_outline = 5, (237, 34, 13), 2
box_color = (237, 34, 13)
if len(img_state.selected_bboxes) in [0, 1]:
img_state.selected_bboxes.append([x, y])
elif len(img_state.selected_bboxes) == 2:
img_state.selected_bboxes = [[x, y]]
image = Image.fromarray(img_state.img)
else:
raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}")
if img_state.img is None:
img_state.set_img(np.array(image))
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
if len(img_state.selected_bboxes) == 2:
box_points = img_state.selected_bboxes
bbox = (min(box_points[0][0], box_points[1][0]),
min(box_points[0][1], box_points[1][1]),
max(box_points[0][0], box_points[1][0]),
max(box_points[0][1], box_points[1][1]),
)
draw.rectangle(
bbox,
outline=box_color,
width=box_outline
)
return img_state, image
def check_and_infer_box(input_image, img_state_bbox,class_names_input_box):
if len(img_state_bbox.selected_bboxes) == 2:
return inference_box(input_image, img_state_bbox, class_names_input_box)
return None
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
cfg = setup_cfg(cfg)
global sam2_model, clip_model, mask_adapter
if sam2_model is None:
sam2_model = build_sam2(model_cfg, sam_path, device="cpu", apply_postprocessing=False)
sam2_model = sam2_model.to("cuda")
print("SAM2 model initialized.")
if clip_model is None:
clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
clip_model = clip_model.eval()
clip_model = clip_model.to("cuda")
print("CLIP model initialized.")
if mask_adapter is None:
mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").to("cuda")
mask_adapter = mask_adapter.eval()
adapter_state_dict = torch.load(adapter_pth)
mask_adapter.load_state_dict(adapter_state_dict)
print("Mask Adapter model initialized.")
def preprocess_example(input_img, img_state):
img_state.clear()
return img_state,None
def clear_everything(img_state):
img_state.clear()
return img_state, None, None, gr.Textbox(value='',lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
def clean_prompts(img_state):
img_state.clean()
return img_state, Image.fromarray(img_state.img), None
# 初始化配置和模型
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam_path = './sam2.1_hiera_large.pt'
adapter_pth = './model_0279999_with_sem_new.pth'
cfg = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
initialize_models(sam_path, adapter_pth, model_cfg, cfg)
# Examples for testing
examples = [
['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'],
['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'],
['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'],
]
examples_point = [
['./demo/images/ADE_val_00000739.jpg'],
['./demo/images/000000052462.jpg'],
['./demo/images/000000081766.jpg'],
['./demo/images/ADE_val_00000001.jpg'],
['./demo/images/000000033707.jpg'],
['./demo/images/ADE_val_00000572.jpg']
]
output_labels = ['segmentation map']
title = '<center><h2>Mask-Adapter + Segment Anything-2</h2></center>'
description = """
<b>Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation</b><br>
Mask-Adapter effectively extends to SAM or SAM-2 without additional training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.<br>
<div style="display: flex; gap: 20px;">
<a href="https://arxiv.org/abs/2406.20076">
<img src="https://img.shields.io/badge/arXiv-Paper-red" alt="arXiv Paper">
</a>
<a href="https://github.com/hustvl/MaskAdapter">
<img src="https://img.shields.io/badge/GitHub-Code-blue" alt="GitHub Code">
</a>
</div>
"""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Automatic Mode"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type='filepath', label="Input Image")
class_names = gr.Textbox(lines=1, placeholder=None, label='Class Names')
with gr.Column():
output_image = gr.Image(type="pil", label='Segmentation Map')
# Buttons below segmentation map (now placed under segmentation map)
run_button = gr.Button("Run Automatic Segmentation", elem_id="run_button",variant='primary')
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
clear_button = gr.Button("Clear")
clear_button.click(lambda: None, inputs=None, outputs=output_image)
with gr.Row():
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
with gr.TabItem("Box Mode"):
img_state_bbox = gr.State(value=IMGState())
with gr.Row(): # 水平排列
with gr.Column(scale=1):
input_image = gr.Image( label="Input Image", type="pil")
class_names_input_box = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
with gr.Column(scale=1):
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
clear_prompt_button_box = gr.Button("Clean Prompt")
clear_button_box = gr.Button("Restart")
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
input_image.select(
get_bbox_with_draw,
[input_image, img_state_bbox],
outputs=[img_state_bbox, input_image]
).then(
check_and_infer_box,
inputs=[input_image, img_state_bbox,class_names_input_box],
outputs=[output_image_box]
)
clear_prompt_button_box.click(
clean_prompts,
inputs=[img_state_bbox],
outputs=[img_state_bbox, input_image, output_image_box]
)
clear_button_box.click(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
)
input_image.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
)
output_image_box.clear(
clear_everything,
inputs=[img_state_bbox],
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
)
gr.Examples(
examples=examples_point,
inputs=[input_image, img_state_bbox],
outputs=[img_state_bbox, output_image_box],
examples_per_page=6,
fn=preprocess_example,
run_on_click=True,
cache_examples=False,
)
with gr.TabItem("Point Mode"):
img_state_points = gr.State(value=IMGState())
with gr.Row(): # 水平排列
with gr.Column(scale=1):
input_image = gr.Image( label="Input Image", type="pil")
class_names_input_point = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
with gr.Column(scale=1):
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
clear_prompt_button_point = gr.Button("Clean Prompt")
clear_button_point = gr.Button("Restart")
input_image.select(
get_points_with_draw,
[input_image, img_state_points],
outputs=[img_state_points, input_image]
).then(
inference_point,
inputs=[input_image, img_state_points,class_names_input_point],
outputs=[output_image_point]
)
clear_prompt_button_point.click(
clean_prompts,
inputs=[img_state_points],
outputs=[img_state_points, input_image, output_image_point]
)
clear_button_point.click(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
)
input_image.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
)
output_image_point.clear(
clear_everything,
inputs=[img_state_points],
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
)
gr.Examples(
examples=examples_point,
inputs=[input_image, img_state_points],
outputs=[img_state_points, output_image_point],
examples_per_page=6,
fn=preprocess_example,
run_on_click=True,
cache_examples=False,
)
# Example images below buttons
demo.launch()