strexp / captum /attr /_utils /input_layer_wrapper.py
markytools's picture
added strexp
d61b9c7
raw
history blame
2.85 kB
#!/usr/bin/env python3
import inspect
from typing import Any
import torch.nn as nn
class InputIdentity(nn.Module):
def __init__(self, input_name: str) -> None:
r"""
The identity operation
Args:
input_name (str)
The name of the input this layer is associated to. For debugging
purposes.
"""
super().__init__()
self.input_name = input_name
def forward(self, x):
return x
class ModelInputWrapper(nn.Module):
def __init__(self, module_to_wrap: nn.Module) -> None:
r"""
This is a convenience class. This wraps a model via first feeding the
model's inputs to separate layers (one for each input) and then feeding
the (unmodified) inputs to the underlying model (`module_to_wrap`). Each
input is fed through an `InputIdentity` layer/module. This class does
not change how you feed inputs to your model, so feel free to use your
model as you normally would.
To access a wrapped input layer, simply access it via the `input_maps`
ModuleDict, e.g. to get the corresponding module for input "x", simply
provide/write `my_wrapped_module.input_maps["x"]`
This is done such that one can use layer attribution methods on inputs.
Which should allow you to use mix layers with inputs with these
attribution methods. This is especially useful multimodal models which
input discrete features (mapped to embeddings, such as text) and regular
continuous feature vectors.
Notes:
- Since inputs are mapped with the identity, attributing to the
input/feature can be done with either the input or output of the
layer, e.g. attributing to an input/feature doesn't depend on whether
attribute_to_layer_input is True or False for
LayerIntegratedGradients.
- Please refer to the multimodal tutorial or unit tests
(test/attr/test_layer_wrapper.py) for an example.
Args:
module_to_wrap (nn.Module):
The model/module you want to wrap
"""
super().__init__()
self.module = module_to_wrap
# ignore self
self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
self.input_maps = nn.ModuleDict(
{arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
)
def forward(self, *args, **kwargs) -> Any:
args = list(args)
for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
args[idx] = self.input_maps[arg_name](arg)
for arg_name in kwargs.keys():
kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])
return self.module(*tuple(args), **kwargs)