File size: 9,285 Bytes
6706230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw
def unload_mask(mask):
mask = mask.cpu().numpy().squeeze()
mask = mask.astype(np.uint8) * 255
return Image.fromarray(mask)
def unload_box(box):
return box.cpu().numpy().tolist()
def masks_overlap(mask1, mask2):
return np.any(np.logical_and(mask1, mask2))
def remove_non_person_masks(person_mask, formatted_results):
return [
for f in formatted_results
if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
def format_masks(masks):
return [unload_mask(mask) for mask in masks]
def format_boxes(boxes):
return [unload_box(box) for box in boxes]
def format_scores(scores):
return scores.cpu().numpy().tolist()
def unload(result, labels_dict):
masks = format_masks(
boxes = format_boxes(result.boxes.xyxy)
scores = format_scores(result.boxes.conf)
labels = result.boxes.cls
labels = [int(label.item()) for label in labels]
phrases = [labels_dict[label] for label in labels]
return masks, boxes, scores, phrases
def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
if isinstance(list(labels_dict.keys())[0], int):
labels_dict = {v: k for k, v in labels_dict.items()}
# check that the person mask is present
if person_masks_only:
assert "person" in labels, "Person mask not present in results"
results_dict = []
for row in zip(labels, scores, boxes, masks):
label, score, box, mask = row
label_id = labels_dict[label]
results_row = dict(
label=label, score=score, mask=mask, box=box, label_id=label_id
results_dict = sorted(results_dict, key=lambda x: x["label"])
if person_masks_only:
# Get the person mask
person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
assert person_mask is not None, "Person mask not found in results"
# Remove any results that do no overlap with the person
results_dict = remove_non_person_masks(person_mask, results_dict)
return results_dict
def filter_highest_score(results, labels):
Filter results to remove entries with lower scores for specified labels.
results (list): List of dictionaries containing 'label', 'score', and other keys.
labels (list): List of labels to filter.
list: Filtered results with only the highest score for each specified label.
# Dictionary to keep track of the highest score entry for each label
label_highest = {}
# First pass: identify the highest score for each label
for result in results:
label = result["label"]
if label in labels:
if (
label not in label_highest
or result["score"] > label_highest[label]["score"]
label_highest[label] = result
# Second pass: construct the filtered list while preserving the order
filtered_results = []
seen_labels = set()
for result in results:
label = result["label"]
if label in labels:
if label in seen_labels:
if result == label_highest[label]:
return filtered_results
def display_image_with_masks(image, results, cols=4, return_images=False):
# Convert PIL Image to numpy array
image_np = np.array(image)
# Check image dimensions
if image_np.ndim != 3 or image_np.shape[2] != 3:
raise ValueError("Image must be a 3-dimensional array with 3 color channels")
# Number of masks
n = len(results)
rows = (n + cols - 1) // cols # Calculate required number of rows
# Setting up the plot
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
for i, result in enumerate(results):
mask = result["mask"]
label = result["label"]
score = float(result["score"])
# Convert PIL mask to numpy array and resize if necessary
mask_np = np.array(mask)
if mask_np.shape != image_np.shape[:2]:
mask_np = resize(
mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
mask_np = (mask_np > 0.5).astype(
) # Threshold back to binary after resize
# Create an overlay where mask is True
overlay = np.zeros_like(image_np)
overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
# Combine the image and the overlay
combined = image_np.copy()
indices = np.where(mask_np > 0)
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
# Show the combined image
ax = axs[i]
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
rect = patches.Rectangle(
(0, 0),
# Hide unused subplots if the total number of masks is not a multiple of cols
for idx in range(i + 1, rows * cols):
def get_bounding_box(mask):
Given a segmentation mask, return the bounding box for the mask object.
# Find indices where the mask is non-zero
coords = np.argwhere(mask)
# Get the minimum and maximum x and y coordinates
x_min, y_min = np.min(coords, axis=0)
x_max, y_max = np.max(coords, axis=0)
# Return the bounding box coordinates
return (y_min, x_min, y_max, x_max)
def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
# Convert the PIL segmentation mask to a NumPy array
mask_array = np.array(segmentation_mask)
# Find the coordinates of the non-zero pixels
non_zero_y, non_zero_x = np.nonzero(mask_array)
# Calculate the bounding box coordinates
min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)
if widen > 0:
min_x = max(0, min_x - widen)
max_x = min(mask_array.shape[1], max_x + widen)
if elongate > 0:
min_y = max(0, min_y - elongate)
max_y = min(mask_array.shape[0], max_y + elongate)
# Create a new blank image for the bounding box mask
bounding_box_mask ="1", segmentation_mask.size)
# Draw the filled bounding box on the blank image
draw = ImageDraw.Draw(bounding_box_mask)
draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)
return bounding_box_mask
colors = {
"blue": (136, 207, 249),
"red": (255, 0, 0),
"green": (0, 255, 0),
"yellow": (255, 255, 0),
"purple": (128, 0, 128),
"cyan": (0, 255, 255),
"magenta": (255, 0, 255),
"orange": (255, 165, 0),
"lime": (50, 205, 50),
"pink": (255, 192, 203),
"brown": (139, 69, 19),
"gray": (128, 128, 128),
"black": (0, 0, 0),
"white": (255, 255, 255),
"gold": (255, 215, 0),
"silver": (192, 192, 192),
"beige": (245, 245, 220),
"navy": (0, 0, 128),
"maroon": (128, 0, 0),
"olive": (128, 128, 0),
def overlay_mask(image, mask, opacity=0.5, color="blue"):
Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
and color the mask with a low opacity blue with hex #88CFF9.
# Convert the boolean mask to an image with alpha channel
alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")
# Choose the color
r, g, b = colors[color]
color_mask ="RGBA", mask.size, (r, g, b, int(opacity * 255)))
mask_rgba = Image.composite(
color_mask,"RGBA", mask.size, (0, 0, 0, 0)), alpha
# Create a new RGBA image to overlay the mask on
overlay ="RGBA", image.size, (0, 0, 0, 0))
# Paste the mask onto the overlay
overlay.paste(mask_rgba, (0, 0))
# Create a new image to return by blending the original image and the overlay
result = Image.alpha_composite(image.convert("RGBA"), overlay)
# Convert the result back to the original mode and return it
return result.convert(image.mode)
def resize_preserve_aspect_ratio(image, max_side=512):
width, height = image.size
scale = min(max_side / width, max_side / height)
new_width = int(width * scale)
new_height = int(height * scale)
return image.resize((new_width, new_height))
def round_to_nearest_eigth(value):
return int((value // 8 * 8))
def resize_image_to_nearest_eight(image):
width, height = image.size
width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
image = image.resize((width, height))
return image