markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
import base64
import warnings
from collections import namedtuple
from io import BytesIO
from typing import Callable, List, Optional, Union
from captum._utils.common import safe_div
from captum.attr._utils import visualization as viz
from captum.insights.attr_vis._utils.transforms import format_transforms
FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution")
def _convert_figure_base64(fig):
buff = BytesIO()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fig.tight_layout() # removes padding
fig.savefig(buff, format="png")
base64img = base64.b64encode(buff.getvalue()).decode("utf-8")
return base64img
class BaseFeature:
r"""
All Feature classes extend this class to implement custom visualizations in
Insights.
It enforces child classes to implement ``visualization_type`` and ``visualize``
methods.
"""
def __init__(
self,
name: str,
baseline_transforms: Optional[Union[Callable, List[Callable]]],
input_transforms: Optional[Union[Callable, List[Callable]]],
visualization_transform: Optional[Callable],
) -> None:
r"""
Args:
name (str): The label of the specific feature. For example, an
ImageFeature's name can be "Photo".
baseline_transforms (list, callable, optional): Optional list of
callables (e.g. functions) to be called on the input tensor
to construct multiple baselines. Currently only one baseline
is supported. See
:py:class:`.IntegratedGradients` for more
information about baselines.
input_transforms (list, callable, optional): Optional list of callables
(e.g. functions) called on the input tensor sequentially to
convert it into the format expected by the model.
visualization_transform (callable, optional): Optional callable (e.g.
function) applied as a postprocessing step of the original
input data (before ``input_transforms``) to convert it to a
format to be understood by the frontend visualizer as
specified in ``captum/captum/insights/frontend/App.js``.
"""
self.name = name
self.baseline_transforms = format_transforms(baseline_transforms)
self.input_transforms = format_transforms(input_transforms)
self.visualization_transform = visualization_transform
@staticmethod
def visualization_type() -> str:
raise NotImplementedError
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
raise NotImplementedError
class ImageFeature(BaseFeature):
r"""
ImageFeature is used to visualize image features in Insights. It expects an image in
NCHW format. If C has a dimension of 1, its assumed to be a greyscale image.
If it has a dimension of 3, its expected to be in RGB format.
"""
def __init__(
self,
name: str,
baseline_transforms: Union[Callable, List[Callable]],
input_transforms: Union[Callable, List[Callable]],
visualization_transform: Optional[Callable] = None,
) -> None:
r"""
Args:
name (str): The label of the specific feature. For example, an
ImageFeature's name can be "Photo".
baseline_transforms (list, callable, optional): Optional list of
callables (e.g. functions) to be called on the input tensor
to construct multiple baselines. Currently only one baseline
is supported. See
:py:class:`.IntegratedGradients` for more
information about baselines.
input_transforms (list, callable, optional): A list of transforms
or transform to be applied to the input. For images,
normalization is often applied here.
visualization_transform (callable, optional): Optional callable (e.g.
function) applied as a postprocessing step of the original
input data (before input_transforms) to convert it to a
format to be visualized.
"""
super().__init__(
name,
baseline_transforms=baseline_transforms,
input_transforms=input_transforms,
visualization_transform=visualization_transform,
)
@staticmethod
def visualization_type() -> str:
return "image"
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
if self.visualization_transform:
data = self.visualization_transform(data)
data_t, attribution_t = [
t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
for t in (data, attribution)
]
orig_fig, _ = viz.visualize_image_attr(
attribution_t, data_t, method="original_image", use_pyplot=False
)
attr_fig, _ = viz.visualize_image_attr(
attribution_t,
data_t,
method="heat_map",
sign="absolute_value",
use_pyplot=False,
)
img_64 = _convert_figure_base64(orig_fig)
attr_img_64 = _convert_figure_base64(attr_fig)
return FeatureOutput(
name=self.name,
base=img_64,
modified=attr_img_64,
type=self.visualization_type(),
contribution=contribution_frac,
)
class TextFeature(BaseFeature):
r"""
TextFeature is used to visualize text (e.g. sentences) in Insights.
It expects the visualization transform to convert the input data (e.g. index to
string) to the raw text.
"""
def __init__(
self,
name: str,
baseline_transforms: Union[Callable, List[Callable]],
input_transforms: Union[Callable, List[Callable]],
visualization_transform: Callable,
) -> None:
r"""
Args:
name (str): The label of the specific feature. For example, an
ImageFeature's name can be "Photo".
baseline_transforms (list, callable, optional): Optional list of
callables (e.g. functions) to be called on the input tensor
to construct multiple baselines. Currently only one baseline
is supported. See
:py:class:`.IntegratedGradients` for more
information about baselines.
For text features, a common baseline is a tensor of indices
corresponding to PAD with the same size as the input
tensor. See :py:class:`.TokenReferenceBase` for more
information.
input_transforms (list, callable, optional): A list of transforms
or transform to be applied to the input. For text, a common
transform is to convert the tokenized input tensor into an
interpretable embedding. See
:py:class:`.InterpretableEmbeddingBase`
and
:py:func:`~.configure_interpretable_embedding_layer`
for more information.
visualization_transform (callable, optional): Optional callable (e.g.
function) applied as a postprocessing step of the original
input data (before ``input_transforms``) to convert it to a
suitable format for visualization. For text features,
a common function is to convert the token indices to their
corresponding (sub)words.
"""
super().__init__(
name,
baseline_transforms=baseline_transforms,
input_transforms=input_transforms,
visualization_transform=visualization_transform,
)
@staticmethod
def visualization_type() -> str:
return "text"
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
if self.visualization_transform:
text = self.visualization_transform(data)
else:
text = data
attribution = attribution.squeeze(0)
data = data.squeeze(0)
if len(attribution.shape) > 1:
attribution = attribution.sum(dim=1)
# L-Infinity norm, if norm is 0, all attr elements are 0
attr_max = attribution.abs().max()
normalized_attribution = safe_div(attribution, attr_max)
modified = [x * 100 for x in normalized_attribution.tolist()]
return FeatureOutput(
name=self.name,
base=text,
modified=modified,
type=self.visualization_type(),
contribution=contribution_frac,
)
class GeneralFeature(BaseFeature):
r"""
GeneralFeature is used for non-specified feature visualization in Insights.
It can be used for dense or sparse features.
Currently general features are only supported for 2-d tensors, in the format (N, C)
where N is the number of samples and C is the number of categories.
"""
def __init__(self, name: str, categories: List[str]) -> None:
r"""
Args:
name (str): The label of the specific feature. For example, an
ImageFeature's name can be "Photo".
categories (list[str]): Category labels for the general feature. The
order and size should match the second dimension of the
``data`` tensor parameter in ``visualize``.
"""
super().__init__(
name,
baseline_transforms=None,
input_transforms=None,
visualization_transform=None,
)
self.categories = categories
@staticmethod
def visualization_type() -> str:
return "general"
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
attribution = attribution.squeeze(0)
data = data.squeeze(0)
# L-2 norm, if norm is 0, all attr elements are 0
l2_norm = attribution.norm()
normalized_attribution = safe_div(attribution, l2_norm)
modified = [x * 100 for x in normalized_attribution.tolist()]
base = [f"{c}: {d:.2f}" for c, d in zip(self.categories, data.tolist())]
return FeatureOutput(
name=self.name,
base=base,
modified=modified,
type=self.visualization_type(),
contribution=contribution_frac,
)
class EmptyFeature(BaseFeature):
def __init__(
self,
name: str = "empty",
baseline_transforms: Optional[Union[Callable, List[Callable]]] = None,
input_transforms: Optional[Union[Callable, List[Callable]]] = None,
visualization_transform: Optional[Callable] = None,
) -> None:
super().__init__(
name,
baseline_transforms=baseline_transforms,
input_transforms=input_transforms,
visualization_transform=visualization_transform,
)
@staticmethod
def visualization_type() -> str:
return "empty"
def visualize(self, _attribution, _data, contribution_frac) -> FeatureOutput:
return FeatureOutput(
name=self.name,
base=None,
modified=None,
type=self.visualization_type(),
contribution=contribution_frac,
)