File size: 11,538 Bytes
6e3a622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcbd7b
 
6e3a622
 
 
 
 
 
1bcbd7b
6e3a622
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from transformers import SamModel, SamProcessor, pipeline
from PIL import Image
import cv2
import random
import numpy as np
import torch
from torch.nn.functional import cosine_similarity
import gradio as gr

class RoiMatching():
    def __init__(self,img1,img2,device='cuda:1', v_min=200, v_max= 7000, mode = 'embedding'):
        """
        Initialize
        :param img1: PIL image
        :param img2:
        """
        self.img1 = img1
        self.img2 = img2
        self.device = device
        self.v_min = v_min
        self.v_max = v_max
        self.mode = mode

    def _sam_everything(self,imgs):
        generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=self.device)
        outputs = generator(imgs, points_per_batch=64,pred_iou_thresh=0.90,stability_score_thresh=0.9,)
        return outputs
    def _mask_criteria(self, masks, v_min=200, v_max= 7000):
        remove_list = set()
        for _i, mask in enumerate(masks):
            if mask.sum() < v_min or mask.sum() > v_max:
                remove_list.add(_i)
        masks = [mask for idx, mask in enumerate(masks) if idx not in remove_list]
        n = len(masks)
        remove_list = set()
        for i in range(n):
            for j in range(i + 1, n):
                mask1, mask2 = masks[i], masks[j]
                intersection = (mask1 & mask2).sum()
                smaller_mask_area = min(masks[i].sum(), masks[j].sum())

                if smaller_mask_area > 0 and (intersection / smaller_mask_area) >= 0.9:
                    if mask1.sum() < mask2.sum():
                        remove_list.add(i)
                    else:
                        remove_list.add(j)
        return [mask for idx, mask in enumerate(masks) if idx not in remove_list]

    def _roi_proto(self, image, masks):
        model = SamModel.from_pretrained("facebook/sam-vit-huge").to(self.device)
        processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
        inputs = processor(image, return_tensors="pt").to(self.device)
        image_embeddings = model.get_image_embeddings(inputs["pixel_values"])
        embs = []
        for _m in masks:
            # Convert mask to uint8, resize, and then back to boolean
            tmp_m = _m.astype(np.uint8)
            tmp_m = cv2.resize(tmp_m, (64, 64), interpolation=cv2.INTER_NEAREST)
            tmp_m = torch.tensor(tmp_m.astype(bool), device=self.device,
                                 dtype=torch.float32)  # Convert to tensor and send to CUDA
            tmp_m = tmp_m.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions to match emb1

            # Element-wise multiplication with emb1
            tmp_emb = image_embeddings * tmp_m
            # (1,256,64,64)

            tmp_emb[tmp_emb == 0] = torch.nan
            emb = torch.nanmean(tmp_emb, dim=(2, 3))
            emb[torch.isnan(emb)] = 0
            embs.append(emb)
        return embs

    def _cosine_similarity(self, vec1, vec2):
        # Ensure vec1 and vec2 are 2D tensors [1, N]
        vec1 = vec1.view(1, -1)
        vec2 = vec2.view(1, -1)
        return cosine_similarity(vec1, vec2).item()

    def _similarity_matrix(self, protos1, protos2):
        # Initialize similarity_matrix as a torch tensor
        similarity_matrix = torch.zeros(len(protos1), len(protos2), device=self.device)
        for i, vec_a in enumerate(protos1):
            for j, vec_b in enumerate(protos2):
                similarity_matrix[i, j] = self._cosine_similarity(vec_a, vec_b)
        # Normalize the similarity matrix
        sim_matrix = (similarity_matrix - similarity_matrix.min()) / (similarity_matrix.max() - similarity_matrix.min())
        return similarity_matrix

    def _roi_match(self, matrix, masks1, masks2, sim_criteria=0.8):
        index_pairs = []
        while torch.any(matrix > sim_criteria):
            max_idx = torch.argmax(matrix)
            max_sim_idx = (max_idx // matrix.shape[1], max_idx % matrix.shape[1])
            if matrix[max_sim_idx[0], max_sim_idx[1]] > sim_criteria:
                index_pairs.append(max_sim_idx)
            matrix[max_sim_idx[0], :] = -1
            matrix[:, max_sim_idx[1]] = -1
        masks1_new = []
        masks2_new = []
        for i, j in index_pairs:
            masks1_new.append(masks1[i])
            masks2_new.append(masks2[j])
        return masks1_new, masks2_new

    def _overlap_pair(self, masks1,masks2):
        self.masks1_cor = []
        self.masks2_cor = []
        k = 0
        for mask in masks1[:-1]:
            k += 1
            print('mask1 {} is finding corresponding region mask...'.format(k))
            m1 = mask
            a1 = mask.sum()
            v1 = np.mean(np.expand_dims(m1, axis=-1) * self.im1)
            overlap = m1 * masks2[-1].astype(np.int64)
            # print(np.unique(overlap))
            if (overlap > 0).sum() / a1 > 0.3:
                counts = np.bincount(overlap.flatten())
                # print(counts)
                sorted_indices = np.argsort(counts)[::-1]
                top_two = sorted_indices[1:3]
                # print(top_two)
                if top_two[-1] == 0:
                    cor_ind = 0
                elif abs(counts[top_two[-1]] - counts[top_two[0]]) / max(counts[top_two[-1]], counts[top_two[0]]) < 0.2:
                    cor_ind = 0
                else:
                    # cor_ind = 0
                    m21 = masks2[top_two[0]-1]
                    m22 = masks2[top_two[1]-1]
                    a21 = masks2[top_two[0]-1].sum()
                    a22 = masks2[top_two[1]-1].sum()
                    v21 = np.mean(np.expand_dims(m21, axis=-1)*self.im2)
                    v22 = np.mean(np.expand_dims(m22, axis=-1)*self.im2)
                    if np.abs(a21-a1) > np.abs(a22-a1):
                        cor_ind = 0
                    else:
                        cor_ind = 1
                    print('area judge to cor_ind {}'.format(cor_ind))
                    if np.abs(v21-v1) < np.abs(v22-v1):
                        cor_ind = 0
                    else:
                        cor_ind = 1
                    # print('value judge to cor_ind {}'.format(cor_ind))
                # print('mask1 {} has found the corresponding region mask: mask2 {}'.format(k, top_two[cor_ind]))

                self.masks2_cor.append(masks2[top_two[cor_ind] - 1])
                self.masks1_cor.append(mask)
        # return masks1_new, masks2_new

    def get_paired_roi(self):
        self.masks1 = self._sam_everything(self.img1)  # len(RM.masks1) 2; RM.masks1[0] dict; RM.masks1[0]['masks'] list
        self.masks2 = self._sam_everything(self.img2)
        self.masks1 = self._mask_criteria(self.masks1['masks'], v_min=self.v_min, v_max=self.v_max)
        self.masks2 = self._mask_criteria(self.masks2['masks'], v_min=self.v_min, v_max=self.v_max)

        match self.mode:
            case 'embedding':
                if len(self.masks1) > 0 and len(self.masks2) > 0:
                    self.embs1 = self._roi_proto(self.img1,self.masks1) #device:cuda1
                    self.embs2 = self._roi_proto(self.img2,self.masks2)
                    self.sim_matrix = self._similarity_matrix(self.embs1, self.embs2)
                    self.masks1, self.masks2 = self._roi_match(self.sim_matrix,self.masks1,self.masks2)
            case 'overlaping':
                self._overlap_pair(self.masks1,self.masks2)

def visualize_masks(image1, masks1, image2, masks2):
    # Convert PIL images to numpy arrays
    background1 = np.array(image1)
    background2 = np.array(image2)

    # Convert RGB to BGR (OpenCV uses BGR color format)
    background1 = cv2.cvtColor(background1, cv2.COLOR_RGB2BGR)
    background2 = cv2.cvtColor(background2, cv2.COLOR_RGB2BGR)

    # Create a blank mask for each image
    mask1 = np.zeros_like(background1)
    mask2 = np.zeros_like(background2)

    distinct_colors = [
        (255, 0, 0),  # Red
        (0, 255, 0),  # Green
        (0, 0, 255),  # Blue
        (255, 255, 0),  # Cyan
        (255, 0, 255),  # Magenta
        (0, 255, 255),  # Yellow
        (128, 0, 0),  # Maroon
        (0, 128, 0),  # Olive
        (0, 0, 128),  # Navy
        (128, 128, 0),  # Teal
        (128, 0, 128),  # Purple
        (0, 128, 128),  # Gray
        (192, 192, 192)  # Silver
    ]

    def random_color():
        """Generate a random color with high saturation and value in HSV color space."""
        hue = random.randint(0, 179)  # Random hue value between 0 and 179 (HSV uses 0-179 range)
        saturation = random.randint(200, 255)  # High saturation value between 200 and 255
        value = random.randint(200, 255)  # High value (brightness) between 200 and 255
        color = np.array([[[hue, saturation, value]]], dtype=np.uint8)
        return cv2.cvtColor(color, cv2.COLOR_HSV2BGR)[0][0]


    # Iterate through mask lists and overlay on the blank masks with different colors
    for idx, (mask1_item, mask2_item) in enumerate(zip(masks1, masks2)):
        # color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        # color = distinct_colors[idx % len(distinct_colors)]
        color = random_color()
        # Convert binary masks to uint8
        mask1_item = np.uint8(mask1_item)
        mask2_item = np.uint8(mask2_item)

        # Create a mask where binary mask is True
        fg_mask1 = np.where(mask1_item, 255, 0).astype(np.uint8)
        fg_mask2 = np.where(mask2_item, 255, 0).astype(np.uint8)

        # Apply the foreground masks on the corresponding masks with the same color
        mask1[fg_mask1 > 0] = color
        mask2[fg_mask2 > 0] = color

    # Add the masks on top of the background images
    result1 = cv2.addWeighted(background1, 1, mask1, 0.5, 0)
    result2 = cv2.addWeighted(background2, 1, mask2, 0.5, 0)

    return result1, result2

def predict(im1,im2):
    RM = RoiMatching(im1,im2,device='cpu')
    RM.get_paired_roi()
    visualized_image1, visualized_image2 = visualize_masks(im1, RM.masks1, im2, RM.masks2)
    return visualized_image1, visualized_image2

examples = [
            ['./example/prostate_2d/image1.png', './example/prostate_2d/image2.png'],
            ['./example/cardiac_2d/image1.png', './example/cardiac_2d/image2.png'],
            ['./example/pathology/1B_B7_R.png', './example/pathology/1B_B7_T.png'],
           ]


gradio_app = gr.Interface(
    predict,
    inputs=[gr.Image(label="img1", sources=['upload', 'webcam'], type="pil"), gr.Image(label="img2", sources=['upload', 'webcam'], type="pil")]
    outputs=[gr.Image(label="ROIs in img1"), gr.Image(label="ROIs in img2")],
    title="SAMReg: One Registration is Worth Two Segmentations",
    examples=examples,
    description="<p> \
                    <strong>Register anything with ROI-based registration representation.</strong> <br>\
                    Choose an example below &#128293; &#128293;  &#128293; <br>\
                    Or, upload by yourself: <br>\
                    1. Upload images to be valided to 'img1' and 'img2'. <br>2. Upload images to 'im1' and 'im2'.  <br>\
                            <br> \
                            πŸ’Ž SAM segments the target with any point or scribble, then SegGPT segments all other images. <br>\
                            πŸ’Ž Examples below were never trained and are randomly selected for testing in the wild. <br>\
                            πŸ’Ž Current UI interface only unleashes a small part of the capabilities of SegGPT, i.e., 1-shot case. \
</p>",
)