import matplotlib |
from matplotlib import pyplot as plt |
from matplotlib.lines import Line2D |
import cv2 |
import numpy as np |
import torch |
from torchvision.transforms import Compose, Normalize, ToTensor |
from typing import List, Dict |
import math |
def preprocess_image( |
img: np.ndarray, mean=[ |
0.5, 0.5, 0.5], std=[ |
0.5, 0.5, 0.5]) -> torch.Tensor: |
preprocessing = Compose([ |
ToTensor(), |
Normalize(mean=mean, std=std) |
]) |
return preprocessing(img.copy()).unsqueeze(0) |
def deprocess_image(img): |
""" see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ |
img = img - np.mean(img) |
img = img / (np.std(img) + 1e-5) |
img = img * 0.1 |
img = img + 0.5 |
img = np.clip(img, 0, 1) |
return np.uint8(img * 255) |
def show_cam_on_image(img: np.ndarray, |
mask: np.ndarray, |
use_rgb: bool = False, |
colormap: int = cv2.COLORMAP_JET, |
image_weight: float = 0.5) -> np.ndarray: |
""" This function overlays the cam mask on the image as an heatmap. |
By default the heatmap is in BGR format. |
:param img: The base image in RGB or BGR format. |
:param mask: The cam mask. |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. |
:param colormap: The OpenCV colormap to be used. |
:param image_weight: The final result is image_weight * img + (1-image_weight) * mask. |
:returns: The default image with the cam overlay. |
""" |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) |
if use_rgb: |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
heatmap = np.float32(heatmap) / 255 |
if np.max(img) > 1: |
raise Exception( |
"The input image should np.float32 in the range [0, 1]") |
if image_weight < 0 or image_weight > 1: |
raise Exception( |
f"image_weight should be in the range [0, 1].\ |
Got: {image_weight}") |
scalar = (1 - image_weight) / 2 |
image_weight = 1 - scalar - scalar * heatmap |
cam = (1 - image_weight) * heatmap + image_weight * img |
cam = cam / np.max(cam) |
return np.uint8(255 * cam) |
def create_labels_legend(concept_scores: np.ndarray, |
labels: Dict[int, str], |
top_k=2): |
concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] |
concept_labels_topk = [] |
for concept_index in range(concept_categories.shape[0]): |
categories = concept_categories[concept_index, :] |
concept_labels = [] |
for category in categories: |
score = concept_scores[concept_index, category] |
label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}" |
concept_labels.append(label) |
concept_labels_topk.append("\n".join(concept_labels)) |
return concept_labels_topk |
def show_factorization_on_image(img: np.ndarray, |
explanations: np.ndarray, |
colors: List[np.ndarray] = None, |
image_weight: float = 0.5, |
concept_labels: List = None) -> np.ndarray: |
""" Color code the different component heatmaps on top of the image. |
Every component color code will be magnified according to the heatmap itensity |
(by modifying the V channel in the HSV color space), |
and optionally create a lagend that shows the labels. |
Since different factorization component heatmaps can overlap in principle, |
we need a strategy to decide how to deal with the overlaps. |
This keeps the component that has a higher value in it's heatmap. |
:param img: The base image RGB format. |
:param explanations: A tensor of shape num_componetns x height x width, with the component visualizations. |
:param colors: List of R, G, B colors to be used for the components. |
If None, will use the gist_rainbow cmap as a default. |
:param image_weight: The final result is image_weight * img + (1-image_weight) * visualization. |
:concept_labels: A list of strings for every component. If this is paseed, a legend that shows |
the labels and their colors will be added to the image. |
:returns: The visualized image. |
""" |
n_components = explanations.shape[0] |
if colors is None: |
_cmap = plt.cm.get_cmap('gist_rainbow') |
colors = [ |
np.array( |
_cmap(i)) for i in np.arange( |
0, |
1, |
1.0 / |
n_components)] |
concept_per_pixel = explanations.argmax(axis=0) |
masks = [] |
for i in range(n_components): |
mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) |
mask[:, :, :] = colors[i][:3] |
explanation = explanations[i] |
explanation[concept_per_pixel != i] = 0 |
mask = np.uint8(mask * 255) |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) |
mask[:, :, 2] = np.uint8(255 * explanation) |
mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) |
mask = np.float32(mask) / 255 |
masks.append(mask) |
mask = np.sum(np.float32(masks), axis=0) |
result = img * image_weight + mask * (1 - image_weight) |
result = np.uint8(result * 255) |
if concept_labels is not None: |
px = 1 / plt.rcParams['figure.dpi'] |
fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) |
plt.rcParams['legend.fontsize'] = int( |
14 * result.shape[0] / 256 / max(1, n_components / 6)) |
lw = 5 * result.shape[0] / 256 |
lines = [Line2D([0], [0], color=colors[i], lw=lw) |
for i in range(n_components)] |
plt.legend(lines, |
concept_labels, |
mode="expand", |
fancybox=True, |
shadow=True) |
plt.tight_layout(pad=0, w_pad=0, h_pad=0) |
plt.axis('off') |
fig.canvas.draw() |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
plt.close(fig=fig) |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
data = cv2.resize(data, (result.shape[1], result.shape[0])) |
result = np.hstack((result, data)) |
return result |
def scale_cam_image(cam, target_size=None): |
result = [] |
for img in cam: |
img = img - np.min(img) |
img = img / (1e-7 + np.max(img)) |
if target_size is not None: |
img = cv2.resize(img, target_size) |
result.append(img) |
result = np.float32(result) |
return result |
def scale_accross_batch_and_channels(tensor, target_size): |
batch_size, channel_size = tensor.shape[:2] |
reshaped_tensor = tensor.reshape( |
batch_size * channel_size, *tensor.shape[2:]) |
result = scale_cam_image(reshaped_tensor, target_size) |
result = result.reshape( |
batch_size, |
channel_size, |
target_size[1], |
target_size[0]) |
return result |