File size: 2,991 Bytes
54a7220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
import sys
sys.path.insert(1, os.getcwd())

import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm
import torch

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger

from mask2former import add_maskformer2_config
from predictor import VisualizationDemo

from annotator.util import annotator_ckpts_path


model_url = "https://huggingface.co/datasets/qqlu1992/Adobe_EntitySeg/resolve/main/CropFormer_model/Entity_Segmentation/CropFormer_hornet_3x.pth"


def make_colors():
    from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
    colors = []
    for cate in COCO_CATEGORIES:
        colors.append(cate["color"])
    return colors


class EntitysegDetector:

    def __init__(self, confidence_threshold=0.5):
        cfg = get_cfg()
        add_deeplab_config(cfg)
        add_maskformer2_config(cfg)

        workdir = os.getcwd()
        config_file = f"{workdir}/annotator/entityseg/configs/cropformer_hornet_3x.yaml"
        model_path = f'{annotator_ckpts_path}/CropFormer_hornet_3x_03823a.pth'
        # Authentication required
        # if not os.path.exists(model_path):
        #     from basicsr.utils.download_util import load_file_from_url
        #     load_file_from_url(model_url, model_dir=annotator_ckpts_path)

        cfg.merge_from_file(config_file)
        opts = ['MODEL.WEIGHTS', model_path]
        cfg.merge_from_list(opts)
        cfg.freeze()

        self.confidence_threshold = confidence_threshold

        self.colors = make_colors()
        self.demo = VisualizationDemo(cfg)


    def __call__(self, image): 
        predictions = self.demo.run_on_image(image)
        ##### color_mask
        pred_masks = predictions["instances"].pred_masks
        pred_scores = predictions["instances"].scores
        
        # select by confidence threshold
        selected_indexes = (pred_scores >= self.confidence_threshold)
        selected_scores = pred_scores[selected_indexes]
        selected_masks  = pred_masks[selected_indexes]
        _, m_H, m_W = selected_masks.shape
        mask_id = np.zeros((m_H, m_W), dtype=np.uint8)

        # rank
        selected_scores, ranks = torch.sort(selected_scores)
        ranks = ranks + 1
        for index in ranks:
            mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index)
        unique_mask_id = np.unique(mask_id)

        color_mask = np.zeros(image.shape, dtype=np.uint8)
        for count in unique_mask_id:
            if count == 0:
                continue
            color_mask[mask_id==count] = self.colors[count % len(self.colors)]
        
        return color_mask