Spaces:
Running
on
Zero
Running
on
Zero
import multiprocessing as mp | |
import numpy as np | |
from PIL import Image | |
import torch | |
try: | |
import detectron2 | |
except: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
from detectron2.config import get_cfg | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.data.detection_utils import read_image | |
from mask_adapter import add_maskformer2_config, add_fcclip_config, add_mask_adapter_config | |
from mask_adapter.sam_maskadapter import SAMVisualizationDemo, SAMPointVisualizationDemo | |
import gradio as gr | |
import gdown | |
import open_clip | |
from sam2.build_sam import build_sam2 | |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter | |
# ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy' | |
# output = './ovseg_swinbase_vitL14_ft_mpt.pth' | |
# gdown.download(ckpt_url, output, quiet=False) | |
def setup_cfg(config_file): | |
# load config from file and command-line arguments | |
cfg = get_cfg() | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
add_fcclip_config(cfg) | |
add_mask_adapter_config(cfg) | |
cfg.merge_from_file(config_file) | |
cfg.freeze() | |
return cfg | |
def inference_automatic(input_img, class_names): | |
mp.set_start_method("spawn", force=True) | |
config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
cfg = setup_cfg(config_file) | |
demo = SAMVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) | |
class_names = class_names.split(',') | |
img = read_image(input_img, format="BGR") | |
_, visualized_output = demo.run_on_image(img, class_names) | |
return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB') | |
def inference_point(input_img, evt: gr.SelectData,): | |
# In point mode, implement the logic to process points from the user click (x, y) | |
# You can adjust your segmentation logic based on clicked points. | |
x, y = evt.index[0], evt.index[1] | |
points = [[x, y]] # 假设只选择一个点作为输入 | |
print(f"Selected point: {points}") | |
import time | |
start_time = time.time() | |
mp.set_start_method("spawn", force=True) | |
config_file = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
cfg = setup_cfg(config_file) | |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter) | |
end_time = time.time() | |
print("init time",end_time - start_time) | |
start_time = time.time() | |
img = read_image(input_img, format="BGR") | |
# Assume 'points' is a list of (x, y) coordinates to specify where the user clicks | |
# Process the image and points to create a segmentation map accordingly | |
_, visualized_output = demo.run_on_image_with_points(img, points) | |
end_time = time.time() | |
print("inf time",end_time - start_time) | |
return visualized_output | |
sam2_model = None | |
clip_model = None | |
mask_adapter = None | |
# 加载和初始化函数 | |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg): | |
cfg = setup_cfg(cfg) | |
global sam2_model, clip_model, mask_adapter | |
# SAM2初始化 | |
if sam2_model is None: | |
sam2_model = build_sam2(model_cfg, sam_path, device="cuda", apply_postprocessing=False) | |
print("SAM2 model initialized.") | |
# CLIP模型初始化 | |
if clip_model is None: | |
clip_model, _, _ = open_clip.create_model_and_transforms("convnext_large_d_320", pretrained="laion2b_s29b_b131k_ft_soup") | |
print("CLIP model initialized.") | |
# Mask Adapter模型初始化 | |
if mask_adapter is None: | |
mask_adapter = build_mask_adapter(cfg, "MASKAdapterHead").cuda() | |
# 加载Adapter状态字典 | |
adapter_state_dict = torch.load(adapter_pth) | |
adapter_state_dict = {k.replace('mask_adapter.', '').replace('adapter.', ''): v | |
for k, v in adapter_state_dict["model"].items() | |
if k.startswith('adapter') or k.startswith('mask_adapter')} | |
mask_adapter.load_state_dict(adapter_state_dict) | |
print("Mask Adapter model initialized.") | |
# 初始化配置和模型 | |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
sam_path = '/home/yongkangli/segment-anything-2/checkpoints/sam2.1_hiera_large.pt' | |
adapter_pth = './model_0279999_with_sem_new.pth' | |
cfg = '/home/yongkangli/Mask-Adapter/configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml' | |
# 调用初始化函数 | |
initialize_models(sam_path, adapter_pth, model_cfg, cfg) | |
# Examples for testing | |
examples = [ | |
['./demo/images/000000001025.jpg', 'dog, beach, trees, sea, sky, snow, person, rocks, buildings, birds, beach umbrella, beach chair'], | |
['./demo/images/ADE_val_00000979.jpg', 'sky,sea,mountain,pier,beach,island,,landscape,horizon'], | |
['./demo/images/ADE_val_00001200.jpg', 'bridge, mountains, trees, water, sky, buildings, boats, animals, flowers, waterfalls, grasslands, rocks'], | |
] | |
output_labels = ['segmentation map'] | |
title = '<center><h2>Mask-Adapter + Segment Anything-2</h2></center>' | |
description = """ | |
<b>Mask-Adapter: The Devil is in the Masks for Open-Vocabulary Segmentation</b><br> | |
Mask-Adapter effectively extends to SAM or SAM-2 without additional training, achieving impressive results across multiple open-vocabulary segmentation benchmarks.<br> | |
<div style="display: flex; gap: 20px;"> | |
<a href="https://arxiv.org/abs/2406.20076"> | |
<img src="https://img.shields.io/badge/arXiv-Paper-red" alt="arXiv Paper"> | |
</a> | |
<a href="https://github.com/hustvl/MaskAdapter"> | |
<img src="https://img.shields.io/badge/GitHub-Code-blue" alt="GitHub Code"> | |
</a> | |
</div> | |
""" | |
# Interface with mode selection using Tabs | |
with gr.Blocks() as demo: | |
gr.Markdown(title) # Title | |
gr.Markdown(description) # Description | |
with gr.Tabs(): | |
with gr.TabItem("Automatic Mode"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='filepath', label="Input Image") | |
class_names = gr.Textbox(lines=1, placeholder=None, label='Class Names') | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label='Segmentation Map') | |
# Buttons below segmentation map (now placed under segmentation map) | |
run_button = gr.Button("Run Automatic Segmentation") | |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image) | |
clear_button = gr.Button("Clear") | |
clear_button.click(lambda: None, inputs=None, outputs=output_image) | |
with gr.Row(): | |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image) | |
with gr.TabItem("Point Mode"): | |
with gr.Row(): # 水平排列 | |
with gr.Column(): | |
input_image = gr.Image(type='filepath', label="Upload Image", interactive=True) # 上传图片并允许交互 | |
points_input = gr.State(value=[]) # 用于存储点击的点 | |
with gr.Column(): # 第二列:分割图输出 | |
output_image_point = gr.Image(type="pil", label='Segmentation Map') # 输出分割图 | |
# 直接使用 `SelectData` 事件触发 `inference_point` | |
input_image.select(inference_point, inputs=[input_image], outputs=output_image_point) | |
# 清除分割图的按钮 | |
clear_button_point = gr.Button("Clear Segmentation Map") | |
clear_button_point.click(lambda: None, inputs=None, outputs=output_image_point) | |
# Example images below buttons | |
demo.launch() | |