Spaces:
Build error
Build error
File size: 59,470 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 |
#!/usr/bin/env python3
import inspect
import math
import typing
import warnings
from typing import Any, Callable, cast, List, Optional, Tuple, Union
import torch
from captum._utils.common import (
_expand_additional_forward_args,
_expand_target,
_flatten_tensor_or_tuple,
_format_output,
_format_tensor_into_tuples,
_is_tuple,
_reduce_list,
_run_forward,
)
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.models.model import Model
from captum._utils.progress import progress
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.batching import _batch_example_iterator
from captum.attr._utils.common import (
_construct_default_feature_mask,
_format_input_baseline,
)
from captum.log import log_usage
from torch import Tensor
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, TensorDataset
class LimeBase(PerturbationAttribution):
r"""
Lime is an interpretability method that trains an interpretable surrogate model
by sampling points around a specified input example and using model evaluations
at these points to train a simpler interpretable 'surrogate' model, such as a
linear model.
LimeBase provides a generic framework to train a surrogate interpretable model.
This differs from most other attribution methods, since the method returns a
representation of the interpretable model (e.g. coefficients of the linear model).
For a similar interface to other perturbation-based attribution methods, please use
the Lime child class, which defines specific transformations for the interpretable
model.
LimeBase allows sampling points in either the interpretable space or the original
input space to train the surrogate model. The interpretable space is a feature
vector used to train the surrogate interpretable model; this feature space is often
of smaller dimensionality than the original feature space in order for the surrogate
model to be more interpretable.
If sampling in the interpretable space, a transformation function must be provided
to define how a vector sampled in the interpretable space can be transformed into
an example in the original input space. If sampling in the original input space, a
transformation function must be provided to define how the input can be transformed
into its interpretable vector representation.
More details regarding LIME can be found in the original paper:
https://arxiv.org/abs/1602.04938
"""
def __init__(
self,
forward_func: Callable,
interpretable_model: Model,
similarity_func: Callable,
perturb_func: Callable,
perturb_interpretable_space: bool,
from_interp_rep_transform: Optional[Callable],
to_interp_rep_transform: Optional[Callable],
) -> None:
r"""
Args:
forward_func (callable): The forward function of the model or any
modification of it. If a batch is provided as input for
attribution, it is expected that forward_func returns a scalar
representing the entire batch.
interpretable_model (Model): Model object to train interpretable model.
A Model object provides a `fit` method to train the model,
given a dataloader, with batches containing three tensors:
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.
Some predefined interpretable linear models are provided in
captum._utils.models.linear_model including wrappers around
SkLearn linear models as well as SGD-based PyTorch linear
models.
Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
training interpretable model. Weight is generally
determined based on similarity to the original input.
The original paper refers to this as a similarity kernel.
The expected signature of this callable is:
>>> similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape (regardless of whether or not
the sampling function returns inputs in the interpretable
space). original_input is the same as the input provided
when calling attribute.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
perturb_func (callable): Function which returns a single
sampled input, generally a perturbation of the original
input, which is used to train the interpretable surrogate
model. Function can return samples in either
the original input space (matching type and tensor shapes
of original input) or in the interpretable input space,
which is a vector containing the intepretable features.
Alternatively, this function can return a generator
yielding samples to train the interpretable surrogate
model, and n_samples perturbations will be sampled
from this generator.
The expected signature of this callable is:
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors or
>>> generator yielding tensor or tuple of Tensors
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
Returned sampled input should match the input type (Tensor
or Tuple of Tensor and corresponding shapes) if
perturb_interpretable_space = False. If
perturb_interpretable_space = True, the return type should
be a single tensor of shape 1 x num_interp_features,
corresponding to the representation of the
sample to train the interpretable model.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
perturb_interpretable_space (bool): Indicates whether
perturb_func returns a sample in the interpretable space
(tensor of shape 1 x num_interp_features) or a sample
in the original space, matching the format of the original
input. Once sampled, inputs can be converted to / from
the interpretable representation with either
to_interp_rep_transform or from_interp_rep_transform.
from_interp_rep_transform (callable): Function which takes a
single sampled interpretable representation (tensor
of shape 1 x num_interp_features) and returns
the corresponding representation in the input space
(matching shapes of original input to attribute).
This argument is necessary if perturb_interpretable_space
is True, otherwise None can be provided for this argument.
The expected signature of this callable is:
>>> from_interp_rep_transform(
>>> curr_sample: Tensor [2D 1 x num_interp_features]
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors
Returned sampled input should match the type of original_input
and corresponding tensor shapes.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
to_interp_rep_transform (callable): Function which takes a
sample in the original input space and converts to
its interpretable representation (tensor
of shape 1 x num_interp_features).
This argument is necessary if perturb_interpretable_space
is False, otherwise None can be provided for this argument.
The expected signature of this callable is:
>>> to_interp_rep_transform(
>>> curr_sample: Tensor or Tuple of Tensors,
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [2D 1 x num_interp_features]
curr_sample will match the type of original_input
and corresponding tensor shapes.
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
"""
PerturbationAttribution.__init__(self, forward_func)
self.interpretable_model = interpretable_model
self.similarity_func = similarity_func
self.perturb_func = perturb_func
self.perturb_interpretable_space = perturb_interpretable_space
self.from_interp_rep_transform = from_interp_rep_transform
self.to_interp_rep_transform = to_interp_rep_transform
if self.perturb_interpretable_space:
assert (
self.from_interp_rep_transform is not None
), "Must provide transform from interpretable space to original input space"
" when sampling from interpretable space."
else:
assert (
self.to_interp_rep_transform is not None
), "Must provide transform from original input space to interpretable space"
@log_usage()
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: Any = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
show_progress: bool = False,
**kwargs,
) -> Tensor:
r"""
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
scalar) to the inputs of the model using the approach described above.
It trains an interpretable model and returns a representation of the
interpretable model.
It is recommended to only provide a single example as input (tensors
with first dimension or batch size = 1). This is because LIME is generally
used for sample-based interpretability, training a separate interpretable
model to explain a model's prediction on each individual example.
A batch of inputs can be provided as inputs only if forward_func
returns a single value per batch (e.g. loss).
The interpretable feature representation should still have shape
1 x num_interp_features, corresponding to the interpretable
representation for the full batch, and perturbations_per_eval
must be set to 1.
Args:
inputs (tensor or tuple of tensors): Input for which LIME
is computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
(for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (any, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. For all other types,
the given argument is used for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
samples.
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
show_progress (bool, optional): Displays the progress of computation.
It will try to use tqdm if available for advanced features
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
**kwargs (Any, optional): Any additional arguments necessary for
sampling and transformation functions (provided to
constructor).
Default: None
Returns:
**interpretable model representation**:
- **interpretable model representation* (*Any*):
A representation of the interpretable model trained. The return
type matches the return type of train_interpretable_model_func.
For example, this could contain coefficients of a
linear surrogate model.
Examples::
>>> # SimpleClassifier takes a single input tensor of
>>> # float features with size N x 5,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>>
>>> # We will train an interpretable model with the same
>>> # features by simply sampling with added Gaussian noise
>>> # to the inputs and training a model to predict the
>>> # score of the target class.
>>>
>>> # For interpretable model training, we will use sklearn
>>> # linear model in this example. We have provided wrappers
>>> # around sklearn linear models to fit the Model interface.
>>> # Any arguments provided to the sklearn constructor can also
>>> # be provided to the wrapper, e.g.:
>>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
>>> from captum._utils.models.linear_model import SkLearnLinearModel
>>>
>>>
>>> # Define similarity kernel (exponential kernel based on L2 norm)
>>> def similarity_kernel(
>>> original_input: Tensor,
>>> perturbed_input: Tensor,
>>> perturbed_interpretable_input: Tensor,
>>> **kwargs)->Tensor:
>>> # kernel_width will be provided to attribute as a kwarg
>>> kernel_width = kwargs["kernel_width"]
>>> l2_dist = torch.norm(original_input - perturbed_input)
>>> return torch.exp(- (l2_dist**2) / (kernel_width**2))
>>>
>>>
>>> # Define sampling function
>>> # This function samples in original input space
>>> def perturb_func(
>>> original_input: Tensor,
>>> **kwargs)->Tensor:
>>> return original_input + torch.randn_like(original_input)
>>>
>>> # For this example, we are setting the interpretable input to
>>> # match the model input, so the to_interp_rep_transform
>>> # function simply returns the input. In most cases, the interpretable
>>> # input will be different and may have a smaller feature set, so
>>> # an appropriate transformation function should be provided.
>>>
>>> def to_interp_transform(curr_sample, original_inp,
>>> **kwargs):
>>> return curr_sample
>>>
>>> # Generating random input with size 1 x 5
>>> input = torch.randn(1, 5)
>>> # Defining LimeBase interpreter
>>> lime_attr = LimeBase(net,
SkLearnLinearModel("linear_model.Ridge"),
similarity_func=similarity_kernel,
perturb_func=perturb_func,
perturb_interpretable_space=False,
from_interp_rep_transform=None,
to_interp_rep_transform=to_interp_transform)
>>> # Computes interpretable model, returning coefficients of linear
>>> # model.
>>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)
"""
with torch.no_grad():
inp_tensor = (
cast(Tensor, inputs) if isinstance(inputs, Tensor) else inputs[0]
)
device = inp_tensor.device
interpretable_inps = []
similarities = []
outputs = []
curr_model_inputs = []
expanded_additional_args = None
expanded_target = None
perturb_generator = None
if inspect.isgeneratorfunction(self.perturb_func):
perturb_generator = self.perturb_func(inputs, **kwargs)
if show_progress:
attr_progress = progress(
total=math.ceil(n_samples / perturbations_per_eval),
desc=f"{self.get_name()} attribution",
)
attr_progress.update(0)
batch_count = 0
for _ in range(n_samples):
if perturb_generator:
try:
curr_sample = next(perturb_generator)
except StopIteration:
warnings.warn(
"Generator completed prior to given n_samples iterations!"
)
break
else:
curr_sample = self.perturb_func(inputs, **kwargs)
batch_count += 1
if self.perturb_interpretable_space:
interpretable_inps.append(curr_sample)
curr_model_inputs.append(
self.from_interp_rep_transform( # type: ignore
curr_sample, inputs, **kwargs
)
)
else:
curr_model_inputs.append(curr_sample)
interpretable_inps.append(
self.to_interp_rep_transform( # type: ignore
curr_sample, inputs, **kwargs
)
)
curr_sim = self.similarity_func(
inputs, curr_model_inputs[-1], interpretable_inps[-1], **kwargs
)
similarities.append(
curr_sim.flatten()
if isinstance(curr_sim, Tensor)
else torch.tensor([curr_sim], device=device)
)
if len(curr_model_inputs) == perturbations_per_eval:
if expanded_additional_args is None:
expanded_additional_args = _expand_additional_forward_args(
additional_forward_args, len(curr_model_inputs)
)
if expanded_target is None:
expanded_target = _expand_target(target, len(curr_model_inputs))
model_out = self._evaluate_batch(
curr_model_inputs,
expanded_target,
expanded_additional_args,
device,
)
if show_progress:
attr_progress.update()
outputs.append(model_out)
curr_model_inputs = []
if len(curr_model_inputs) > 0:
expanded_additional_args = _expand_additional_forward_args(
additional_forward_args, len(curr_model_inputs)
)
expanded_target = _expand_target(target, len(curr_model_inputs))
model_out = self._evaluate_batch(
curr_model_inputs,
expanded_target,
expanded_additional_args,
device,
)
if show_progress:
attr_progress.update()
outputs.append(model_out)
if show_progress:
attr_progress.close()
combined_interp_inps = torch.cat(interpretable_inps).double()
combined_outputs = (
torch.cat(outputs)
if len(outputs[0].shape) > 0
else torch.stack(outputs)
).double()
combined_sim = (
torch.cat(similarities)
if len(similarities[0].shape) > 0
else torch.stack(similarities)
).double()
dataset = TensorDataset(
combined_interp_inps, combined_outputs, combined_sim
)
self.interpretable_model.fit(DataLoader(dataset, batch_size=batch_count))
return self.interpretable_model.representation()
def _evaluate_batch(
self,
curr_model_inputs: List[TensorOrTupleOfTensorsGeneric],
expanded_target: TargetType,
expanded_additional_args: Any,
device: torch.device,
):
model_out = _run_forward(
self.forward_func,
_reduce_list(curr_model_inputs),
expanded_target,
expanded_additional_args,
)
if isinstance(model_out, Tensor):
assert model_out.numel() == len(curr_model_inputs), (
"Number of outputs is not appropriate, must return "
"one output per perturbed input"
)
if isinstance(model_out, Tensor):
return model_out.flatten()
return torch.tensor([model_out], device=device)
def has_convergence_delta(self) -> bool:
return False
@property
def multiplies_by_inputs(self):
return False
# Default transformations and methods
# for Lime child implementation.
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
assert (
"feature_mask" in kwargs
), "Must provide feature_mask to use default interpretable representation transform"
assert (
"baselines" in kwargs
), "Must provide baselines to use default interpretable representation transfrom"
feature_mask = kwargs["feature_mask"]
if isinstance(feature_mask, Tensor):
binary_mask = curr_sample[0][feature_mask].bool()
return (
binary_mask.to(original_inputs.dtype) * original_inputs
+ (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
)
else:
binary_mask = tuple(
curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
)
return tuple(
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
+ (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
for j in range(len(feature_mask))
)
def get_exp_kernel_similarity_function(
distance_mode: str = "cosine", kernel_width: float = 1.0
) -> Callable:
r"""
This method constructs an appropriate similarity function to compute
weights for perturbed sample in LIME. Distance between the original
and perturbed inputs is computed based on the provided distance mode,
and the distance is passed through an exponential kernel with given
kernel width to convert to a range between 0 and 1.
The callable returned can be provided as the similarity_fn for
Lime or LimeBase.
Args:
distance_mode (str, optional): Distance mode can be either "cosine" or
"euclidean" corresponding to either cosine distance
or Euclidean distance respectively. Distance is computed
by flattening the original inputs and perturbed inputs
(concatenating tuples of inputs if necessary) and computing
distances between the resulting vectors.
Default: "cosine"
kernel_width (float, optional):
Kernel width for exponential kernel applied to distance.
Default: 1.0
Returns:
*Callable*:
- **similarity_fn** (*Callable*):
Similarity function. This callable can be provided as the
similarity_fn for Lime or LimeBase.
"""
def default_exp_kernel(original_inp, perturbed_inp, __, **kwargs):
flattened_original_inp = _flatten_tensor_or_tuple(original_inp).float()
flattened_perturbed_inp = _flatten_tensor_or_tuple(perturbed_inp).float()
if distance_mode == "cosine":
cos_sim = CosineSimilarity(dim=0)
distance = 1 - cos_sim(flattened_original_inp, flattened_perturbed_inp)
elif distance_mode == "euclidean":
distance = torch.norm(flattened_original_inp - flattened_perturbed_inp)
else:
raise ValueError("distance_mode must be either cosine or euclidean.")
return math.exp(-1 * (distance ** 2) / (2 * (kernel_width ** 2)))
return default_exp_kernel
def default_perturb_func(original_inp, **kwargs):
assert (
"num_interp_features" in kwargs
), "Must provide num_interp_features to use default interpretable sampling function"
if isinstance(original_inp, Tensor):
device = original_inp.device
else:
device = original_inp[0].device
probs = torch.ones(1, kwargs["num_interp_features"]) * 0.5
return torch.bernoulli(probs).to(device=device).long()
def construct_feature_mask(feature_mask, formatted_inputs):
if feature_mask is None:
feature_mask, num_interp_features = _construct_default_feature_mask(
formatted_inputs
)
else:
feature_mask = _format_tensor_into_tuples(feature_mask)
min_interp_features = int(
min(
torch.min(single_mask).item()
for single_mask in feature_mask
if single_mask.numel()
)
)
if min_interp_features != 0:
warnings.warn(
"Minimum element in feature mask is not 0, shifting indices to"
" start at 0."
)
feature_mask = tuple(
single_mask - min_interp_features for single_mask in feature_mask
)
num_interp_features = int(
max(
torch.max(single_mask).item()
for single_mask in feature_mask
if single_mask.numel()
)
+ 1
)
return feature_mask, num_interp_features
class Lime(LimeBase):
r"""
Lime is an interpretability method that trains an interpretable surrogate model
by sampling points around a specified input example and using model evaluations
at these points to train a simpler interpretable 'surrogate' model, such as a
linear model.
Lime provides a more specific implementation than LimeBase in order to expose
a consistent API with other perturbation-based algorithms. For more general
use of the LIME framework, consider using the LimeBase class directly and
defining custom sampling and transformation to / from interpretable
representation functions.
Lime assumes that the interpretable representation is a binary vector,
corresponding to some elements in the input being set to their baseline value
if the corresponding binary interpretable feature value is 0 or being set
to the original input value if the corresponding binary interpretable
feature value is 1. Input values can be grouped to correspond to the same
binary interpretable feature using a feature mask provided when calling
attribute, similar to other perturbation-based attribution methods.
One example of this setting is when applying Lime to an image classifier.
Pixels in an image can be grouped into super-pixels or segments, which
correspond to interpretable features, provided as a feature_mask when
calling attribute. Sampled binary vectors convey whether a super-pixel
is on (retains the original input values) or off (set to the corresponding
baseline value, e.g. black image). An interpretable linear model is trained
with input being the binary vectors and outputs as the corresponding scores
of the image classifier with the appropriate super-pixels masked based on the
binary vector. Coefficients of the trained surrogate
linear model convey the importance of each super-pixel.
More details regarding LIME can be found in the original paper:
https://arxiv.org/abs/1602.04938
"""
def __init__(
self,
forward_func: Callable,
interpretable_model: Optional[Model] = None,
similarity_func: Optional[Callable] = None,
perturb_func: Optional[Callable] = None,
) -> None:
r"""
Args:
forward_func (callable): The forward function of the model or any
modification of it
interpretable_model (optional, Model): Model object to train
interpretable model.
This argument is optional and defaults to SkLearnLasso(alpha=0.01),
which is a wrapper around the Lasso linear model in SkLearn.
This requires having sklearn version >= 0.23 available.
Other predefined interpretable linear models are provided in
captum._utils.models.linear_model.
Alternatively, a custom model object must provide a `fit` method to
train the model, given a dataloader, with batches containing
three tensors:
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
interpretable model after fitting.
Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (optional, callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
training the interpretable model.
This is often referred to as a similarity kernel.
This argument is optional and defaults to a function which
applies an exponential kernel to the consine distance between
the original input and perturbed input, with a kernel width
of 1.0.
A similarity function applying an exponential
kernel to cosine / euclidean distances can be constructed
using the provided get_exp_kernel_similarity_function in
captum.attr._core.lime.
Alternately, a custom callable can also be provided.
The expected signature of this callable is:
>>> def similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape, with original_input
being the same as the input provided when calling attribute.
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask).
perturb_func (optional, callable): Function which returns a single
sampled input, which is a binary vector of length
num_interp_features, or a generator of such tensors.
This function is optional, the default function returns
a binary vector where each element is selected
independently and uniformly at random. Custom
logic for selecting sampled binary vectors can
be implemented by providing a function with the
following expected signature:
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
>>> or generator yielding such tensors
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask).
"""
if interpretable_model is None:
interpretable_model = SkLearnLasso(alpha=0.01)
if similarity_func is None:
similarity_func = get_exp_kernel_similarity_function()
if perturb_func is None:
perturb_func = default_perturb_func
LimeBase.__init__(
self,
forward_func,
interpretable_model,
similarity_func,
perturb_func,
True,
default_from_interp_rep_transform,
None,
)
@log_usage()
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 50,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
) -> TensorOrTupleOfTensorsGeneric:
r"""
This method attributes the output of the model with given target index
(in case it is provided, otherwise it assumes that output is a
scalar) to the inputs of the model using the approach described above,
training an interpretable model and returning a representation of the
interpretable model.
It is recommended to only provide a single example as input (tensors
with first dimension or batch size = 1). This is because LIME is generally
used for sample-based interpretability, training a separate interpretable
model to explain a model's prediction on each individual example.
A batch of inputs can also be provided as inputs, similar to
other perturbation-based attribution methods. In this case, if forward_fn
returns a scalar per example, attributions will be computed for each
example independently, with a separate interpretable model trained for each
example. Note that provided similarity and perturbation functions will be
provided each example separately (first dimension = 1) in this case.
If forward_fn returns a scalar per batch (e.g. loss), attributions will
still be computed using a single interpretable model for the full batch.
In this case, similarity and perturbation functions will be provided the
same original input containing the full batch.
The number of interpretable features is determined from the provided
feature mask, or if none is provided, from the default feature mask,
which considers each scalar input as a separate feature. It is
generally recommended to provide a feature mask which groups features
into a small number of interpretable features / components (e.g.
superpixels in images).
Args:
inputs (tensor or tuple of tensors): Input for which LIME
is computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
baselines (scalar, tensor, tuple of scalars or tensors, optional):
Baselines define reference value which replaces each
feature when the corresponding interpretable feature
is set to 0.
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
- either a tensor with matching dimensions to
corresponding tensor in the inputs' tuple
or the first dimension is one and the remaining
dimensions match with the corresponding
input tensor.
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
(for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (any, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. It will be
repeated for each of `n_steps` along the integrated
path. For all other types, the given argument is used
for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
feature_mask (tensor or tuple of tensors, optional):
feature_mask defines a mask for the input, grouping
features which correspond to the same
interpretable feature. feature_mask
should contain the same number of tensors as inputs.
Each tensor should
be the same size as the corresponding input or
broadcastable to match the input tensor. Values across
all tensors should be integers in the range 0 to
num_interp_features - 1, and indices corresponding to the
same feature should have the same value.
Note that features are grouped across tensors
(unlike feature ablation and occlusion), so
if the same index is used in different tensors, those
features are still grouped and added simultaneously.
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature.
Default: None
n_samples (int, optional): The number of samples of the original
model used to train the surrogate interpretable model.
Default: `50` if `n_samples` is not provided.
perturbations_per_eval (int, optional): Allows multiple samples
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
samples.
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
return_input_shape (bool, optional): Determines whether the returned
tensor(s) only contain the coefficients for each interp-
retable feature from the trained surrogate model, or
whether the returned attributions match the input shape.
When return_input_shape is True, the return type of attribute
matches the input shape, with each element containing the
coefficient of the corresponding interpretale feature.
All elements with the same value in the feature mask
will contain the same coefficient in the returned
attributions. If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpreatable models, with length
num_interp_features.
show_progress (bool, optional): Displays the progress of computation.
It will try to use tqdm if available for advanced features
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
Returns:
*tensor* or tuple of *tensors* of **attributions**:
- **attributions** (*tensor* or tuple of *tensors*):
The attributions with respect to each input feature.
If return_input_shape = True, attributions will be
the same size as the provided inputs, with each value
providing the coefficient of the corresponding
interpretale feature.
If return_input_shape is False, a 1D
tensor is returned, containing only the coefficients
of the trained interpreatable models, with length
num_interp_features.
Examples::
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>> # Generating random input with size 1 x 4 x 4
>>> input = torch.randn(1, 4, 4)
>>> # Defining Lime interpreter
>>> lime = Lime(net)
>>> # Computes attribution, with each of the 4 x 4 = 16
>>> # features as a separate interpretable feature
>>> attr = lime.attribute(input, target=1, n_samples=200)
>>> # Alternatively, we can group each 2x2 square of the inputs
>>> # as one 'interpretable' feature and perturb them together.
>>> # This can be done by creating a feature mask as follows, which
>>> # defines the feature groups, e.g.:
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # With this mask, all inputs with the same value are set to their
>>> # baseline value, when the corresponding binary interpretable
>>> # feature is set to 0.
>>> # The attributions can be calculated as follows:
>>> # feature mask has dimensions 1 x 4 x 4
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
>>> [2,2,3,3],[2,2,3,3]]])
>>> # Computes interpretable model and returning attributions
>>> # matching input shape.
>>> attr = lime.attribute(input, target=1, feature_mask=feature_mask)
"""
return self._attribute_kwargs(
inputs=inputs,
baselines=baselines,
target=target,
additional_forward_args=additional_forward_args,
feature_mask=feature_mask,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
return_input_shape=return_input_shape,
show_progress=show_progress,
)
def _attribute_kwargs( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
n_samples: int = 25,
perturbations_per_eval: int = 1,
return_input_shape: bool = True,
show_progress: bool = False,
**kwargs,
) -> TensorOrTupleOfTensorsGeneric:
is_inputs_tuple = _is_tuple(inputs)
formatted_inputs, baselines = _format_input_baseline(inputs, baselines)
bsz = formatted_inputs[0].shape[0]
feature_mask, num_interp_features = construct_feature_mask(
feature_mask, formatted_inputs
)
if num_interp_features > 10000:
warnings.warn(
"Attempting to construct interpretable model with > 10000 features."
"This can be very slow or lead to OOM issues. Please provide a feature"
"mask which groups input features to reduce the number of interpretable"
"features. "
)
coefs: Tensor
if bsz > 1:
test_output = _run_forward(
self.forward_func, inputs, target, additional_forward_args
)
if isinstance(test_output, Tensor) and torch.numel(test_output) > 1:
if torch.numel(test_output) == bsz:
warnings.warn(
"You are providing multiple inputs for Lime / Kernel SHAP "
"attributions. This trains a separate interpretable model "
"for each example, which can be time consuming. It is "
"recommended to compute attributions for one example at a time."
)
output_list = []
for (
curr_inps,
curr_target,
curr_additional_args,
curr_baselines,
curr_feature_mask,
) in _batch_example_iterator(
bsz,
formatted_inputs,
target,
additional_forward_args,
baselines,
feature_mask,
):
coefs = super().attribute.__wrapped__(
self,
inputs=curr_inps if is_inputs_tuple else curr_inps[0],
target=curr_target,
additional_forward_args=curr_additional_args,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
baselines=curr_baselines
if is_inputs_tuple
else curr_baselines[0],
feature_mask=curr_feature_mask
if is_inputs_tuple
else curr_feature_mask[0],
num_interp_features=num_interp_features,
show_progress=show_progress,
**kwargs,
)
if return_input_shape:
output_list.append(
self._convert_output_shape(
curr_inps,
curr_feature_mask,
coefs,
num_interp_features,
is_inputs_tuple,
)
)
else:
output_list.append(coefs.reshape(1, -1)) # type: ignore
return _reduce_list(output_list)
else:
raise AssertionError(
"Invalid number of outputs, forward function should return a"
"scalar per example or a scalar per input batch."
)
else:
assert perturbations_per_eval == 1, (
"Perturbations per eval must be 1 when forward function"
"returns single value per batch!"
)
coefs = super().attribute.__wrapped__(
self,
inputs=inputs,
target=target,
additional_forward_args=additional_forward_args,
n_samples=n_samples,
perturbations_per_eval=perturbations_per_eval,
baselines=baselines if is_inputs_tuple else baselines[0],
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
num_interp_features=num_interp_features,
show_progress=show_progress,
**kwargs,
)
if return_input_shape:
return self._convert_output_shape(
formatted_inputs,
feature_mask,
coefs,
num_interp_features,
is_inputs_tuple,
)
else:
return coefs
@typing.overload
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: Literal[True],
) -> Tuple[Tensor, ...]:
...
@typing.overload
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: Literal[False],
) -> Tensor:
...
def _convert_output_shape(
self,
formatted_inp: Tuple[Tensor, ...],
feature_mask: Tuple[Tensor, ...],
coefs: Tensor,
num_interp_features: int,
is_inputs_tuple: bool,
) -> Union[Tensor, Tuple[Tensor, ...]]:
coefs = coefs.flatten()
attr = [
torch.zeros_like(single_inp, dtype=torch.float)
for single_inp in formatted_inp
]
for tensor_ind in range(len(formatted_inp)):
for single_feature in range(num_interp_features):
attr[tensor_ind] += (
coefs[single_feature].item()
* (feature_mask[tensor_ind] == single_feature).float()
)
return _format_output(is_inputs_tuple, tuple(attr))
|