#!/usr/bin/env python3
import inspect
import math
import typing
import warnings
from typing import Any, Callable, cast, List, Optional, Tuple, Union
import torch
from captum._utils.common import (
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import (
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
from captum.log import log_usage
from torch import Tensor
from torch.nn import CosineSimilarity
from import DataLoader, TensorDataset
class LimeBase(PerturbationAttribution):
Lime is an interpretability method that trains an interpretable surrogate model
by sampling points around a specified input example and using model evaluations
at these points to train a simpler interpretable 'surrogate' model, such as a
linear model.
LimeBase provides a generic framework to train a surrogate interpretable model.
This differs from most other attribution methods, since the method returns a
representation of the interpretable model (e.g. coefficients of the linear model).
For a similar interface to other perturbation-based attribution methods, please use
the Lime child class, which defines specific transformations for the interpretable
LimeBase allows sampling points in either the interpretable space or the original
input space to train the surrogate model. The interpretable space is a feature
vector used to train the surrogate interpretable model; this feature space is often
of smaller dimensionality than the original feature space in order for the surrogate
model to be more interpretable.
If sampling in the interpretable space, a transformation function must be provided
to define how a vector sampled in the interpretable space can be transformed into
an example in the original input space. If sampling in the original input space, a
transformation function must be provided to define how the input can be transformed
into its interpretable vector representation.
More details regarding LIME can be found in the original paper:
def __init__(
forward_func: Callable,
interpretable_model: Model,
similarity_func: Callable,
perturb_func: Callable,
perturb_interpretable_space: bool,
from_interp_rep_transform: Optional[Callable],
to_interp_rep_transform: Optional[Callable],
) -> None:
forward_func (callable): The forward function of the model or any
modification of it. If a batch is provided as input for
attribution, it is expected that forward_func returns a scalar
representing the entire batch.
interpretable_model (Model): Model object to train interpretable model.
A Model object provides a `fit` method to train the model,
given a dataloader, with batches containing three tensors:
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.
Some predefined interpretable linear models are provided in
captum._utils.models.linear_model including wrappers around
SkLearn linear models as well as SGD-based PyTorch linear
Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
training interpretable model. Weight is generally
determined based on similarity to the original input.
The original paper refers to this as a similarity kernel.
The expected signature of this callable is:
>>> similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape (regardless of whether or not
the sampling function returns inputs in the interpretable
space). original_input is the same as the input provided
when calling attribute.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
perturb_func (callable): Function which returns a single
sampled input, generally a perturbation of the original
input, which is used to train the interpretable surrogate
model. Function can return samples in either
the original input space (matching type and tensor shapes
of original input) or in the interpretable input space,
which is a vector containing the intepretable features.
Alternatively, this function can return a generator
yielding samples to train the interpretable surrogate
model, and n_samples perturbations will be sampled
from this generator.
The expected signature of this callable is:
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors or
>>> generator yielding tensor or tuple of Tensors
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
Returned sampled input should match the input type (Tensor
or Tuple of Tensor and corresponding shapes) if
perturb_interpretable_space = False. If
perturb_interpretable_space = True, the return type should
be a single tensor of shape 1 x num_interp_features,
corresponding to the representation of the
sample to train the interpretable model.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
perturb_interpretable_space (bool): Indicates whether
perturb_func returns a sample in the interpretable space
(tensor of shape 1 x num_interp_features) or a sample
in the original space, matching the format of the original
input. Once sampled, inputs can be converted to / from
the interpretable representation with either
to_interp_rep_transform or from_interp_rep_transform.
from_interp_rep_transform (callable): Function which takes a
single sampled interpretable representation (tensor
of shape 1 x num_interp_features) and returns
the corresponding representation in the input space
(matching shapes of original input to attribute).
This argument is necessary if perturb_interpretable_space
is True, otherwise None can be provided for this argument.
The expected signature of this callable is:
>>> from_interp_rep_transform(
>>> curr_sample: Tensor [2D 1 x num_interp_features]
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors
Returned sampled input should match the type of original_input
and corresponding tensor shapes.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
to_interp_rep_transform (callable): Function which takes a
sample in the original input space and converts to
its interpretable representation (tensor
of shape 1 x num_interp_features).
This argument is necessary if perturb_interpretable_space
is False, otherwise None can be provided for this argument.
The expected signature of this callable is:
>>> to_interp_rep_transform(
>>> curr_sample: Tensor or Tuple of Tensors,
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [2D 1 x num_interp_features]
curr_sample will match the type of original_input
and corresponding tensor shapes.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
PerturbationAttribution.__init__(self, forward_func)
self.interpretable_model = interpretable_model
self.similarity_func = similarity_func
self.perturb_func = perturb_func
self.perturb_interpretable_space = perturb_interpretable_space
self.from_interp_rep_transform = from_interp_rep_transform
self.to_interp_rep_transform = to_interp_rep_transform
if self.perturb_interpretable_space:
assert (
self.from_interp_rep_transform is not None
), "Must provide transform from interpretable space to original input space"
" when sampling from interpretable space."
assert (
self.to_interp_rep_transform is not None
), "Must provide transform from original input space to interpretable space"
def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: Any = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
show_progress: bool = False,
) -> Tensor:
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
scalar) to the inputs of the model using the approach described above.
It trains an interpretable model and returns a representation of the
interpretable model.
It is recommended to only provide a single example as input (tensors
with first dimension or batch size = 1). This is because LIME is generally
used for sample-based interpretability, training a separate interpretable
model to explain a model's prediction on each individual example.
A batch of inputs can be provided as inputs only if forward_func
returns a single value per batch (e.g. loss).
The interpretable feature representation should still have shape
1 x num_interp_features, corresponding to the interpretable
representation for the full batch, and perturbations_per_eval
must be set to 1.
inputs (tensor or tuple of tensors): Input for which LIME
is computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
(for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (any, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. For all other types,
the given argument is used for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
show_progress (bool, optional): Displays the progress of computation.
It will try to use tqdm if available for advanced features
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
**kwargs (Any, optional): Any additional arguments necessary for
sampling and transformation functions (provided to
Default: None
**interpretable model representation**:
- **interpretable model representation* (*Any*):
A representation of the interpretable model trained. The return
type matches the return type of train_interpretable_model_func.
For example, this could contain coefficients of a
linear surrogate model.
>>> # SimpleClassifier takes a single input tensor of
>>> # float features with size N x 5,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>> # We will train an interpretable model with the same
>>> # features by simply sampling with added Gaussian noise
>>> # to the inputs and training a model to predict the
>>> # score of the target class.
>>> # For interpretable model training, we will use sklearn
>>> # linear model in this example. We have provided wrappers
>>> # around sklearn linear models to fit the Model interface.
>>> # Any arguments provided to the sklearn constructor can also
>>> # be provided to the wrapper, e.g.:
>>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
>>> from captum._utils.models.linear_model import SkLearnLinearModel
>>> # Define similarity kernel (exponential kernel based on L2 norm)
>>> def similarity_kernel(
>>> original_input: Tensor,
>>> perturbed_input: Tensor,
>>> perturbed_interpretable_input: Tensor,
>>> **kwargs)->Tensor:
>>> # kernel_width will be provided to attribute as a kwarg
>>> kernel_width = kwargs["kernel_width"]
>>> l2_dist = torch.norm(original_input - perturbed_input)
>>> return torch.exp(- (l2_dist**2) / (kernel_width**2))
>>> # Define sampling function
>>> # This function samples in original input space
>>> def perturb_func(
>>> original_input: Tensor,
>>> **kwargs)->Tensor:
>>> return original_input + torch.randn_like(original_input)
>>> # For this example, we are setting the interpretable input to
>>> # match the model input, so the to_interp_rep_transform
>>> # function simply returns the input. In most cases, the interpretable
>>> # input will be different and may have a smaller feature set, so
>>> # an appropriate transformation function should be provided.
>>> def to_interp_transform(curr_sample, original_inp,
>>> **kwargs):
>>> return curr_sample
>>> # Generating random input with size 1 x 5
>>> input = torch.randn(1, 5)
>>> # Defining LimeBase interpreter
>>> lime_attr = LimeBase(net,
>>> # Computes interpretable model, returning coefficients of linear
>>> # model.
>>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
with torch.no_grad():
inp_tensor = (
cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
device = inp_tensor.device
interpretable_inps = []
similarities = []
outputs = []
curr_model_inputs = []
expanded_additional_args = None
expanded_target = None
perturb_generator = None
if inspect.isgeneratorfunction(self.perturb_func):
perturb_generator = self.perturb_func(inputs, **kwargs)
if show_progress:
attr_progress = progress(
total=math.ceil(n_samples / perturbations_per_eval),
desc=f"{self.get_name()} attribution",
batch_count = 0
for _ in range(n_samples):
if perturb_generator:
curr_sample = next(perturb_generator)
except StopIteration:
"Generator completed prior to given n_samples iterations!"
curr_sample = self.perturb_func(inputs, **kwargs)
batch_count += 1
if self.perturb_interpretable_space:
self.from_interp_rep_transform( # type: ignore
curr_sample, inputs, **kwargs
self.to_interp_rep_transform( # type: ignore
curr_sample, inputs, **kwargs
curr_sim = self.similarity_func(
inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs
if isinstance(curr_sim, Tensor)
else torch.tensor([curr_sim], device=device)
if len(curr_model_inputs) == perturbations_per_eval:
if expanded_additional_args is None:
expanded_additional_args = _expand_additional_forward_args(
additional_forward_args, len(curr_model_inputs)
if expanded_target is None:
expanded_target = _expand_target(target, len(curr_model_inputs))
model_out = self._evaluate_batch(
if show_progress:
curr_model_inputs = []
if len(curr_model_inputs) > 0:
expanded_additional_args = _expand_additional_forward_args(
additional_forward_args, len(curr_model_inputs)
expanded_target = _expand_target(target, len(curr_model_inputs))
model_out = self._evaluate_batch(
if show_progress:
if show_progress:
combined_interp_inps =
combined_outputs = (
if len(outputs[0].shape) > 0
else torch.stack(outputs)
combined_sim = (
if len(similarities[0].shape) > 0
else torch.stack(similarities)
dataset = TensorDataset(
combined_interp_inps, combined_outputs, combined_sim
), batch_size=batch_count))
return self.interpretable_model.representation()
def _evaluate_batch(
curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
expanded_target: TargetType,
expanded_additional_args: Any,
device: torch.device,
model_out = _run_forward(
if isinstance(model_out, Tensor):
assert model_out.numel() == len(curr_model_inputs), (
"Number of outputs is not appropriate, must return "
"one output per perturbed input"
if isinstance(model_out, Tensor):
return model_out.flatten()
return torch.tensor([model_out], device=device)
def has_convergence_delta(self) -> bool:
return False
def multiplies_by_inputs(self):
return False
# Default transformations and methods
# for Lime child implementation.
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
assert (
"feature_mask" in kwargs
), "Must provide feature_mask to use default interpretable representation transform"
assert (
"baselines" in kwargs
), "Must provide baselines to use default interpretable representation transfrom"
feature_mask = kwargs["feature_mask"]
if isinstance(feature_mask, Tensor):
binary_mask = curr_sample[0][feature_mask].bool()
return ( * original_inputs
+ (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
binary_mask = tuple(
curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
return tuple(
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
+ (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
for j in range(len(feature_mask))
def get_exp_kernel_similarity_function(
distance_mode: str = "cosine", kernel_width: float = 1.0
) -> Callable:
This method constructs an appropriate similarity function to compute
weights for perturbed sample in LIME. Distance between the original
and perturbed inputs is computed based on the provided distance mode,
and the distance is passed through an exponential kernel with given
kernel width to convert to a range between 0 and 1.
The callable returned can be provided as the similarity_fn for
Lime or LimeBase.
distance_mode (str, optional): Distance mode can be either "cosine" or
"euclidean" corresponding to either cosine distance
or Euclidean distance respectively. Distance is computed
by flattening the original inputs and perturbed inputs
(concatenating tuples of inputs if necessary) and computing
distances between the resulting vectors.
Default: "cosine"
kernel_width (float, optional):
Kernel width for exponential kernel applied to distance.
Default: 1.0
- **similarity_fn** (*Callable*):
Similarity function. This callable can be provided as the
similarity_fn for Lime or LimeBase.
def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
if distance_mode == "cosine":
cos_sim = CosineSimilarity(dim=0)
distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp)
elif distance_mode == "euclidean":
distance = torch.norm(flattened_original_inp - flattened_perturbed_inp)
raise ValueError("distance_mode must be either cosine or euclidean.")
return math.exp(-1 * (distance ** 2) / (2 * (kernel_width ** 2)))
return default_exp_kernel
def default_perturb_func(original_inp, **kwargs):
assert (
"num_interp_features" in kwargs
), "Must provide num_interp_features to use default interpretable sampling function"
if isinstance(original_inp, Tensor):
device = original_inp.device
device = original_inp[0].device
probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
return torch.bernoulli(probs).to(device=device).long()
def construct_feature_mask(feature_mask, formatted_inputs):
if feature_mask is None:
feature_mask, num_interp_features = _construct_default_feature_mask(
feature_mask = _format_tensor_into_tuples(feature_mask)
min_interp_features = int(
for single_mask in feature_mask
if single_mask.numel()
if min_interp_features != 0:
"Minimum element in feature mask is not 0, shifting indices to"
" start at 0."
feature_mask = tuple(
single_mask - min_interp_features for single_mask in feature_mask
num_interp_features = int(
for single_mask in feature_mask
if single_mask.numel()
+ 1
return feature_mask, num_interp_features
class Lime(LimeBase):
Lime is an interpretability method that trains an interpretable surrogate model
by sampling points around a specified input example and using model evaluations
at these points to train a simpler interpretable 'surrogate' model, such as a
linear model.
Lime provides a more specific implementation than LimeBase in order to expose
a consistent API with other perturbation-based algorithms. For more general
use of the LIME framework, consider using the LimeBase class directly and
defining custom sampling and transformation to / from interpretable
representation functions.
Lime assumes that the interpretable representation is a binary vector,
corresponding to some elements in the input being set to their baseline value
if the corresponding binary interpretable feature value is 0 or being set
to the original input value if the corresponding binary interpretable
feature value is 1. Input values can be grouped to correspond to the same
binary interpretable feature using a feature mask provided when calling
attribute, similar to other perturbation-based attribution methods.
One example of this setting is when applying Lime to an image classifier.
Pixels in an image can be grouped into super-pixels or segments, which
correspond to interpretable features, provided as a feature_mask when
calling attribute. Sampled binary vectors convey whether a super-pixel
is on (retains the original input values) or off (set to the corresponding
baseline value, e.g. black image). An interpretable linear model is trained
with input being the binary vectors and outputs as the corresponding scores
of the image classifier with the appropriate super-pixels masked based on the
binary vector. Coefficients of the trained surrogate
linear model convey the importance of each super-pixel.
More details regarding LIME can be found in the original paper:
def __init__(
forward_func: Callable,
interpretable_model: Optional[Model] = None,
similarity_func: Optional[Callable] = None,
perturb_func: Optional[Callable] = None,
) -> None:
forward_func (callable): The forward function of the model or any
modification of it
interpretable_model (optional, Model): Model object to train
interpretable model.
This argument is optional and defaults to SkLearnLasso(alpha=0.01),
which is a wrapper around the Lasso linear model in SkLearn.
This requires having sklearn version >= 0.23 available.
Other predefined interpretable linear models are provided in
Alternatively, a custom model object must provide a `fit` method to
train the model, given a dataloader, with batches containing
three tensors:
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.
Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (optional, callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
training the interpretable model.
This is often referred to as a similarity kernel.
This argument is optional and defaults to a function which
applies an exponential kernel to the consine distance between
the original input and perturbed input, with a kernel width
of 1.0.
A similarity function applying an exponential
kernel to cosine / euclidean distances can be constructed
using the provided get_exp_kernel_similarity_function in
Alternately, a custom callable can also be provided.
The expected signature of this callable is:
>>> def similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape, with original_input
being the same as the input provided when calling attribute.
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask).
perturb_func (optional, callable): Function which returns a single
sampled input, which is a binary vector of length
num_interp_features, or a generator of such tensors.
This function is optional, the default function returns
a binary vector where each element is selected
independently and uniformly at random. Custom
logic for selecting sampled binary vectors can
be implemented by providing a function with the
following expected signature:
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
>>> or generator yielding such tensors
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask).
if interpretable_model is None:
interpretable_model = SkLearnLasso(alpha=0.01)
if similarity_func is None:
similarity_func = get_exp_kernel_similarity_function()
if perturb_func is None:
perturb_func = default_perturb_func
def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
scalar) to the inputs of the model using the approach described above,
training an interpretable model and returning a representation of the
interpretable model.
It is recommended to only provide a single example as input (tensors
with first dimension or batch size = 1). This is because LIME is generally
used for sample-based interpretability, training a separate interpretable
model to explain a model's prediction on each individual example.
A batch of inputs can also be provided as inputs, similar to
other perturbation-based attribution methods. In this case, if forward_fn
returns a scalar per example, attributions will be computed for each
example independently, with a separate interpretable model trained for each
example. Note that provided similarity and perturbation functions will be
provided each example separately (first dimension = 1) in this case.
If forward_fn returns a scalar per batch (e.g. loss), attributions will
still be computed using a single interpretable model for the full batch.
In this case, similarity and perturbation functions will be provided the
same original input containing the full batch.
The number of interpretable features is determined from the provided
feature mask, or if none is provided, from the default feature mask,
which considers each scalar input as a separate feature. It is
generally recommended to provide a feature mask which groups features
into a small number of interpretable features / components (e.g.
superpixels in images).
inputs (tensor or tuple of tensors): Input for which LIME
is computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
baselines (scalar, tensor, tuple of scalars or tensors, optional):
Baselines define reference value which replaces each
feature when the corresponding interpretable feature
is set to 0.
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
- either a tensor with matching dimensions to
corresponding tensor in the inputs' tuple
or the first dimension is one and the remaining
dimensions match with the corresponding
input tensor.
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
(for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (any, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. It will be
repeated for each of `n_steps` along the integrated
path. For all other types, the given argument is used
for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
feature_mask (tensor or tuple of tensors, optional):
feature_mask defines a mask for the input, grouping
features which correspond to the same
interpretable feature. feature_mask
should contain the same number of tensors as inputs.
Each tensor should
be the same size as the corresponding input or
broadcastable to match the input tensor. Values across
all tensors should be integers in the range 0 to
num_interp_features - 1, and indices corresponding to the
same feature should have the same value.
Note that features are grouped across tensors
(unlike feature ablation and occlusion), so
if the same index is used in different tensors, those
features are still grouped and added simultaneously.
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
return_input_shape (bool, optional): Determines whether the returned
tensor(s) only contain the coefficients for each interp-
retable feature from the trained surrogate model, or
whether the returned attributions match the input shape.
When return_input_shape is True, the return type of attribute
matches the input shape, with each element containing the
coefficient of the corresponding interpretale feature.
All elements with the same value in the feature mask
will contain the same coefficient in the returned
attributions. If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpreatable models, with length
show_progress (bool, optional): Displays the progress of computation.
It will try to use tqdm if available for advanced features
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
*tensor* or tuple of *tensors* of **attributions**:
- **attributions** (*tensor* or tuple of *tensors*):
The attributions with respect to each input feature.
If return_input_shape = True, attributions will be
the same size as the provided inputs, with each value
providing the coefficient of the corresponding
interpretale feature.
If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpreatable models, with length
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>> # Generating random input with size 1 x 4 x 4
>>> input = torch.randn(1, 4, 4)
>>> # Defining Lime interpreter
>>> lime = Lime(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = lime.attribute(input, target=1, n_samples=200)
>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
>>> # This can be done by creating a feature mask as follows, which
>>> # defines the feature groups, e.g.:
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # With this mask, all inputs with the same value are set to their
>>> # baseline value, when the corresponding binary interpretable
>>> # feature is set to 0.
>>> # The attributions can be calculated as follows:
>>> # feature mask has dimensions 1 x 4 x 4
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
>>> [2,2,3,3],[2,2,3,3]]])
>>> # Computes interpretable model and returning attributions
>>> # matching input shape.
>>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
return self._attribute_kwargs(
def _attribute_kwargs( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
is_inputs_tuple = _is_tuple(inputs)
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
bsz = formatted_inputs[0].shape[0]
feature_mask, num_interp_features = construct_feature_mask(
feature_mask, formatted_inputs
if num_interp_features > 10000:
"Attempting to construct interpretable model with > 10000 features."
"This can be very slow or lead to OOM issues. Please provide a feature"
"mask which groups input features to reduce the number of interpretable"
"features. "
coefs: Tensor
if bsz > 1:
test_output = _run_forward(
self.forward_func, inputs, target, additional_forward_args
if isinstance(test_output, Tensor) and torch.numel(test_output) > 1:
if torch.numel(test_output) == bsz:
"You are providing multiple inputs for Lime / Kernel SHAP "
"attributions. This trains a separate interpretable model "
"for each example, which can be time consuming. It is "
"recommended to compute attributions for one example at a time."
output_list = []
for (
) in _batch_example_iterator(
coefs = super().attribute.__wrapped__(
inputs=curr_inps if is_inputs_tuple else curr_inps[0],
if is_inputs_tuple
else curr_baselines[0],
if is_inputs_tuple
else curr_feature_mask[0],
if return_input_shape:
output_list.append(coefs.reshape(1, -1)) # type: ignore
return _reduce_list(output_list)
raise AssertionError(
"Invalid number of outputs, forward function should return a"
"scalar per example or a scalar per input batch."
assert perturbations_per_eval == 1, (
"Perturbations per eval must be 1 when forward function"
"returns single value per batch!"
coefs = super().attribute.__wrapped__(
baselines=baselines if is_inputs_tuple else baselines[0],
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
if return_input_shape:
return self._convert_output_shape(
return coefs
def _convert_output_shape(
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: Literal[True],
) -> Tuple[Tensor, ...]:
def _convert_output_shape(
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: Literal[False],
) -> Tensor:
def _convert_output_shape(
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]:
coefs = coefs.flatten()
attr = [
torch.zeros_like(single_inp, dtype=torch.float)
for single_inp in formatted_inp
for tensor_ind in range(len(formatted_inp)):
for single_feature in range(num_interp_features):
attr[tensor_ind] += (
* (feature_mask[tensor_ind] == single_feature).float()
return _format_output(is_inputs_tuple, tuple(attr))