Spaces:
Build error
Build error
""" | |
Mask R-CNN | |
Display and Visualization Functions. | |
Copyright (c) 2017 Matterport, Inc. | |
Licensed under the MIT License (see LICENSE for details) | |
Written by Waleed Abdulla | |
""" | |
import os | |
import cv2 | |
import sys | |
import random | |
import itertools | |
import colorsys | |
import numpy as np | |
from skimage.measure import find_contours | |
import matplotlib.pyplot as plt | |
from matplotlib import patches, lines | |
from matplotlib.patches import Polygon | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
import IPython.display | |
# Root directory of the project | |
ROOT_DIR = os.path.abspath("../") | |
# Import Mask RCNN | |
sys.path.append(ROOT_DIR) # To find local version of the library | |
from mrcnn import utils | |
############################################################ | |
# Visualization | |
############################################################ | |
def display_images(images, titles=None, cols=4, cmap=None, norm=None, | |
interpolation=None): | |
"""Display the given set of images, optionally with titles. | |
images: list or array of image tensors in HWC format. | |
titles: optional. A list of titles to display with each image. | |
cols: number of images per row | |
cmap: Optional. Color map to use. For example, "Blues". | |
norm: Optional. A Normalize instance to map values to colors. | |
interpolation: Optional. Image interpolation to use for display. | |
""" | |
titles = titles if titles is not None else [""] * len(images) | |
rows = len(images) // cols + 1 | |
plt.figure(figsize=(14, 14 * rows // cols)) | |
i = 1 | |
for image, title in zip(images, titles): | |
plt.subplot(rows, cols, i) | |
plt.title(title, fontsize=9) | |
plt.axis('off') | |
plt.imshow(image.astype(np.uint8), cmap=cmap, | |
norm=norm, interpolation=interpolation) | |
i += 1 | |
plt.show() | |
def random_colors(N, bright=True): | |
""" | |
Generate random colors. | |
To get visually distinct colors, generate them in HSV space then | |
convert to RGB. | |
""" | |
brightness = 1.0 if bright else 0.7 | |
hsv = [(i / N, 1, brightness) for i in range(N)] | |
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) | |
random.shuffle(colors) | |
return colors | |
def apply_mask(image, mask, color, alpha=0.5): | |
"""Apply the given mask to the image. | |
""" | |
for c in range(3): | |
image[:, :, c] = np.where(mask == 1, | |
image[:, :, c] * | |
(1 - alpha) + alpha * color[c] * 255, | |
image[:, :, c]) | |
return image | |
def display_instances(image, boxes, masks, class_ids, class_names, | |
scores=None, title="", | |
figsize=(16, 16), ax=None, | |
show_mask=True, show_bbox=True, | |
colors=None, captions=None): | |
""" | |
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. | |
masks: [height, width, num_instances] | |
class_ids: [num_instances] | |
class_names: list of class names of the dataset | |
scores: (optional) confidence scores for each box | |
title: (optional) Figure title | |
show_mask, show_bbox: To show masks and bounding boxes or not | |
figsize: (optional) the size of the image | |
colors: (optional) An array or colors to use with each object | |
captions: (optional) A list of strings to use as captions for each object | |
""" | |
# Number of instances | |
N = boxes.shape[0] | |
if not N: | |
print("\n*** No instances to display *** \n") | |
else: | |
assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] | |
# If no axis is passed, create one and automatically call show() | |
auto_show = False | |
if not ax: | |
_, ax = plt.subplots(1, figsize=figsize) | |
auto_show = True | |
# Generate random colors | |
colors = colors or random_colors(N) | |
# Show area outside image boundaries. | |
height, width = image.shape[:2] | |
ax.set_ylim(height + 10, -10) | |
ax.set_xlim(-10, width + 10) | |
ax.axis('off') | |
ax.set_title(title) | |
masked_image = image.astype(np.uint32).copy() | |
for i in range(N): | |
color = colors[i] | |
# Bounding box | |
if not np.any(boxes[i]): | |
# Skip this instance. Has no bbox. Likely lost in image cropping. | |
continue | |
y1, x1, y2, x2 = boxes[i] | |
if show_bbox: | |
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, | |
alpha=0.7, linestyle="dashed", | |
edgecolor=color, facecolor='none') | |
ax.add_patch(p) | |
# Label | |
if not captions: | |
class_id = class_ids[i] | |
score = scores[i] if scores is not None else None | |
label = class_names[class_id] | |
caption = "{} {:.3f}".format(label, score) if score else label | |
else: | |
caption = captions[i] | |
ax.text(x1, y1 + 8, caption, | |
color='w', size=11, backgroundcolor="none") | |
# Mask | |
mask = masks[:, :, i] | |
if show_mask: | |
masked_image = apply_mask(masked_image, mask, color) | |
# Mask Polygon | |
# Pad to ensure proper polygons for masks that touch image edges. | |
padded_mask = np.zeros( | |
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) | |
padded_mask[1:-1, 1:-1] = mask | |
contours = find_contours(padded_mask, 0.5) | |
for verts in contours: | |
# Subtract the padding and flip (y, x) to (x, y) | |
verts = np.fliplr(verts) - 1 | |
p = Polygon(verts, facecolor="none", edgecolor=color) | |
ax.add_patch(p) | |
ax.imshow(masked_image.astype(np.uint8)) | |
if auto_show: | |
plt.show() | |
def display_instances_(image, boxes, masks, class_ids, class_names, | |
scores=None, title="", | |
figsize=(16, 16), ax=None, | |
show_mask=True, show_bbox=True, | |
colors=None, captions=None, | |
making_video=False, making_image=False, detect=False, hc=False, real_time=False): | |
""" | |
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. | |
masks: [height, width, num_instances] | |
class_ids: [num_instances] | |
class_names: list of class names of the dataset | |
scores: (optional) confidence scores for each box | |
title: (optional) Figure title | |
show_mask, show_bbox: To show masks and bounding boxes or not | |
figsize: (optional) the size of the image | |
colors: (optional) An array or colors to use with each object | |
captions: (optional) A list of strings to use as captions for each object | |
""" | |
# Number of instances | |
N = boxes.shape[0] | |
if not N: | |
print("\n*** No instances to display *** \n") | |
else: | |
assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] | |
# If no axis is passed, create one and automatically call show() | |
auto_show = True | |
if not ax: | |
fig, ax = plt.subplots(1, figsize=figsize) | |
canvas = FigureCanvas(fig) | |
# Generate random colors | |
if not making_video or not real_time: | |
colors = colors or random_colors(N) | |
# Show area outside image boundaries. | |
height, width = image.shape[:2] | |
ax.set_ylim(height + 1, -1) | |
ax.set_xlim(-1, width + 1) | |
ax.axis('off') | |
ax.set_title(title) | |
masked_image = image.astype(np.uint32).copy() | |
for i in range(N): | |
class_id = class_ids[i] | |
if making_video or real_time: | |
# you can also assign a specific color for each class. etc: | |
# if class_id == 1: | |
# color = colors[0] | |
# elif class_id == 2: | |
# color = colors[1] | |
color = colors[class_id-1] | |
elif hc: | |
#just for hard-code the mask for paper | |
if class_id == 14: | |
color = colors[0] | |
else: | |
color = colors[class_id] | |
else: | |
color = colors[i] | |
# Bounding box | |
if not np.any(boxes[i]): | |
# Skip this instance. Has no bbox. Likely lost in image cropping. | |
continue | |
y1, x1, y2, x2 = boxes[i] | |
if show_bbox: | |
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, | |
alpha=0.7, linestyle="dashed", | |
edgecolor=color, facecolor='none') | |
ax.add_patch(p) | |
# Label | |
if not captions: | |
score = scores[i] if scores is not None else None | |
label = class_names[class_id] | |
x = random.randint(x1, (x1 + x2) // 2) | |
caption = "{} {:.3f}".format(label, score) if score else label | |
else: | |
caption = captions[i] | |
ax.text(x1, y1 + 8, caption, | |
color='w', size=14, backgroundcolor="none") | |
# Mask | |
mask = masks[:, :, i] | |
if show_mask: | |
masked_image = apply_mask(masked_image, mask, color) | |
# Mask Polygon | |
# Pad to ensure proper polygons for masks that touch image edges. | |
padded_mask = np.zeros( | |
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) | |
padded_mask[1:-1, 1:-1] = mask | |
contours = find_contours(padded_mask, 0.5) | |
for verts in contours: | |
# Subtract the padding and flip (y, x) to (x, y) | |
verts = np.fliplr(verts) - 1 | |
p = Polygon(verts, facecolor="none", edgecolor=color) | |
ax.add_patch(p) | |
ax.imshow(masked_image.astype(np.uint8)) | |
if detect: | |
plt.close() | |
return canvas | |
# To transform the drawn figure into ndarray X | |
fig.canvas.draw() | |
X = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
X = X.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
# open cv's RGB style: BGR | |
if not real_time: | |
X = X[..., ::-1] | |
if making_video or real_time: | |
plt.close() | |
return X | |
elif making_image: | |
cv2.imwrite('splash.png', X) | |
if auto_show: | |
plt.show() | |
def display_differences(image, | |
gt_box, gt_class_id, gt_mask, | |
pred_box, pred_class_id, pred_score, pred_mask, | |
class_names, title="", ax=None, | |
show_mask=True, show_box=True, | |
iou_threshold=0.5, score_threshold=0.5): | |
"""Display ground truth and prediction instances on the same image.""" | |
# Match predictions to ground truth | |
gt_match, pred_match, overlaps = utils.compute_matches( | |
gt_box, gt_class_id, gt_mask, | |
pred_box, pred_class_id, pred_score, pred_mask, | |
iou_threshold=iou_threshold, score_threshold=score_threshold) | |
# Ground truth = green. Predictions = red | |
colors = [(0, 1, 0, .8)] * len(gt_match)\ | |
+ [(1, 0, 0, 1)] * len(pred_match) | |
# Concatenate GT and predictions | |
class_ids = np.concatenate([gt_class_id, pred_class_id]) | |
scores = np.concatenate([np.zeros([len(gt_match)]), pred_score]) | |
boxes = np.concatenate([gt_box, pred_box]) | |
masks = np.concatenate([gt_mask, pred_mask], axis=-1) | |
# Captions per instance show score/IoU | |
captions = ["" for m in gt_match] + ["{:.2f} / {:.2f}".format( | |
pred_score[i], | |
(overlaps[i, int(pred_match[i])] | |
if pred_match[i] > -1 else overlaps[i].max())) | |
for i in range(len(pred_match))] | |
# Set title if not provided | |
title = title or "Ground Truth and Detections\n GT=green, pred=red, captions: score/IoU" | |
# Display | |
display_instances( | |
image, | |
boxes, masks, class_ids, | |
class_names, scores, ax=ax, | |
show_bbox=show_box, show_mask=show_mask, | |
colors=colors, captions=captions, | |
title=title) | |
def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10): | |
""" | |
anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates. | |
proposals: [n, 4] the same anchors but refined to fit objects better. | |
""" | |
masked_image = image.copy() | |
# Pick random anchors in case there are too many. | |
ids = np.arange(rois.shape[0], dtype=np.int32) | |
ids = np.random.choice( | |
ids, limit, replace=False) if ids.shape[0] > limit else ids | |
fig, ax = plt.subplots(1, figsize=(12, 12)) | |
if rois.shape[0] > limit: | |
plt.title("Showing {} random ROIs out of {}".format( | |
len(ids), rois.shape[0])) | |
else: | |
plt.title("{} ROIs".format(len(ids))) | |
# Show area outside image boundaries. | |
ax.set_ylim(image.shape[0] + 20, -20) | |
ax.set_xlim(-50, image.shape[1] + 20) | |
ax.axis('off') | |
for i, id in enumerate(ids): | |
color = np.random.rand(3) | |
class_id = class_ids[id] | |
# ROI | |
y1, x1, y2, x2 = rois[id] | |
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, | |
edgecolor=color if class_id else "gray", | |
facecolor='none', linestyle="dashed") | |
ax.add_patch(p) | |
# Refined ROI | |
if class_id: | |
ry1, rx1, ry2, rx2 = refined_rois[id] | |
p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2, | |
edgecolor=color, facecolor='none') | |
ax.add_patch(p) | |
# Connect the top-left corners of the anchor and proposal for easy visualization | |
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) | |
# Label | |
label = class_names[class_id] | |
ax.text(rx1, ry1 + 8, "{}".format(label), | |
color='w', size=11, backgroundcolor="none") | |
# Mask | |
m = utils.unmold_mask(mask[id], rois[id] | |
[:4].astype(np.int32), image.shape) | |
masked_image = apply_mask(masked_image, m, color) | |
ax.imshow(masked_image) | |
# Print stats | |
print("Positive ROIs: ", class_ids[class_ids > 0].shape[0]) | |
print("Negative ROIs: ", class_ids[class_ids == 0].shape[0]) | |
print("Positive Ratio: {:.2f}".format( | |
class_ids[class_ids > 0].shape[0] / class_ids.shape[0])) | |
# TODO: Replace with matplotlib equivalent? | |
def draw_box(image, box, color): | |
"""Draw 3-pixel width bounding boxes on the given image array. | |
color: list of 3 int values for RGB. | |
""" | |
y1, x1, y2, x2 = box | |
image[y1:y1 + 2, x1:x2] = color | |
image[y2:y2 + 2, x1:x2] = color | |
image[y1:y2, x1:x1 + 2] = color | |
image[y1:y2, x2:x2 + 2] = color | |
return image | |
def display_top_masks(image, mask, class_ids, class_names, limit=4): | |
"""Display the given image and the top few class masks.""" | |
to_display = [] | |
titles = [] | |
to_display.append(image) | |
titles.append("H x W={}x{}".format(image.shape[0], image.shape[1])) | |
# Pick top prominent classes in this image | |
unique_class_ids = np.unique(class_ids) | |
mask_area = [np.sum(mask[:, :, np.where(class_ids == i)[0]]) | |
for i in unique_class_ids] | |
top_ids = [v[0] for v in sorted(zip(unique_class_ids, mask_area), | |
key=lambda r: r[1], reverse=True) if v[1] > 0] | |
# Generate images and titles | |
for i in range(limit): | |
class_id = top_ids[i] if i < len(top_ids) else -1 | |
# Pull masks of instances belonging to the same class. | |
m = mask[:, :, np.where(class_ids == class_id)[0]] | |
m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1) | |
to_display.append(m) | |
titles.append(class_names[class_id] if class_id != -1 else "-") | |
display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r") | |
def plot_precision_recall(AP, precisions, recalls): | |
"""Draw the precision-recall curve. | |
AP: Average precision at IoU >= 0.5 | |
precisions: list of precision values | |
recalls: list of recall values | |
""" | |
# Plot the Precision-Recall curve | |
_, ax = plt.subplots(1) | |
ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP)) | |
ax.set_ylim(0, 1.1) | |
ax.set_xlim(0, 1.1) | |
_ = ax.plot(recalls, precisions) | |
def plot_overlaps(gt_class_ids, pred_class_ids, pred_scores, | |
overlaps, class_names, threshold=0.5): | |
"""Draw a grid showing how ground truth objects are classified. | |
gt_class_ids: [N] int. Ground truth class IDs | |
pred_class_id: [N] int. Predicted class IDs | |
pred_scores: [N] float. The probability scores of predicted classes | |
overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes. | |
class_names: list of all class names in the dataset | |
threshold: Float. The prediction probability required to predict a class | |
""" | |
gt_class_ids = gt_class_ids[gt_class_ids != 0] | |
pred_class_ids = pred_class_ids[pred_class_ids != 0] | |
plt.figure(figsize=(12, 10)) | |
plt.imshow(overlaps, interpolation='nearest', cmap=plt.cm.Blues) | |
plt.yticks(np.arange(len(pred_class_ids)), | |
["{} ({:.2f})".format(class_names[int(id)], pred_scores[i]) | |
for i, id in enumerate(pred_class_ids)]) | |
plt.xticks(np.arange(len(gt_class_ids)), | |
[class_names[int(id)] for id in gt_class_ids], rotation=90) | |
thresh = overlaps.max() / 2. | |
for i, j in itertools.product(range(overlaps.shape[0]), | |
range(overlaps.shape[1])): | |
text = "" | |
if overlaps[i, j] > threshold: | |
text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong" | |
color = ("white" if overlaps[i, j] > thresh | |
else "black" if overlaps[i, j] > 0 | |
else "grey") | |
plt.text(j, i, "{:.3f}\n{}".format(overlaps[i, j], text), | |
horizontalalignment="center", verticalalignment="center", | |
fontsize=9, color=color) | |
plt.tight_layout() | |
plt.xlabel("Ground Truth") | |
plt.ylabel("Predictions") | |
def draw_boxes(image, boxes=None, refined_boxes=None, | |
masks=None, captions=None, visibilities=None, | |
title="", ax=None): | |
"""Draw bounding boxes and segmentation masks with different | |
customizations. | |
boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates. | |
refined_boxes: Like boxes, but draw with solid lines to show | |
that they're the result of refining 'boxes'. | |
masks: [N, height, width] | |
captions: List of N titles to display on each box | |
visibilities: (optional) List of values of 0, 1, or 2. Determine how | |
prominent each bounding box should be. | |
title: An optional title to show over the image | |
ax: (optional) Matplotlib axis to draw on. | |
""" | |
# Number of boxes | |
assert boxes is not None or refined_boxes is not None | |
N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0] | |
# Matplotlib Axis | |
if not ax: | |
_, ax = plt.subplots(1, figsize=(12, 12)) | |
# Generate random colors | |
colors = random_colors(N) | |
# Show area outside image boundaries. | |
margin = image.shape[0] // 10 | |
ax.set_ylim(image.shape[0] + margin, -margin) | |
ax.set_xlim(-margin, image.shape[1] + margin) | |
ax.axis('off') | |
ax.set_title(title) | |
masked_image = image.astype(np.uint32).copy() | |
for i in range(N): | |
# Box visibility | |
visibility = visibilities[i] if visibilities is not None else 1 | |
if visibility == 0: | |
color = "gray" | |
style = "dotted" | |
alpha = 0.5 | |
elif visibility == 1: | |
color = colors[i] | |
style = "dotted" | |
alpha = 1 | |
elif visibility == 2: | |
color = colors[i] | |
style = "solid" | |
alpha = 1 | |
# Boxes | |
if boxes is not None: | |
if not np.any(boxes[i]): | |
# Skip this instance. Has no bbox. Likely lost in cropping. | |
continue | |
y1, x1, y2, x2 = boxes[i] | |
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, | |
alpha=alpha, linestyle=style, | |
edgecolor=color, facecolor='none') | |
ax.add_patch(p) | |
# Refined boxes | |
if refined_boxes is not None and visibility > 0: | |
ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32) | |
p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2, | |
edgecolor=color, facecolor='none') | |
ax.add_patch(p) | |
# Connect the top-left corners of the anchor and proposal | |
if boxes is not None: | |
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) | |
# Captions | |
if captions is not None: | |
caption = captions[i] | |
# If there are refined boxes, display captions on them | |
if refined_boxes is not None: | |
y1, x1, y2, x2 = ry1, rx1, ry2, rx2 | |
ax.text(x1, y1, caption, size=11, verticalalignment='top', | |
color='w', backgroundcolor="none", | |
bbox={'facecolor': color, 'alpha': 0.5, | |
'pad': 2, 'edgecolor': 'none'}) | |
# Masks | |
if masks is not None: | |
mask = masks[:, :, i] | |
masked_image = apply_mask(masked_image, mask, color) | |
# Mask Polygon | |
# Pad to ensure proper polygons for masks that touch image edges. | |
padded_mask = np.zeros( | |
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) | |
padded_mask[1:-1, 1:-1] = mask | |
contours = find_contours(padded_mask, 0.5) | |
for verts in contours: | |
# Subtract the padding and flip (y, x) to (x, y) | |
verts = np.fliplr(verts) - 1 | |
p = Polygon(verts, facecolor="none", edgecolor=color) | |
ax.add_patch(p) | |
ax.imshow(masked_image.astype(np.uint8)) | |
def display_table(table): | |
"""Display values in a table format. | |
table: an iterable of rows, and each row is an iterable of values. | |
""" | |
html = "" | |
for row in table: | |
row_html = "" | |
for col in row: | |
row_html += "<td>{:40}</td>".format(str(col)) | |
html += "<tr>" + row_html + "</tr>" | |
html = "<table>" + html + "</table>" | |
IPython.display.display(IPython.display.HTML(html)) | |
def display_weight_stats(model): | |
"""Scans all the weights in the model and returns a list of tuples | |
that contain stats about each weight. | |
""" | |
layers = model.get_trainable_layers() | |
table = [["WEIGHT NAME", "SHAPE", "MIN", "MAX", "STD"]] | |
for l in layers: | |
weight_values = l.get_weights() # list of Numpy arrays | |
weight_tensors = l.weights # list of TF tensors | |
for i, w in enumerate(weight_values): | |
weight_name = weight_tensors[i].name | |
# Detect problematic layers. Exclude biases of conv layers. | |
alert = "" | |
if w.min() == w.max() and not (l.__class__.__name__ == "Conv2D" and i == 1): | |
alert += "<span style='color:red'>*** dead?</span>" | |
if np.abs(w.min()) > 1000 or np.abs(w.max()) > 1000: | |
alert += "<span style='color:red'>*** Overflow?</span>" | |
# Add row | |
table.append([ | |
weight_name + alert, | |
str(w.shape), | |
"{:+9.4f}".format(w.min()), | |
"{:+10.4f}".format(w.max()), | |
"{:+9.4f}".format(w.std()), | |
]) | |
display_table(table) |