nowsyn's picture
upload codes
54a7220
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
import sys
sys.path.insert(1, os.getcwd())
import tempfile
import time
import warnings
import cv2
import numpy as np
import tqdm
import torch
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger
from mask2former import add_maskformer2_config
from predictor import VisualizationDemo
from annotator.util import annotator_ckpts_path
model_url = "https://huggingface.co/datasets/qqlu1992/Adobe_EntitySeg/resolve/main/CropFormer_model/Entity_Segmentation/CropFormer_hornet_3x.pth"
def make_colors():
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
colors = []
for cate in COCO_CATEGORIES:
colors.append(cate["color"])
return colors
class EntitysegDetector:
def __init__(self, confidence_threshold=0.5):
cfg = get_cfg()
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
workdir = os.getcwd()
config_file = f"{workdir}/annotator/entityseg/configs/cropformer_hornet_3x.yaml"
model_path = f'{annotator_ckpts_path}/CropFormer_hornet_3x_03823a.pth'
# Authentication required
# if not os.path.exists(model_path):
# from basicsr.utils.download_util import load_file_from_url
# load_file_from_url(model_url, model_dir=annotator_ckpts_path)
cfg.merge_from_file(config_file)
opts = ['MODEL.WEIGHTS', model_path]
cfg.merge_from_list(opts)
cfg.freeze()
self.confidence_threshold = confidence_threshold
self.colors = make_colors()
self.demo = VisualizationDemo(cfg)
def __call__(self, image):
predictions = self.demo.run_on_image(image)
##### color_mask
pred_masks = predictions["instances"].pred_masks
pred_scores = predictions["instances"].scores
# select by confidence threshold
selected_indexes = (pred_scores >= self.confidence_threshold)
selected_scores = pred_scores[selected_indexes]
selected_masks = pred_masks[selected_indexes]
_, m_H, m_W = selected_masks.shape
mask_id = np.zeros((m_H, m_W), dtype=np.uint8)
# rank
selected_scores, ranks = torch.sort(selected_scores)
ranks = ranks + 1
for index in ranks:
mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index)
unique_mask_id = np.unique(mask_id)
color_mask = np.zeros(image.shape, dtype=np.uint8)
for count in unique_mask_id:
if count == 0:
continue
color_mask[mask_id==count] = self.colors[count % len(self.colors)]
return color_mask