|
import matplotlib.patches as patches |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
|
|
|
|
def unload_mask(mask): |
|
mask = mask.cpu().numpy().squeeze() |
|
mask = mask.astype(np.uint8) * 255 |
|
return Image.fromarray(mask) |
|
|
|
|
|
def unload_box(box): |
|
return box.cpu().numpy().tolist() |
|
|
|
|
|
def masks_overlap(mask1, mask2): |
|
return np.any(np.logical_and(mask1, mask2)) |
|
|
|
|
|
def remove_non_person_masks(person_mask, formatted_results): |
|
return [ |
|
f |
|
for f in formatted_results |
|
if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask")) |
|
] |
|
|
|
|
|
def format_masks(masks): |
|
return [unload_mask(mask) for mask in masks] |
|
|
|
|
|
def format_boxes(boxes): |
|
return [unload_box(box) for box in boxes] |
|
|
|
|
|
def format_scores(scores): |
|
return scores.cpu().numpy().tolist() |
|
|
|
|
|
def unload(result, labels_dict): |
|
masks = format_masks(result.masks.data) |
|
boxes = format_boxes(result.boxes.xyxy) |
|
scores = format_scores(result.boxes.conf) |
|
labels = result.boxes.cls |
|
labels = [int(label.item()) for label in labels] |
|
phrases = [labels_dict[label] for label in labels] |
|
return masks, boxes, scores, phrases |
|
|
|
|
|
def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True): |
|
if isinstance(list(labels_dict.keys())[0], int): |
|
labels_dict = {v: k for k, v in labels_dict.items()} |
|
|
|
|
|
if person_masks_only: |
|
assert "person" in labels, "Person mask not present in results" |
|
results_dict = [] |
|
for row in zip(labels, scores, boxes, masks): |
|
label, score, box, mask = row |
|
label_id = labels_dict[label] |
|
results_row = dict( |
|
label=label, score=score, mask=mask, box=box, label_id=label_id |
|
) |
|
results_dict.append(results_row) |
|
results_dict = sorted(results_dict, key=lambda x: x["label"]) |
|
if person_masks_only: |
|
|
|
person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"] |
|
assert person_mask is not None, "Person mask not found in results" |
|
|
|
|
|
results_dict = remove_non_person_masks(person_mask, results_dict) |
|
return results_dict |
|
|
|
|
|
def filter_highest_score(results, labels): |
|
""" |
|
Filter results to remove entries with lower scores for specified labels. |
|
|
|
Args: |
|
results (list): List of dictionaries containing 'label', 'score', and other keys. |
|
labels (list): List of labels to filter. |
|
|
|
Returns: |
|
list: Filtered results with only the highest score for each specified label. |
|
""" |
|
|
|
label_highest = {} |
|
|
|
|
|
for result in results: |
|
label = result["label"] |
|
if label in labels: |
|
if ( |
|
label not in label_highest |
|
or result["score"] > label_highest[label]["score"] |
|
): |
|
label_highest[label] = result |
|
|
|
|
|
filtered_results = [] |
|
seen_labels = set() |
|
|
|
for result in results: |
|
label = result["label"] |
|
if label in labels: |
|
if label in seen_labels: |
|
continue |
|
if result == label_highest[label]: |
|
filtered_results.append(result) |
|
seen_labels.add(label) |
|
else: |
|
filtered_results.append(result) |
|
|
|
return filtered_results |
|
|
|
|
|
def display_image_with_masks(image, results, cols=4, return_images=False): |
|
|
|
image_np = np.array(image) |
|
|
|
|
|
if image_np.ndim != 3 or image_np.shape[2] != 3: |
|
raise ValueError("Image must be a 3-dimensional array with 3 color channels") |
|
|
|
|
|
n = len(results) |
|
rows = (n + cols - 1) // cols |
|
|
|
|
|
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows)) |
|
axs = np.array(axs).reshape(-1) |
|
for i, result in enumerate(results): |
|
mask = result["mask"] |
|
label = result["label"] |
|
score = float(result["score"]) |
|
|
|
|
|
mask_np = np.array(mask) |
|
if mask_np.shape != image_np.shape[:2]: |
|
mask_np = resize( |
|
mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False |
|
) |
|
mask_np = (mask_np > 0.5).astype( |
|
np.uint8 |
|
) |
|
|
|
|
|
overlay = np.zeros_like(image_np) |
|
overlay[mask_np > 0] = [0, 0, 255] |
|
|
|
|
|
combined = image_np.copy() |
|
indices = np.where(mask_np > 0) |
|
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5 |
|
|
|
|
|
ax = axs[i] |
|
ax.imshow(combined) |
|
ax.axis("off") |
|
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12) |
|
rect = patches.Rectangle( |
|
(0, 0), |
|
image_np.shape[1], |
|
image_np.shape[0], |
|
linewidth=1, |
|
edgecolor="r", |
|
facecolor="none", |
|
) |
|
ax.add_patch(rect) |
|
|
|
|
|
for idx in range(i + 1, rows * cols): |
|
axs[idx].axis("off") |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
def get_bounding_box(mask): |
|
""" |
|
Given a segmentation mask, return the bounding box for the mask object. |
|
""" |
|
|
|
coords = np.argwhere(mask) |
|
|
|
x_min, y_min = np.min(coords, axis=0) |
|
x_max, y_max = np.max(coords, axis=0) |
|
|
|
return (y_min, x_min, y_max, x_max) |
|
|
|
|
|
def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0): |
|
|
|
mask_array = np.array(segmentation_mask) |
|
|
|
|
|
non_zero_y, non_zero_x = np.nonzero(mask_array) |
|
|
|
|
|
min_x, max_x = np.min(non_zero_x), np.max(non_zero_x) |
|
min_y, max_y = np.min(non_zero_y), np.max(non_zero_y) |
|
|
|
if widen > 0: |
|
min_x = max(0, min_x - widen) |
|
max_x = min(mask_array.shape[1], max_x + widen) |
|
|
|
if elongate > 0: |
|
min_y = max(0, min_y - elongate) |
|
max_y = min(mask_array.shape[0], max_y + elongate) |
|
|
|
|
|
bounding_box_mask = Image.new("1", segmentation_mask.size) |
|
|
|
|
|
draw = ImageDraw.Draw(bounding_box_mask) |
|
draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1) |
|
|
|
return bounding_box_mask |
|
|
|
|
|
colors = { |
|
"blue": (136, 207, 249), |
|
"red": (255, 0, 0), |
|
"green": (0, 255, 0), |
|
"yellow": (255, 255, 0), |
|
"purple": (128, 0, 128), |
|
"cyan": (0, 255, 255), |
|
"magenta": (255, 0, 255), |
|
"orange": (255, 165, 0), |
|
"lime": (50, 205, 50), |
|
"pink": (255, 192, 203), |
|
"brown": (139, 69, 19), |
|
"gray": (128, 128, 128), |
|
"black": (0, 0, 0), |
|
"white": (255, 255, 255), |
|
"gold": (255, 215, 0), |
|
"silver": (192, 192, 192), |
|
"beige": (245, 245, 220), |
|
"navy": (0, 0, 128), |
|
"maroon": (128, 0, 0), |
|
"olive": (128, 128, 0), |
|
} |
|
|
|
|
|
def overlay_mask(image, mask, opacity=0.5, color="blue"): |
|
""" |
|
Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image |
|
and color the mask with a low opacity blue with hex #88CFF9. |
|
""" |
|
|
|
alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1") |
|
|
|
|
|
r, g, b = colors[color] |
|
|
|
color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255))) |
|
mask_rgba = Image.composite( |
|
color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha |
|
) |
|
|
|
|
|
overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
|
|
|
|
|
overlay.paste(mask_rgba, (0, 0)) |
|
|
|
|
|
result = Image.alpha_composite(image.convert("RGBA"), overlay) |
|
|
|
|
|
return result.convert(image.mode) |
|
|
|
|
|
def resize_preserve_aspect_ratio(image, max_side=512): |
|
width, height = image.size |
|
scale = min(max_side / width, max_side / height) |
|
new_width = int(width * scale) |
|
new_height = int(height * scale) |
|
return image.resize((new_width, new_height)) |
|
|
|
|
|
def round_to_nearest_eigth(value): |
|
return int((value // 8 * 8)) |
|
|
|
|
|
def resize_image_to_nearest_eight(image): |
|
width, height = image.size |
|
width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height) |
|
image = image.resize((width, height)) |
|
return image |
|
|