Mask-Adapter / app.py
wondervictor's picture
Upload 10 files
f773839 verified
raw
history blame
7.88 kB
import multiprocessing as mp
import numpy as np
from PIL import Image
import torch
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data.detection_utils import read_image
from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config
from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo
import gradio as gr
import gdown
import open_clip
from sam2.build_sam import build_sam2
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
# ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy'
# output = './ovseg_swinbase_vitL14_ft_mpt.pth'
# gdown.download(ckpt_url, output, quiet=False)
def setup_cfg(config_file):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
add_fcclip_config(cfg)
add_mask_adapter_config(cfg)
cfg.merge_from_file(config_file)
cfg.freeze()
return cfg
def inference_automatic(input_img, class_names):
mp.set_start_method("spawn", force=True)
config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
cfg = setup_cfg(config_file)
demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
class_names = class_names.split(',')
img = read_image(input_img, format="BGR")
_, visualized_output = demo.run_on_image(img, class_names)
return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
def inference_point(input_img, evt: gr.SelectData,):
# In point mode, implement the logic to process points from the user click (x, y)
# You can adjust your segmentation logic based on clicked points.
x, y = evt.index[0], evt.index[1]
points = [[x, y]] # 假设只选择一个点作为输入
print(f"Selected point: {points}")
import time
start_time = time.time()
mp.set_start_method("spawn", force=True)
config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
cfg = setup_cfg(config_file)
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
end_time = time.time()
print("init time",end_time - start_time)
start_time = time.time()
img = read_image(input_img, format="BGR")
# Assume 'points' is a list of (x, y) coordinates to specify where the user clicks
# Process the image and points to create a segmentation map accordingly
_, visualized_output = demo.run_on_image_with_points(img, points)
end_time = time.time()
print("inf time",end_time - start_time)
return visualized_output
sam2_model = None
clip_model = None
mask_adapter = None
# 加载和初始化函数
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
cfg = setup_cfg(cfg)
global sam2_model, clip_model, mask_adapter
# SAM2初始化
if sam2_model is None:
sam2_model = build_sam2(model_cfg, sam_path, device="cuda", apply_postprocessing=False)
print("SAM2 model initialized.")
# CLIP模型初始化
if clip_model is None:
clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup")
print("CLIP model initialized.")
# Mask Adapter模型初始化
if mask_adapter is None:
mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cuda()
# 加载Adapter状态字典
adapter_state_dict = torch.load(adapter_pth)
adapter_state_dict = {k.replace('mask_adapter.', '').replace('adapter.', ''): v
for k, v in adapter_state_dict["model"].items()
if k.startswith('adapter') or k.startswith('mask_adapter')}
mask_adapter.load_state_dict(adapter_state_dict)
print("Mask Adapter model initialized.")
# 初始化配置和模型
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam_path = '/home/yongkangli/segment-anything-2/checkpoints/sam2.1_hiera_large.pt'
adapter_pth = './model_0279999_with_sem_new.pth'
cfg = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
# 调用初始化函数
initialize_models(sam_path, adapter_pth, model_cfg, cfg)
# Examples for testing
examples = [
['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'],
['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'],
['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'],
]
output_labels = ['segmentation map']
title = '<center><h2>Mask-Adapter + Segment Anything-2</h2></center>'
description = """
<b>Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation</b><br>
Mask-Adapter effectively extends to SAM or SAM-2 without additional training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.<br>
<div style="display: flex; gap: 20px;">
<a href="https://arxiv.org/abs/2406.20076">
<img src="https://img.shields.io/badge/arXiv-Paper-red" alt="arXiv Paper">
</a>
<a href="https://github.com/hustvl/MaskAdapter">
<img src="https://img.shields.io/badge/GitHub-Code-blue" alt="GitHub Code">
</a>
</div>
"""
# Interface with mode selection using Tabs
with gr.Blocks() as demo:
gr.Markdown(title) # Title
gr.Markdown(description) # Description
with gr.Tabs():
with gr.TabItem("Automatic Mode"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type='filepath', label="Input Image")
class_names = gr.Textbox(lines=1, placeholder=None, label='Class Names')
with gr.Column():
output_image = gr.Image(type="pil", label='Segmentation Map')
# Buttons below segmentation map (now placed under segmentation map)
run_button = gr.Button("Run Automatic Segmentation")
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
clear_button = gr.Button("Clear")
clear_button.click(lambda: None, inputs=None, outputs=output_image)
with gr.Row():
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
with gr.TabItem("Point Mode"):
with gr.Row(): # 水平排列
with gr.Column():
input_image = gr.Image(type='filepath', label="Upload Image", interactive=True) # 上传图片并允许交互
points_input = gr.State(value=[]) # 用于存储点击的点
with gr.Column(): # 第二列:分割图输出
output_image_point = gr.Image(type="pil", label='Segmentation Map') # 输出分割图
# 直接使用 `SelectData` 事件触发 `inference_point`
input_image.select(inference_point, inputs=[input_image], outputs=output_image_point)
# 清除分割图的按钮
clear_button_point = gr.Button("Clear Segmentation Map")
clear_button_point.click(lambda: None, inputs=None, outputs=output_image_point)
# Example images below buttons
demo.launch()