File size: 9,285 Bytes
6706230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw


def unload_mask(mask):
    mask = mask.cpu().numpy().squeeze()
    mask = mask.astype(np.uint8) * 255
    return Image.fromarray(mask)


def unload_box(box):
    return box.cpu().numpy().tolist()


def masks_overlap(mask1, mask2):
    return np.any(np.logical_and(mask1, mask2))


def remove_non_person_masks(person_mask, formatted_results):
    return [
        f
        for f in formatted_results
        if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
    ]


def format_masks(masks):
    return [unload_mask(mask) for mask in masks]


def format_boxes(boxes):
    return [unload_box(box) for box in boxes]


def format_scores(scores):
    return scores.cpu().numpy().tolist()


def unload(result, labels_dict):
    masks = format_masks(result.masks.data)
    boxes = format_boxes(result.boxes.xyxy)
    scores = format_scores(result.boxes.conf)
    labels = result.boxes.cls
    labels = [int(label.item()) for label in labels]
    phrases = [labels_dict[label] for label in labels]
    return masks, boxes, scores, phrases


def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
    if isinstance(list(labels_dict.keys())[0], int):
        labels_dict = {v: k for k, v in labels_dict.items()}

    # check that the person mask is present
    if person_masks_only:
        assert "person" in labels, "Person mask not present in results"
    results_dict = []
    for row in zip(labels, scores, boxes, masks):
        label, score, box, mask = row
        label_id = labels_dict[label]
        results_row = dict(
            label=label, score=score, mask=mask, box=box, label_id=label_id
        )
        results_dict.append(results_row)
    results_dict = sorted(results_dict, key=lambda x: x["label"])
    if person_masks_only:
        # Get the person mask
        person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
        assert person_mask is not None, "Person mask not found in results"

        # Remove any results that do no overlap with the person
        results_dict = remove_non_person_masks(person_mask, results_dict)
    return results_dict


def filter_highest_score(results, labels):
    """
    Filter results to remove entries with lower scores for specified labels.

    Args:
        results (list): List of dictionaries containing 'label', 'score', and other keys.
        labels (list): List of labels to filter.

    Returns:
        list: Filtered results with only the highest score for each specified label.
    """
    # Dictionary to keep track of the highest score entry for each label
    label_highest = {}

    # First pass: identify the highest score for each label
    for result in results:
        label = result["label"]
        if label in labels:
            if (
                label not in label_highest
                or result["score"] > label_highest[label]["score"]
            ):
                label_highest[label] = result

    # Second pass: construct the filtered list while preserving the order
    filtered_results = []
    seen_labels = set()

    for result in results:
        label = result["label"]
        if label in labels:
            if label in seen_labels:
                continue
            if result == label_highest[label]:
                filtered_results.append(result)
                seen_labels.add(label)
        else:
            filtered_results.append(result)

    return filtered_results


def display_image_with_masks(image, results, cols=4, return_images=False):
    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Check image dimensions
    if image_np.ndim != 3 or image_np.shape[2] != 3:
        raise ValueError("Image must be a 3-dimensional array with 3 color channels")

    # Number of masks
    n = len(results)
    rows = (n + cols - 1) // cols  # Calculate required number of rows

    # Setting up the plot
    fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
    axs = np.array(axs).reshape(-1)  # Flatten axs array for easy indexing
    for i, result in enumerate(results):
        mask = result["mask"]
        label = result["label"]
        score = float(result["score"])

        # Convert PIL mask to numpy array and resize if necessary
        mask_np = np.array(mask)
        if mask_np.shape != image_np.shape[:2]:
            mask_np = resize(
                mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
            )
            mask_np = (mask_np > 0.5).astype(
                np.uint8
            )  # Threshold back to binary after resize

        # Create an overlay where mask is True
        overlay = np.zeros_like(image_np)
        overlay[mask_np > 0] = [0, 0, 255]  # Applying blue color on the mask area

        # Combine the image and the overlay
        combined = image_np.copy()
        indices = np.where(mask_np > 0)
        combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5

        # Show the combined image
        ax = axs[i]
        ax.imshow(combined)
        ax.axis("off")
        ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
        rect = patches.Rectangle(
            (0, 0),
            image_np.shape[1],
            image_np.shape[0],
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(rect)

    # Hide unused subplots if the total number of masks is not a multiple of cols
    for idx in range(i + 1, rows * cols):
        axs[idx].axis("off")
    plt.tight_layout()
    plt.show()


def get_bounding_box(mask):
    """
    Given a segmentation mask, return the bounding box for the mask object.
    """
    # Find indices where the mask is non-zero
    coords = np.argwhere(mask)
    # Get the minimum and maximum x and y coordinates
    x_min, y_min = np.min(coords, axis=0)
    x_max, y_max = np.max(coords, axis=0)
    # Return the bounding box coordinates
    return (y_min, x_min, y_max, x_max)


def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
    # Convert the PIL segmentation mask to a NumPy array
    mask_array = np.array(segmentation_mask)

    # Find the coordinates of the non-zero pixels
    non_zero_y, non_zero_x = np.nonzero(mask_array)

    # Calculate the bounding box coordinates
    min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
    min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)

    if widen > 0:
        min_x = max(0, min_x - widen)
        max_x = min(mask_array.shape[1], max_x + widen)

    if elongate > 0:
        min_y = max(0, min_y - elongate)
        max_y = min(mask_array.shape[0], max_y + elongate)

    # Create a new blank image for the bounding box mask
    bounding_box_mask = Image.new("1", segmentation_mask.size)

    # Draw the filled bounding box on the blank image
    draw = ImageDraw.Draw(bounding_box_mask)
    draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)

    return bounding_box_mask


colors = {
    "blue": (136, 207, 249),
    "red": (255, 0, 0),
    "green": (0, 255, 0),
    "yellow": (255, 255, 0),
    "purple": (128, 0, 128),
    "cyan": (0, 255, 255),
    "magenta": (255, 0, 255),
    "orange": (255, 165, 0),
    "lime": (50, 205, 50),
    "pink": (255, 192, 203),
    "brown": (139, 69, 19),
    "gray": (128, 128, 128),
    "black": (0, 0, 0),
    "white": (255, 255, 255),
    "gold": (255, 215, 0),
    "silver": (192, 192, 192),
    "beige": (245, 245, 220),
    "navy": (0, 0, 128),
    "maroon": (128, 0, 0),
    "olive": (128, 128, 0),
}


def overlay_mask(image, mask, opacity=0.5, color="blue"):
    """
    Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
    and color the mask with a low opacity blue with hex #88CFF9.
    """
    # Convert the boolean mask to an image with alpha channel
    alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")

    # Choose the color
    r, g, b = colors[color]

    color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255)))
    mask_rgba = Image.composite(
        color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha
    )

    # Create a new RGBA image to overlay the mask on
    overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))

    # Paste the mask onto the overlay
    overlay.paste(mask_rgba, (0, 0))

    # Create a new image to return by blending the original image and the overlay
    result = Image.alpha_composite(image.convert("RGBA"), overlay)

    # Convert the result back to the original mode and return it
    return result.convert(image.mode)


def resize_preserve_aspect_ratio(image, max_side=512):
    width, height = image.size
    scale = min(max_side / width, max_side / height)
    new_width = int(width * scale)
    new_height = int(height * scale)
    return image.resize((new_width, new_height))


def round_to_nearest_eigth(value):
    return int((value // 8 * 8))


def resize_image_to_nearest_eight(image):
    width, height = image.size
    width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
    image = image.resize((width, height))
    return image