File size: 5,504 Bytes
6c016cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
metadata = MetadataCatalog.get('coco_2017_train_panoptic')


class SemanticSAMPredictor:
    def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
        """
        thresh: iou thresh to filter low confidence objects
        text_size: resize the input image short edge for the model to process
        hole_scale: fill in small holes as in SAM
        island_scale: remove small regions as in SAM
        """
        self.model = model
        self.thresh = thresh
        self.text_size = hole_scale
        self.hole_scale = hole_scale
        self.island_scale = island_scale
        self.point = None

    def predict(self, image_ori, image, point=None):
        """
        produce up to 6 prediction results for each click
        """
        width = image_ori.shape[0]
        height = image_ori.shape[1]

        data = {"image": image, "height": height, "width": width}
        # import ipdb; ipdb.set_trace()
        if point is None:
            point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
        else:
            point = torch.tensor(point).cuda()
            point_ = point
            point = point_.clone()
            point[0, 0] = point_[0, 0]
            point[0, 1] = point_[0, 1]
            # point = point[:, [1, 0]]
            point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)

        self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))

        data['targets'] = [dict()]
        data['targets'][0]['points'] = point
        data['targets'][0]['pb'] = point.new_tensor([0.])

        batch_inputs = [data]
        masks, ious = self.model.model.evaluate_demo(batch_inputs)

        return masks, ious

    def process_multi_mask(self, masks, ious, image_ori):
        pred_masks_poses = masks
        reses = []
        ious = ious[0, 0]
        ids = torch.argsort(ious, descending=True)

        text_res = ''
        mask_ls = []
        ious_res = []
        areas = []
        for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
            iou = round(float(iou), 2)
            texts = f'{iou}'
            mask = (pred_masks_pos > 0.0).cpu().numpy()
            area = mask.sum()
            conti = False
            if iou < self.thresh:
                conti = True
            for m in mask_ls:
                if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
                    conti = True
                    break
            if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
                conti = False
            if conti:
                continue
            ious_res.append(iou)
            mask_ls.append(mask)
            areas.append(area)
            mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
            mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
            mask = (mask).astype(np.float)
            out_txt = texts
            visual = Visualizer(image_ori, metadata=metadata)
            color = [0., 0., 1.0]
            demo = visual.draw_binary_mask(mask, color=color, text=texts)
            res = demo.get_image()
            point_x0 = max(0, int(self.point[0, 0]) - 3)
            point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
            point_y0 = max(0, int(self.point[0, 1]) - 3)
            point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
            res[point_y0:point_y1, point_x0:point_x1, 0] = 255
            res[point_y0:point_y1, point_x0:point_x1, 1] = 0
            res[point_y0:point_y1, point_x0:point_x1, 2] = 0
            reses.append(Image.fromarray(res))
            text_res = text_res + ';' + out_txt
        ids = list(torch.argsort(torch.tensor(areas), descending=False))
        ids = [int(i) for i in ids]

        torch.cuda.empty_cache()

        return reses, [reses[i] for i in ids]

    def predict_masks(self, image_ori, image, point=None):
        masks, ious = self.predict(image_ori, image, point)
        return self.process_multi_mask(masks, ious, image_ori)

    @staticmethod
    def remove_small_regions(
            mask: np.ndarray, area_thresh: float, mode: str
    ) -> Tuple[np.ndarray, bool]:
        """
        Removes small disconnected regions and holes in a mask. Returns the
        mask and an indicator of if the mask has been modified.
        """
        import cv2  # type: ignore

        assert mode in ["holes", "islands"]
        correct_holes = mode == "holes"
        working_mask = (correct_holes ^ mask).astype(np.uint8)
        n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
        sizes = stats[:, -1][1:]  # Row 0 is background label
        small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
        if len(small_regions) == 0:
            return mask, False
        fill_labels = [0] + small_regions
        if not correct_holes:
            fill_labels = [i for i in range(n_labels) if i not in fill_labels]
            # If every region is below threshold, keep largest
            if len(fill_labels) == 0:
                fill_labels = [int(np.argmax(sizes)) + 1]
        mask = np.isin(regions, fill_labels)
        return mask, True