File size: 2,848 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3

import inspect
from typing import Any

import torch.nn as nn


class InputIdentity(nn.Module):
    def __init__(self, input_name: str) -> None:
        r"""
        The identity operation

        Args:
            input_name (str)
                The name of the input this layer is associated to. For debugging
                purposes.
        """
        super().__init__()
        self.input_name = input_name

    def forward(self, x):
        return x


class ModelInputWrapper(nn.Module):
    def __init__(self, module_to_wrap: nn.Module) -> None:
        r"""
        This is a convenience class. This wraps a model via first feeding the
        model's inputs to separate layers (one for each input) and then feeding
        the (unmodified) inputs to the underlying model (`module_to_wrap`). Each
        input is fed through an `InputIdentity` layer/module. This class does
        not change how you feed inputs to your model, so feel free to use your
        model as you normally would.

        To access a wrapped input layer, simply access it via the `input_maps`
        ModuleDict, e.g. to get the corresponding module for input "x", simply
        provide/write `my_wrapped_module.input_maps["x"]`

        This is done such that one can use layer attribution methods on inputs.
        Which should allow you to use mix layers with inputs with these
        attribution methods. This is especially useful multimodal models which
        input discrete features (mapped to embeddings, such as text) and regular
        continuous feature vectors.

        Notes:
        - Since inputs are mapped with the identity, attributing to the
          input/feature can be done with either the input or output of the
          layer, e.g. attributing to an input/feature doesn't depend on whether
          attribute_to_layer_input is True or False for
          LayerIntegratedGradients.
        - Please refer to the multimodal tutorial or unit tests
          (test/attr/test_layer_wrapper.py) for an example.

        Args:
            module_to_wrap (nn.Module):
                The model/module you want to wrap
        """
        super().__init__()
        self.module = module_to_wrap

        # ignore self
        self.arg_name_list = inspect.getfullargspec(module_to_wrap.forward).args[1:]
        self.input_maps = nn.ModuleDict(
            {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list}
        )

    def forward(self, *args, **kwargs) -> Any:
        args = list(args)
        for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args)):
            args[idx] = self.input_maps[arg_name](arg)

        for arg_name in kwargs.keys():
            kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])

        return self.module(*tuple(args), **kwargs)