Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import warnings | |
from enum import Enum | |
from typing import Any, Iterable, List, Tuple, Union | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.figure import Figure | |
from matplotlib.pyplot import axis, figure | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from numpy import ndarray | |
try: | |
from IPython.core.display import display, HTML | |
HAS_IPYTHON = True | |
except ImportError: | |
HAS_IPYTHON = False | |
class ImageVisualizationMethod(Enum): | |
heat_map = 1 | |
blended_heat_map = 2 | |
original_image = 3 | |
masked_image = 4 | |
alpha_scaling = 5 | |
class VisualizeSign(Enum): | |
positive = 1 | |
absolute_value = 2 | |
negative = 3 | |
all = 4 | |
def _prepare_image(attr_visual: ndarray): | |
return np.clip(attr_visual.astype(int), 0, 255) | |
def _normalize_scale(attr: ndarray, scale_factor: float): | |
assert scale_factor != 0, "Cannot normalize by scale factor = 0" | |
if abs(scale_factor) < 1e-5: | |
warnings.warn( | |
"Attempting to normalize by value approximately 0, visualized results" | |
"may be misleading. This likely means that attribution values are all" | |
"close to 0." | |
) | |
attr_norm = attr / scale_factor | |
return np.clip(attr_norm, -1, 1) | |
def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]): | |
# given values should be non-negative | |
assert percentile >= 0 and percentile <= 100, ( | |
"Percentile for thresholding must be " "between 0 and 100 inclusive." | |
) | |
sorted_vals = np.sort(values.flatten()) | |
cum_sums = np.cumsum(sorted_vals) | |
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] | |
return sorted_vals[threshold_id] | |
def _normalize_image_attr( | |
attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2 | |
): | |
attr_combined = np.sum(attr, axis=2) | |
# Choose appropriate signed values and rescale, removing given outlier percentage. | |
if VisualizeSign[sign] == VisualizeSign.all: | |
threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc) | |
elif VisualizeSign[sign] == VisualizeSign.positive: | |
attr_combined = (attr_combined > 0) * attr_combined | |
threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) | |
elif VisualizeSign[sign] == VisualizeSign.negative: | |
attr_combined = (attr_combined < 0) * attr_combined | |
threshold = -1 * _cumulative_sum_threshold( | |
np.abs(attr_combined), 100 - outlier_perc | |
) | |
elif VisualizeSign[sign] == VisualizeSign.absolute_value: | |
attr_combined = np.abs(attr_combined) | |
threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc) | |
else: | |
raise AssertionError("Visualize Sign type is not valid.") | |
return _normalize_scale(attr_combined, threshold) | |
def visualize_image_attr( | |
attr: ndarray, | |
original_image: Union[None, ndarray] = None, | |
method: str = "heat_map", | |
sign: str = "absolute_value", | |
plt_fig_axis: Union[None, Tuple[figure, axis]] = None, | |
outlier_perc: Union[int, float] = 2, | |
cmap: Union[None, str] = None, | |
alpha_overlay: float = 0.5, | |
show_colorbar: bool = False, | |
title: Union[None, str] = None, | |
fig_size: Tuple[int, int] = (6, 6), | |
use_pyplot: bool = True, | |
): | |
r""" | |
Visualizes attribution for a given image by normalizing attribution values | |
of the desired sign (positive, negative, absolute value, or all) and displaying | |
them using the desired mode in a matplotlib figure. | |
Args: | |
attr (numpy.array): Numpy array corresponding to attributions to be | |
visualized. Shape must be in the form (H, W, C), with | |
channels as last dimension. Shape must also match that of | |
the original image if provided. | |
original_image (numpy.array, optional): Numpy array corresponding to | |
original image. Shape must be in the form (H, W, C), with | |
channels as the last dimension. Image can be provided either | |
with float values in range 0-1 or int values between 0-255. | |
This is a necessary argument for any visualization method | |
which utilizes the original image. | |
Default: None | |
method (string, optional): Chosen method for visualizing attribution. | |
Supported options are: | |
1. `heat_map` - Display heat map of chosen attributions | |
2. `blended_heat_map` - Overlay heat map over greyscale | |
version of original image. Parameter alpha_overlay | |
corresponds to alpha of heat map. | |
3. `original_image` - Only display original image. | |
4. `masked_image` - Mask image (pixel-wise multiply) | |
by normalized attribution values. | |
5. `alpha_scaling` - Sets alpha channel of each pixel | |
to be equal to normalized attribution value. | |
Default: `heat_map` | |
sign (string, optional): Chosen sign of attributions to visualize. Supported | |
options are: | |
1. `positive` - Displays only positive pixel attributions. | |
2. `absolute_value` - Displays absolute value of | |
attributions. | |
3. `negative` - Displays only negative pixel attributions. | |
4. `all` - Displays both positive and negative attribution | |
values. This is not supported for `masked_image` or | |
`alpha_scaling` modes, since signed information cannot | |
be represented in these modes. | |
Default: `absolute_value` | |
plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis | |
on which to visualize. If None is provided, then a new figure | |
and axis are created. | |
Default: None | |
outlier_perc (float or int, optional): Top attribution values which | |
correspond to a total of outlier_perc percentage of the | |
total attribution are set to 1 and scaling is performed | |
using the minimum of these values. For sign=`all`, outliers | |
and scale value are computed using absolute value of | |
attributions. | |
Default: 2 | |
cmap (string, optional): String corresponding to desired colormap for | |
heatmap visualization. This defaults to "Reds" for negative | |
sign, "Blues" for absolute value, "Greens" for positive sign, | |
and a spectrum from red to green for all. Note that this | |
argument is only used for visualizations displaying heatmaps. | |
Default: None | |
alpha_overlay (float, optional): Alpha to set for heatmap when using | |
`blended_heat_map` visualization mode, which overlays the | |
heat map over the greyscaled original image. | |
Default: 0.5 | |
show_colorbar (boolean, optional): Displays colorbar for heatmap below | |
the visualization. If given method does not use a heatmap, | |
then a colormap axis is created and hidden. This is | |
necessary for appropriate alignment when visualizing | |
multiple plots, some with colorbars and some without. | |
Default: False | |
title (string, optional): Title string for plot. If None, no title is | |
set. | |
Default: None | |
fig_size (tuple, optional): Size of figure created. | |
Default: (6,6) | |
use_pyplot (boolean, optional): If true, uses pyplot to create and show | |
figure and displays the figure after creating. If False, | |
uses Matplotlib object oriented API and simply returns a | |
figure object without showing. | |
Default: True. | |
Returns: | |
2-element tuple of **figure**, **axis**: | |
- **figure** (*matplotlib.pyplot.figure*): | |
Figure object on which visualization | |
is created. If plt_fig_axis argument is given, this is the | |
same figure provided. | |
- **axis** (*matplotlib.pyplot.axis*): | |
Axis object on which visualization | |
is created. If plt_fig_axis argument is given, this is the | |
same axis provided. | |
Examples:: | |
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32, | |
>>> # and returns an Nx10 tensor of class probabilities. | |
>>> net = ImageClassifier() | |
>>> ig = IntegratedGradients(net) | |
>>> # Computes integrated gradients for class 3 for a given image . | |
>>> attribution, delta = ig.attribute(orig_image, target=3) | |
>>> # Displays blended heat map visualization of computed attributions. | |
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map") | |
""" | |
# Create plot if figure, axis not provided | |
if plt_fig_axis is not None: | |
plt_fig, plt_axis = plt_fig_axis | |
else: | |
if use_pyplot: | |
plt_fig, plt_axis = plt.subplots(figsize=fig_size) | |
else: | |
plt_fig = Figure(figsize=fig_size) | |
plt_axis = plt_fig.subplots() | |
if original_image is not None: | |
if np.max(original_image) <= 1.0: | |
original_image = _prepare_image(original_image * 255) | |
else: | |
assert ( | |
ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map | |
), "Original Image must be provided for any visualization other than heatmap." | |
# Remove ticks and tick labels from plot. | |
plt_axis.xaxis.set_ticks_position("none") | |
plt_axis.yaxis.set_ticks_position("none") | |
plt_axis.set_yticklabels([]) | |
plt_axis.set_xticklabels([]) | |
plt_axis.grid(b=False) | |
heat_map = None | |
# Show original image | |
if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image: | |
if len(original_image.shape) > 2 and original_image.shape[2] == 1: | |
original_image = np.squeeze(original_image, axis=2) | |
plt_axis.imshow(original_image) | |
else: | |
# Choose appropriate signed attributions and normalize. | |
norm_attr = _normalize_image_attr(attr, sign, outlier_perc) | |
# Set default colormap and bounds based on sign. | |
if VisualizeSign[sign] == VisualizeSign.all: | |
default_cmap = LinearSegmentedColormap.from_list( | |
"RdWhGn", ["red", "white", "green"] | |
) | |
vmin, vmax = -1, 1 | |
elif VisualizeSign[sign] == VisualizeSign.positive: | |
default_cmap = "Greens" | |
vmin, vmax = 0, 1 | |
elif VisualizeSign[sign] == VisualizeSign.negative: | |
default_cmap = "Reds" | |
vmin, vmax = 0, 1 | |
elif VisualizeSign[sign] == VisualizeSign.absolute_value: | |
default_cmap = "Blues" | |
vmin, vmax = 0, 1 | |
else: | |
raise AssertionError("Visualize Sign type is not valid.") | |
cmap = cmap if cmap is not None else default_cmap | |
# Show appropriate image visualization. | |
if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map: | |
heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax) | |
elif ( | |
ImageVisualizationMethod[method] | |
== ImageVisualizationMethod.blended_heat_map | |
): | |
plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray") | |
heat_map = plt_axis.imshow( | |
norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay | |
) | |
elif ImageVisualizationMethod[method] == ImageVisualizationMethod.masked_image: | |
assert VisualizeSign[sign] != VisualizeSign.all, ( | |
"Cannot display masked image with both positive and negative " | |
"attributions, choose a different sign option." | |
) | |
plt_axis.imshow( | |
_prepare_image(original_image * np.expand_dims(norm_attr, 2)) | |
) | |
elif ImageVisualizationMethod[method] == ImageVisualizationMethod.alpha_scaling: | |
assert VisualizeSign[sign] != VisualizeSign.all, ( | |
"Cannot display alpha scaling with both positive and negative " | |
"attributions, choose a different sign option." | |
) | |
plt_axis.imshow( | |
np.concatenate( | |
[ | |
original_image, | |
_prepare_image(np.expand_dims(norm_attr, 2) * 255), | |
], | |
axis=2, | |
) | |
) | |
else: | |
raise AssertionError("Visualize Method type is not valid.") | |
# Add colorbar. If given method is not a heatmap and no colormap is relevant, | |
# then a colormap axis is created and hidden. This is necessary for appropriate | |
# alignment when visualizing multiple plots, some with heatmaps and some | |
# without. | |
if show_colorbar: | |
axis_separator = make_axes_locatable(plt_axis) | |
colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1) | |
if heat_map: | |
plt_fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis) | |
else: | |
colorbar_axis.axis("off") | |
if title: | |
plt_axis.set_title(title) | |
if use_pyplot: | |
plt.show() | |
return plt_fig, plt_axis | |
def visualize_image_attr_multiple( | |
attr: ndarray, | |
original_image: Union[None, ndarray], | |
methods: List[str], | |
signs: List[str], | |
titles: Union[None, List[str]] = None, | |
fig_size: Tuple[int, int] = (8, 6), | |
use_pyplot: bool = True, | |
**kwargs: Any, | |
): | |
r""" | |
Visualizes attribution using multiple visualization methods displayed | |
in a 1 x k grid, where k is the number of desired visualizations. | |
Args: | |
attr (numpy.array): Numpy array corresponding to attributions to be | |
visualized. Shape must be in the form (H, W, C), with | |
channels as last dimension. Shape must also match that of | |
the original image if provided. | |
original_image (numpy.array, optional): Numpy array corresponding to | |
original image. Shape must be in the form (H, W, C), with | |
channels as the last dimension. Image can be provided either | |
with values in range 0-1 or 0-255. This is a necessary | |
argument for any visualization method which utilizes | |
the original image. | |
methods (list of strings): List of strings of length k, defining method | |
for each visualization. Each method must be a valid | |
string argument for method to visualize_image_attr. | |
signs (list of strings): List of strings of length k, defining signs for | |
each visualization. Each sign must be a valid | |
string argument for sign to visualize_image_attr. | |
titles (list of strings, optional): List of strings of length k, providing | |
a title string for each plot. If None is provided, no titles | |
are added to subplots. | |
Default: None | |
fig_size (tuple, optional): Size of figure created. | |
Default: (8, 6) | |
use_pyplot (boolean, optional): If true, uses pyplot to create and show | |
figure and displays the figure after creating. If False, | |
uses Matplotlib object oriented API and simply returns a | |
figure object without showing. | |
Default: True. | |
**kwargs (Any, optional): Any additional arguments which will be passed | |
to every individual visualization. Such arguments include | |
`show_colorbar`, `alpha_overlay`, `cmap`, etc. | |
Returns: | |
2-element tuple of **figure**, **axis**: | |
- **figure** (*matplotlib.pyplot.figure*): | |
Figure object on which visualization | |
is created. If plt_fig_axis argument is given, this is the | |
same figure provided. | |
- **axis** (*matplotlib.pyplot.axis*): | |
Axis object on which visualization | |
is created. If plt_fig_axis argument is given, this is the | |
same axis provided. | |
Examples:: | |
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32, | |
>>> # and returns an Nx10 tensor of class probabilities. | |
>>> net = ImageClassifier() | |
>>> ig = IntegratedGradients(net) | |
>>> # Computes integrated gradients for class 3 for a given image . | |
>>> attribution, delta = ig.attribute(orig_image, target=3) | |
>>> # Displays original image and heat map visualization of | |
>>> # computed attributions side by side. | |
>>> _ = visualize_image_attr_multiple(attribution, orig_image, | |
>>> ["original_image", "heat_map"], ["all", "positive"]) | |
""" | |
assert len(methods) == len(signs), "Methods and signs array lengths must match." | |
if titles is not None: | |
assert len(methods) == len(titles), ( | |
"If titles list is given, length must " "match that of methods list." | |
) | |
if use_pyplot: | |
plt_fig = plt.figure(figsize=fig_size) | |
else: | |
plt_fig = Figure(figsize=fig_size) | |
plt_axis = plt_fig.subplots(1, len(methods)) | |
# When visualizing one | |
if len(methods) == 1: | |
plt_axis = [plt_axis] | |
for i in range(len(methods)): | |
visualize_image_attr( | |
attr, | |
original_image=original_image, | |
method=methods[i], | |
sign=signs[i], | |
plt_fig_axis=(plt_fig, plt_axis[i]), | |
use_pyplot=False, | |
title=titles[i] if titles else None, | |
**kwargs, | |
) | |
plt_fig.tight_layout() | |
if use_pyplot: | |
plt.show() | |
return plt_fig, plt_axis | |
# These visualization methods are for text and are partially copied from | |
# experiments conducted by Davide Testuggine at Facebook. | |
class VisualizationDataRecord: | |
r""" | |
A data record for storing attribution relevant information | |
""" | |
__slots__ = [ | |
"word_attributions", | |
"pred_prob", | |
"pred_class", | |
"true_class", | |
"attr_class", | |
"attr_score", | |
"raw_input_ids", | |
"convergence_score", | |
] | |
def __init__( | |
self, | |
word_attributions, | |
pred_prob, | |
pred_class, | |
true_class, | |
attr_class, | |
attr_score, | |
raw_input_ids, | |
convergence_score, | |
) -> None: | |
self.word_attributions = word_attributions | |
self.pred_prob = pred_prob | |
self.pred_class = pred_class | |
self.true_class = true_class | |
self.attr_class = attr_class | |
self.attr_score = attr_score | |
self.raw_input_ids = raw_input_ids | |
self.convergence_score = convergence_score | |
def _get_color(attr): | |
# clip values to prevent CSS errors (Values should be from [-1,1]) | |
attr = max(-1, min(1, attr)) | |
if attr > 0: | |
hue = 120 | |
sat = 75 | |
lig = 100 - int(50 * attr) | |
else: | |
hue = 0 | |
sat = 75 | |
lig = 100 - int(-40 * attr) | |
return "hsl({}, {}%, {}%)".format(hue, sat, lig) | |
def format_classname(classname): | |
return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname) | |
def format_special_tokens(token): | |
if token.startswith("<") and token.endswith(">"): | |
return "#" + token.strip("<>") | |
return token | |
def format_tooltip(item, text): | |
return '<div class="tooltip">{item}\ | |
<span class="tooltiptext">{text}</span>\ | |
</div>'.format( | |
item=item, text=text | |
) | |
def format_word_importances(words, importances): | |
if importances is None or len(importances) == 0: | |
return "<td></td>" | |
assert len(words) <= len(importances) | |
tags = ["<td>"] | |
for word, importance in zip(words, importances[: len(words)]): | |
word = format_special_tokens(word) | |
color = _get_color(importance) | |
unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \ | |
line-height:1.75"><font color="black"> {word}\ | |
</font></mark>'.format( | |
color=color, word=word | |
) | |
tags.append(unwrapped_tag) | |
tags.append("</td>") | |
return "".join(tags) | |
def visualize_text( | |
datarecords: Iterable[VisualizationDataRecord], legend: bool = True | |
) -> "HTML": # In quotes because this type doesn't exist in standalone mode | |
assert HAS_IPYTHON, ( | |
"IPython must be available to visualize text. " | |
"Please run 'pip install ipython'." | |
) | |
dom = ["<table width: 100%>"] | |
rows = [ | |
"<tr><th>True Label</th>" | |
"<th>Predicted Label</th>" | |
"<th>Attribution Label</th>" | |
"<th>Attribution Score</th>" | |
"<th>Word Importance</th>" | |
] | |
for datarecord in datarecords: | |
rows.append( | |
"".join( | |
[ | |
"<tr>", | |
format_classname(datarecord.true_class), | |
format_classname( | |
"{0} ({1:.2f})".format( | |
datarecord.pred_class, datarecord.pred_prob | |
) | |
), | |
format_classname(datarecord.attr_class), | |
format_classname("{0:.2f}".format(datarecord.attr_score)), | |
format_word_importances( | |
datarecord.raw_input_ids, datarecord.word_attributions | |
), | |
"<tr>", | |
] | |
) | |
) | |
if legend: | |
dom.append( | |
'<div style="border-top: 1px solid; margin-top: 5px; \ | |
padding-top: 5px; display: inline-block">' | |
) | |
dom.append("<b>Legend: </b>") | |
for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): | |
dom.append( | |
'<span style="display: inline-block; width: 10px; height: 10px; \ | |
border: 1px solid; background-color: \ | |
{value}"></span> {label} '.format( | |
value=_get_color(value), label=label | |
) | |
) | |
dom.append("</div>") | |
dom.append("".join(rows)) | |
dom.append("</table>") | |
html = HTML("".join(dom)) | |
display(html) | |
return html | |