#!/usr/bin/env python3 import inspect from typing import Any import torch.nn as nn class InputIdentity(nn.Module): def __init__(self, input_name: str) -> None: r""" The identity operation Args: input_name (str) The name of the input this layer is associated to. For debugging purposes. """ super().__init__() self.input_name = input_name def forward(self, x): return x class ModelInputWrapper(nn.Module): def __init__(self, module_to_wrap: nn.Module) -> None: r""" This is a convenience class. This wraps a model via first feeding the model's inputs to separate layers (one for each input) and then feeding the (unmodified) inputs to the underlying model (`module_to_wrap`). Each input is fed through an `InputIdentity` layer/module. This class does not change how you feed inputs to your model, so feel free to use your model as you normally would. To access a wrapped input layer, simply access it via the `input_maps` ModuleDict, e.g. to get the corresponding module for input "x", simply provide/write `my_wrapped_module.input_maps["x"]` This is done such that one can use layer attribution methods on inputs. Which should allow you to use mix layers with inputs with these attribution methods. This is especially useful multimodal models which input discrete features (mapped to embeddings, such as text) and regular continuous feature vectors. Notes: - Since inputs are mapped with the identity, attributing to the input/feature can be done with either the input or output of the layer, e.g. attributing to an input/feature doesn't depend on whether attribute_to_layer_input is True or False for LayerIntegratedGradients. - Please refer to the multimodal tutorial or unit tests (test/attr/test_layer_wrapper.py) for an example. Args: module_to_wrap (nn.Module): The model/module you want to wrap """ super().__init__() self.module = module_to_wrap # ignore self self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:] self.input_maps = nn.ModuleDict( {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} ) def forward(self, *args, **kwargs) -> Any: args = list(args) for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)): args[idx] = self.input_maps[arg_name](arg) for arg_name in kwargs.keys(): kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name]) return self.module(*tuple(args), **kwargs)